#include <render/bsdf.h>
#include <render/intersection.h>
#include "bsdf/diffuse.h"
#include "bsdf/roughconductor.h"
#include "bsdf/roughdielectric.h"
#include "bsdf/null.h"
PSDR_INACTIVE_CLASS(BSDF)

std::string BSDF::toString() const
{
    std::ostringstream oss;
    oss << "Base BSDF []" << std::endl;
    return oss.str();
}

Float BSDF::correction(const Intersection &its, const Vector &wo) const
{
    Vector wi_global = its.toWorld(its.wi);
    Vector wo_global = its.toWorld(wo);
    Float wiDotGeoN = wi_global.dot(its.geoFrame.n),
          woDotGeoN = wo_global.dot(its.geoFrame.n);
    return std::abs((its.wi.z() * woDotGeoN) / (wo.z() * wiDotGeoN));
}

#define PSDR_IMPLEMENT_BSDF(...)                        \
                                                        \
    bool BSDF::isTransmissive() const                   \
    {                                                   \
        PSDR_MAP_STMT_F(isTransmissive(), __VA_ARGS__); \
        assert(false);                                  \
        return true;                                    \
    }                                                   \
                                                        \
    bool BSDF::isTwosided() const                       \
    {                                                   \
        PSDR_MAP_STMT_F(isTwosided(), __VA_ARGS__);     \
        assert(false);                                  \
        return true;                                    \
    }                                                   \
                                                        \
    bool BSDF::isNull() const                           \
    {                                                   \
        PSDR_MAP_STMT_F(isNull(), __VA_ARGS__);         \
        assert(false);                                  \
        return true;                                    \
    }

//* insert an entry here when a new BSDF is added
PSDR_IMPLEMENT_BSDF(DiffuseBSDF, RoughConductorBSDF, RoughDielectricBSDF, NullBSDF)

// =============================================================================
// c-style functions to which custom augment and gradient functions are binded.
// original function call virtual BSDF::xxx()
namespace bsdf
{
    __attribute__((noinline)) void __eval(const BSDF *bsdf, const Intersection &its, const Vector &wo, EBSDFMode mode,
                                          Spectrum &ret)
    {
        bsdf->__eval(its, wo, mode, ret);
    }

    __attribute__((optnone)) void __sample(const BSDF *bsdf, const Intersection &its, const Array3 &sample, Vector &wo, Float &pdf, Float &eta, EBSDFMode mode,
                  Spectrum &ret)
    {
        bsdf->__sample(its, sample, wo, pdf, eta, mode, ret);
    }

    __attribute__((optnone)) void __pdf(const BSDF *bsdf, const Intersection &its, const Vector3 &wo,
               Float &ret)
    {
        bsdf->__pdf(its, wo, ret);
    }

    void *__augment_eval(const BSDF *bsdf, BSDF *d_bsdf, const Intersection &its, Intersection &d_its, const Vector &wo, Vector &d_wo, EBSDFMode mode, EBSDFMode d_mode, Spectrum &ret, Spectrum &d_ret)
    {
        return bsdf->__augment_eval(d_bsdf, its, d_its, wo, d_wo, mode, d_mode, ret, d_ret);
    }
    void __gradient_eval(const BSDF *bsdf, BSDF *d_bsdf, const Intersection &its, Intersection &d_its, const Vector &wo, Vector &d_wo, EBSDFMode mode, EBSDFMode d_mode, Spectrum &ret, Spectrum &d_ret, void *tape)
    {
        bsdf->__gradient_eval(d_bsdf, its, d_its, wo, d_wo, mode, d_mode, ret, d_ret, tape);
    }
    void *__enzyme_register_gradient_eval[] = {
        (void *)__eval,
        (void *)__augment_eval,
        (void *)__gradient_eval,
    };
    INACTIVE_FN(__pdf, __pdf);
    INACTIVE_FN(__sample, __sample);
}
// =============================================================================
// actual functions that will be called by the clients
Spectrum BSDF::eval(const Intersection &its, const Vector &wo, EBSDFMode mode) const
{
    Spectrum ret;
    bsdf::__eval(this, its, wo, mode, ret);
    return ret;
}
__attribute__((optnone)) Spectrum BSDF::sample(const Intersection &its, const Array3 &sample, Vector &wo, Float &pdf, Float &eta, EBSDFMode mode) const
{
    Spectrum ret;
    bsdf::__sample(this, its, sample, wo, pdf, eta, mode, ret);
    return ret;
}
__attribute__((optnone)) Float BSDF::pdf(const Intersection &its, const Vector3 &wo) const
{
    Float ret;
    bsdf::__pdf(this, its, wo, ret);
    return ret;
}

INACTIVE_FN(BSDF_pdf, &BSDF::pdf);
INACTIVE_FN(BSDF_sample, &BSDF::sample);

/*
// * for reference
__attribute__((optnone)) Float BSDF::pdf(const Intersection &its, const Vector3 &wo) const
{
    return static_cast<const DiffuseBSDF *>(this)->pdf(its, wo);
}
// Evaluate the cosine weighted BSDF value
Spectrum BSDF::eval(const Intersection &its, const Vector &wo,
                    EBSDFMode mode) const
{
    return static_cast<const DiffuseBSDF *>(this)->eval(its, wo, mode);
}

Spectrum BSDF::sample(const Intersection &its, const Array3 &sample,
                      Vector &wo, Float &pdf, Float &eta,
                      EBSDFMode mode) const
{
    return static_cast<const DiffuseBSDF *>(this)->sample(its, sample, wo, pdf, eta, mode);
}

// Check if the BSDF is transmissive
bool BSDF::isTransmissive() const
{
    return static_cast<const DiffuseBSDF *>(this)->isTransmissive();
}

// Check if the BSDF is two-sided
bool BSDF::isTwosided() const
{
    return static_cast<const DiffuseBSDF *>(this)->isTwosided();
}

// Check if the BSDF is null
bool BSDF::isNull() const
{
    return static_cast<const DiffuseBSDF *>(this)->isNull();
}
*/