#include <bsdf/diffuse.h>
#include <render/intersection.h>

// void *__augment_diffuse_eval(const BSDF *bsdf, BSDF *d_bsdf,
//                              const Intersection &its, Intersection &d_its,
//                              const Vector &wo, Vector &d_wo, EBSDFMode mode,
//                              EBSDFMode d_mode, Spectrum &ret, Spectrum
//                              &d_ret) {
//   return __enzyme_augmentfwd((void *)__diffuse_eval, enzyme_dup, bsdf,
//   d_bsdf,
//                              enzyme_dup, &its, &d_its, enzyme_dup, &wo,
//                              &d_wo, enzyme_const, mode, enzyme_dup, &ret,
//                              &d_ret);
// }

// void __gradient_diffuse_eval(const BSDF *bsdf, BSDF *d_bsdf,
//                              const Intersection &its, Intersection &d_its,
//                              const Vector &wo, Vector &d_wo, EBSDFMode mode,
//                              EBSDFMode d_mode, Spectrum &ret, Spectrum
//                              &d_ret, void *tape) {
//   __enzyme_reverse((void *)__diffuse_eval, enzyme_dup, bsdf, d_bsdf,
//   enzyme_dup,
//                    &its, &d_its, enzyme_dup, &wo, &d_wo, enzyme_const, mode,
//                    enzyme_dup, &ret, &d_ret, tape);
// }
// void *__enzyme_register_gradient_eval[] = {
//     (void *)__diffuse_eval,
//     (void *)__augment_diffuse_eval,
//     (void *)__gradient_diffuse_eval,
// };

DiffuseBSDF::DiffuseBSDF() {
    this->m_type = this->TYPE_ID;
}

DiffuseBSDF::DiffuseBSDF(const Spectrum &reflectance) : reflectance(reflectance) {
    this->m_type = this->TYPE_ID;
}

BSDFEvalType DiffuseBSDF::eval(const Intersection &its, const Vector &wo,
                               EBSDFMode mode) const {
    if (its.wi.z() < Epsilon || wo.z() < Epsilon)
        return Spectrum::Zero();
    else {
        if (mode == EBSDFMode::EImportanceWithCorrection)
            return reflectance.eval(its.uv) * INV_PI * wo.z() * correction(its, wo);
        else
            return reflectance.eval(its.uv) * INV_PI * wo.z();
    }
}

BSDFEvalType DiffuseBSDF::sample(const Intersection &its, const Array3 &rnd,
                                 Vector &wo, Float &pdf, Float &eta,
                                 EBSDFMode mode) const {
    if (its.wi.z() < Epsilon)
        return Spectrum::Zero();
    wo  = squareToCosineHemisphere(Vector2(rnd[0], rnd[1]));
    eta = 1.0f;
    pdf = squareToCosineHemispherePdf(wo);
    if (mode == EBSDFMode::EImportanceWithCorrection)
        return reflectance.eval(its.uv) * correction(its, wo);
    else
        return reflectance.eval(its.uv);
}

Float DiffuseBSDF::pdf(const Intersection &its, const Vector &wo) const {
    if (its.wi.z() < Epsilon || wo.z() < Epsilon)
        return 0.0;
    else
        return squareToCosineHemispherePdf(wo);
}

PSDR_IMPL_BSDF_HELPER_FUNCTIONS(DiffuseBSDF);
