#pragma once
#include <core/fwd.h>
#include <core/object.h>
#include <sstream>
#include <string>
struct Intersection;
struct SpatialVertex;
enum EBSDFMode
{
    ERadiance = 0,
    EImportance = 1,
    EImportanceWithCorrection = 2
};

enum EBSDF
{
    EDiffuseBSDF = 0,
};

struct BSDF : Object
{
    virtual ~BSDF() {}
    virtual BSDF *clone() const = 0;
    virtual void merge(BSDF *bsdf) = 0;
    virtual void setZero() = 0;

    // Evaluate the cosine weighted BSDF value
    Spectrum eval(const Intersection &its, const Vector &wo,
                  EBSDFMode mode = EBSDFMode::ERadiance) const;
    Float correction(const Intersection &its, const Vector &wo) const;
    // Sample the BSDF and return the BSDF value divided by pdf
    Spectrum sample(const Intersection &its, const Array3 &sample,
                    Vector &wo, Float &pdf, Float &eta,
                    EBSDFMode mode = EBSDFMode::ERadiance) const;

    // Compute the probability of sampling wo (given wi).
    //! This function should be marked as inactive
    Float pdf(const Intersection &its, const Vector3 &wo) const;

    // Check if the BSDF is transmissive
    bool isTransmissive() const;

    // Check if the BSDF is two-sided
    bool isTwosided() const;

    // Check if the BSDF is null
    bool isNull() const;

    /// Return a readable string representation of this BSDF
    virtual std::string toString() const;

    virtual std::unordered_map<std::string, ParamType> toMap() const
    {
        return {};
    };

    // ======================================================================
    virtual void __eval(const Intersection &its, const Vector &wo, EBSDFMode mode, Spectrum &ret) const = 0;
    // Sample the BSDF and return the BSDF value divided by pdf
    virtual void __sample(const Intersection &its, const Array3 &rnd, Vector &wo, Float &pdf, Float &eta, EBSDFMode mode, Spectrum &ret) const = 0;
    virtual void __pdf(const Intersection &its, const Vector3 &wo, Float &ret) const = 0;
    // ======================================================================

    // ======================================================================
    // virtual augment and gradient
    virtual void *__augment_eval(BSDF *d_bsdf, const Intersection &its, Intersection &d_its, const Vector &wo, Vector &d_wo, EBSDFMode mode, EBSDFMode d_mode, const Spectrum &ret, Spectrum &d_ret) const = 0;
    virtual void __gradient_eval(BSDF *d_bsdf, const Intersection &its, Intersection &d_its, const Vector &wo, Vector &d_wo, EBSDFMode mode, EBSDFMode d_mode, const Spectrum &ret, Spectrum &d_ret, void *tape) const = 0;

    // ======================================================================

    PSDR_DECLARE_VIRTUAL_CLASS()
};

#define DECLEAR_BSDF_HELPER_FUNCTIONS()                                                                                                                                                                      \
    void __eval(const Intersection &its, const Vector &wo, EBSDFMode mode, Spectrum &ret) const override                                                                                                     \
    {                                                                                                                                                                                                        \
        ret = eval(its, wo, mode);                                                                                                                                                                           \
    }                                                                                                                                                                                                        \
    void __sample(const Intersection &its, const Array3 &rnd, Vector &wo, Float &pdf, Float &eta, EBSDFMode mode, Spectrum &ret) const override                                                              \
    {                                                                                                                                                                                                        \
        ret = sample(its, rnd, wo, pdf, eta, mode);                                                                                                                                                          \
    }                                                                                                                                                                                                        \
    void __pdf(const Intersection &its, const Vector3 &wo, Float &ret) const override                                                                                                                        \
    {                                                                                                                                                                                                        \
        ret = pdf(its, wo);                                                                                                                                                                                  \
    }                                                                                                                                                                                                        \
    void *__augment_eval(BSDF *d_bsdf, const Intersection &its, Intersection &d_its, const Vector &wo, Vector &d_wo, EBSDFMode mode, EBSDFMode d_mode, const Spectrum &ret, Spectrum &d_ret) const override; \
    void __gradient_eval(BSDF *d_bsdf, const Intersection &its, Intersection &d_its, const Vector &wo, Vector &d_wo, EBSDFMode mode, EBSDFMode d_mode, const Spectrum &ret, Spectrum &d_ret, void *tape) const override;

#define IMPLEMENT_BSDF_HELPER_FUNCTIONS(CLASSNAME)                                                                                                                                                                             \
    namespace                                                                                                                                                                                                                  \
    {                                                                                                                                                                                                                          \
        void __eval(const CLASSNAME *bsdf, const Intersection &its, const Vector &wo, EBSDFMode mode, Spectrum &ret)                                                                                                           \
        {                                                                                                                                                                                                                      \
            ret = bsdf->eval(its, wo, mode);                                                                                                                                                                                   \
        }                                                                                                                                                                                                                      \
        void *__augment_eval(const BSDF *bsdf, BSDF *d_bsdf, const Intersection &its, Intersection &d_its, const Vector &wo, Vector &d_wo, EBSDFMode mode, EBSDFMode d_mode, const Spectrum &ret, Spectrum &d_ret)             \
        {                                                                                                                                                                                                                      \
            return __enzyme_augmentfwd((void *)__eval, enzyme_dup, bsdf, d_bsdf, enzyme_dup, &its, &d_its, enzyme_dup, &wo, &d_wo, enzyme_const, mode, enzyme_dup, &ret, &d_ret);                                              \
        }                                                                                                                                                                                                                      \
        void __gradient_eval(const BSDF *bsdf, BSDF *d_bsdf, const Intersection &its, Intersection &d_its, const Vector &wo, Vector &d_wo, EBSDFMode mode, EBSDFMode d_mode, const Spectrum &ret, Spectrum &d_ret, void *tape) \
        {                                                                                                                                                                                                                      \
            __enzyme_reverse((void *)__eval, enzyme_dup, bsdf, d_bsdf, enzyme_dup, &its, &d_its, enzyme_dup, &wo, &d_wo, enzyme_const, mode, enzyme_dup, &ret, &d_ret, tape);                                                  \
        }                                                                                                                                                                                                                      \
    }                                                                                                                                                                                                                          \
    void *CLASSNAME::__augment_eval(BSDF *d_bsdf, const Intersection &its, Intersection &d_its, const Vector &wo, Vector &d_wo, EBSDFMode mode, EBSDFMode d_mode, const Spectrum &ret, Spectrum &d_ret) const                  \
    {                                                                                                                                                                                                                          \
        return ::__augment_eval(this, d_bsdf, its, d_its, wo, d_wo, mode, d_mode, ret, d_ret);                                                                                                                                 \
    }                                                                                                                                                                                                                          \
    void CLASSNAME::__gradient_eval(BSDF *d_bsdf, const Intersection &its, Intersection &d_its, const Vector &wo, Vector &d_wo, EBSDFMode mode, EBSDFMode d_mode, const Spectrum &ret, Spectrum &d_ret, void *tape) const      \
    {                                                                                                                                                                                                                          \
        ::__gradient_eval(this, d_bsdf, its, d_its, wo, d_wo, mode, d_mode, ret, d_ret, tape);                                                                                                                                 \
    }