#include "twosided.h"
#include "intersection.h"
#include "intersectionAD.h"
#include "utils.h"
#include <assert.h>

Spectrum TwosidedBSDF::eval(const Intersection &its, const Vector &wo, EBSDFMode mode) const {
    if ( its.wi.z() < -Epsilon ) {
        Intersection its1(its);
        its1.wi.z() = -its1.wi.z();
        its1.geoFrame.n = -its1.geoFrame.n;
        its1.shFrame.n = -its1.shFrame.n;

        Vector wo1(wo);
        wo1.z() = -wo.z();
        return m_nested->eval(its1, wo1, mode);
    } else
        return m_nested->eval(its, wo, mode);
}


SpectrumAD TwosidedBSDF::evalAD(const IntersectionAD &its, const VectorAD &wo, EBSDFMode mode) const {
    if ( its.wi.z() < -Epsilon ) {
        IntersectionAD its1(its);
        its1.wi.z() = -its1.wi.z();
        its1.geoFrame.n = -its1.geoFrame.n;
        its1.shFrame.n = -its1.shFrame.n;

        VectorAD wo1(wo);
        wo1.z() = -wo.z();
        return m_nested->evalAD(its1, wo1, mode);
    } else
        return m_nested->evalAD(its, wo, mode);
}


Spectrum TwosidedBSDF::sample(const Intersection &its, const Array3 &rnd, Vector &wo, Float &pdf, Float &eta, EBSDFMode mode) const {
    Intersection its1(its);
    bool flipped = false;

    if ( its.wi.z() < -Epsilon ) {
        its1.wi.z() = -its1.wi.z();
        its1.geoFrame.n = -its1.geoFrame.n;
        its1.shFrame.n = -its1.shFrame.n;
        flipped = true;
    }

    Spectrum ret = m_nested->sample(its1, rnd, wo, pdf, eta, mode);

    if ( flipped && !ret.isZero(Epsilon) )
        wo.z() = -wo.z();

    return ret;
}


Float TwosidedBSDF::pdf(const Intersection &its, const Vector &wo) const{
    if ( its.wi.z() < -Epsilon ) {
        Intersection its1(its);
        its1.wi.z() = -its1.wi.z();
        its1.geoFrame.n = -its1.geoFrame.n;
        its1.shFrame.n = -its1.shFrame.n;

        Vector wo1(wo);
        wo1.z() = -wo.z();
        return m_nested->pdf(its1, wo1);
    } else
        return m_nested->pdf(its, wo);
}
