#include "bsdf.h"
#include "intersection.h"
#include "intersectionAD.h"


SpectrumAD BSDF::evalAD(const IntersectionAD &its, const VectorAD &wo, EBSDFMode mode) const {
    assert(false);
    return SpectrumAD();
}

SpectrumAD BSDF::sampleAD(const IntersectionAD &its, const Array3 &rnd, VectorAD &wo, Float &pdf, Float &eta, EBSDFMode mode) const {
    SpectrumAD ret;
    Intersection its0 = its.toIntersection();
    Vector wo0;
    if ( !sample(its0, rnd, wo0, pdf, eta).isZero(Epsilon) )
        if ( pdf > Epsilon ) {
            wo = its.toLocal(its0.toWorld(wo0));
            ret = evalAD(its, wo, mode)/pdf;
        }
    return ret;
}

Float BSDF::pdf(const IntersectionAD &its, const Vector3 &wo) const { return pdf(its.toIntersection(), wo);}

Float BSDF::correction(const Intersection &its, const Vector &wo) const {
    Vector wi_global = its.toWorld(its.wi);
    Vector wo_global = its.toWorld(wo);
    Float wiDotGeoN = wi_global.dot(its.geoFrame.n),
          woDotGeoN = wo_global.dot(its.geoFrame.n);
    return std::abs((its.wi.z() * woDotGeoN)/(wo.z() * wiDotGeoN));
}

FloatAD BSDF::correctionAD(const IntersectionAD &its, const VectorAD &wo) const {
    // if( its.ptr_shape->hasNormals()) {
    //     VectorAD wi_global = its.toWorld(its.wi);
    //     VectorAD wo_global = its.toWorld(wo);
    //     FloatAD wiDotGeoN = wi_global.dot(its.geoFrame.n),
    //             woDotGeoN = wo_global.dot(its.geoFrame.n);
    //     return (its.wi.z()*woDotGeoN)/(wo.z() * wiDotGeoN);
    //     // return FloatAD(std::abs((its.wi.val.z() * woDotGeoN.val)/(wo.val.z() * wiDotGeoN.val)));
    //     // return ((its.wi.z()*woDotGeoN)/(wo.z().val * wiDotGeoN)).abs();
    // } else
    //    return FloatAD(1.0); 
    VectorAD wi_global = its.toWorld(its.wi);
    VectorAD wo_global = its.toWorld(wo);
    FloatAD wiDotGeoN = wi_global.dot(its.geoFrame.n),
            woDotGeoN = wo_global.dot(its.geoFrame.n);
    return ((its.wi.z()*woDotGeoN)/(wo.z() * wiDotGeoN)).abs();
}


std::string BSDF::toString() const {
    std::ostringstream oss;
    oss << "Base BSDF []" << std::endl;
    return oss.str();
}
