#pragma once

#include <core/ptr.h>
#include <render/medium.h>
#include <render/volumegrid.h>

#include <sstream>
#include <string>

struct Heterogeneous : public Medium {
    Heterogeneous(const Properties &props);
    Heterogeneous(const VolumeGrid &sigmaT,
                  const VolumeGrid &albedo,
                  Float             scale,
                  int               phase_id);
    Heterogeneous(const Heterogeneous &other);
    void    configure();
    Medium *clone() const override;
    void    merge(Medium *other) override;
    void    setZero() override;

    bool sampleDistance(const Ray &ray, const Float &tmax, const Array2 &rnd2,
                        RndSampler *sampler, Vector &p_scatter,
                        Spectrum &throughput) const override;

    bool sampleDistance(const Ray &ray, const Float &tmax, RndSampler *sampler,
                        MediumSamplingRecord &mRec) const override;

    inline Float intSigmaT(const Ray &ray, const Float &tmin,
                           const Float &tmax, RndSampler *sampler) const;

    Float evalTransmittance(const Ray &ray, const Float &tmin,
                            const Float &tmax, RndSampler *sampler) const;

    Float evalTransmittanceRatio(const Ray &ray, const Float &tmin,
                                 const Float &tmax, RndSampler *sampler) const;

    inline bool isHomogeneous() const override { return true; }

    Spectrum     sigS(const Vector &x) const;
    inline Float sigT(const Vector &x) const {
        return m_sigmaT.lookupFloat(x) * m_scale;
    };
    inline std::string toString() const {
        std::ostringstream oss;
        oss << "Homogeneous Medium...";
        return oss.str();
    }

    VolumeGrid m_sigmaT, m_albedo;
    Float      m_scale;
    Float      m_max_density;
    Float      m_inv_max_density;

    static const int TYPE_ID = 2;

    PSDR_DECLARE_CLASS(Heterogeneous)
    PSDR_IMPLEMENT_VIRTUAL_CLASS(Heterogeneous)
    //==========================================================================
};

PSDR_DECLARE_MEDIUM_HELPER_FUNCTIONS(Heterogeneous);
// namespace heterogeneous {
// __attribute__((optnone)) void __sigS(const Heterogeneous *medium,
//                                      const Vector &x, Spectrum &ret) {
//     ret = medium->sigS(x);
// }

// void *sigS_fwd(const Heterogeneous *medium, Heterogeneous *d_medium,
//                const Vector &x, Vector &d_x, //
//                Spectrum &ret, Spectrum &d_ret) {
//     return __enzyme_augmentfwd((void *) __sigS,              //
//                                enzyme_dup, medium, d_medium, //
//                                enzyme_dup, &x, &d_x,         //
//                                enzyme_dup, &ret, &d_ret);
// }

// void sigS_bwd(const Heterogeneous *medium, Heterogeneous *d_medium,
//               const Vector &x, Vector &d_x,   //
//               Spectrum &ret, Spectrum &d_ret, //
//               void *tape) {
//     __enzyme_reverse((void *) __sigS,              //
//                      enzyme_dup, medium, d_medium, //
//                      enzyme_dup, &x, &d_x,         //
//                      enzyme_dup, &ret, &d_ret,     //
//                      tape);
// }

// void *__enzyme_register_gradient_sigS[3] = {
//     (void *) __sigS,
//     (void *) sigS_fwd,
//     (void *) sigS_bwd,
// };

// Spectrum sigS(const Heterogeneous *medium, const Vector &x) {
//     Spectrum ret = Spectrum::Zero();
//     __sigS(medium, x, ret);
//     return ret;
// }
// } // namespace heterogeneous