#pragma once

#include <core/fwd.h>
#include <core/logger.h>
#include <core/object.h>
#include <core/properties.h>

#include "tetra.hpp"

struct Ray;
struct RndSampler;
struct Medium;

static constexpr int max_null_interactions = 100;

struct MediumSamplingRecord {
    Float         t;
    Vector        p;
    Vector        wi;
    Float         pdfSuccess; // pdf of sampling the medium interaction
    Float         pdfFailure; // pdf of not sampling any medium interactions
    Spectrum      sigmaS;
    Float         sigmaT;
    const Medium *medium;
    Float         transmittance;
};

struct Medium : Object {
    Medium(int phase_id) : phase_id(phase_id) {}
    Medium(const Properties &props) { phase_id = props.get<int>("phase_id"); }
    virtual ~Medium() {}
    virtual Medium *clone() const          = 0;
    virtual void    merge(Medium *emitter) = 0;
    virtual void    setZero()              = 0;

    void setTetmesh(const std::vector<Vector>              &vertices,
                    const std::vector<Vector3i>            &indices,
                    const std::vector<std::pair<int, int>> &ids) {
        assert(vertices.size() == ids.size());
        m_tetmesh = TetrahedronMesh(vertices, indices, ids);
    }

    Vector4i getTet(int i) const { return m_tetmesh.getTet(i); }

    Vector getVertex(const Scene &scene, int i) const {
        return m_tetmesh.getVertex(scene, i);
    }

    virtual bool sampleDistance(const Ray &ray, const Float &tmax,
                                const Array2 &rnd2, RndSampler *sampler,
                                Vector   &p_scatter,
                                Spectrum &throughput) const = 0;

    virtual bool sampleDistance(const Ray &ray, const Float &tmax,
                                RndSampler           *sampler,
                                MediumSamplingRecord &mRec) const = 0;

    Float evalTransmittance(const Ray &ray, const Float &tmin,
                            const Float &tmax, RndSampler *sampler) const;

    Float evalTransmittanceRatio(const Ray &ray, const Float &tmin,
                                 const Float &tmax, RndSampler *sampler) const;

    virtual bool isHomogeneous() const = 0;

    Spectrum sigS(const Vector &x) const;
    Float    sigT(const Vector &x) const;

    int phase_id;

    // rotation and translation.
    Matrix3x3       m_volumeToWorld_R;
    Vector          m_volumeToWorld_T;
    TetrahedronMesh m_tetmesh;

    PSDR_DECLARE_VIRTUAL_CLASS()

    // --------------------------------------------------
    int m_type = -1;
};

#define PSDR_DECLARE_MEDIUM_HELPER_SIGS(CLASSNAME) \
    void __##CLASSNAME##_sigS(const CLASSNAME *medium, const Vector &x, Spectrum &ret)

#define PSDR_DECLARE_MEDIUM_HELPER_EVAL_TRANSMITTANCE(CLASSNAME) \
    void __##CLASSNAME##_evalTransmittance(const CLASSNAME *medium, const Ray &ray, const Float &tmin, const Float &tmax, RndSampler *sampler, Float &ret)

#define PSDR_DECLARE_MEDIUM_HELPER_EVAL_TRANSMITTANCE_RATIO(CLASSNAME) \
    void __##CLASSNAME##_evalTransmittanceRatio(const CLASSNAME *medium, const Ray &ray, const Float &tmin, const Float &tmax, RndSampler *sampler, Float &ret)

#define PSDR_INVOKE_MEDIUM_HELPER_SIGS(CLASSNAME, medium, x, ret) \
    __##CLASSNAME##_sigS(dynamic_cast<const CLASSNAME *>(medium), x, ret)

#define PSDR_INVOKE_MEDIUM_HELPER_EVAL_TRANSMITTANCE(CLASSNAME, medium, ray, tmin, tmax, sampler, ret) \
    __##CLASSNAME##_evalTransmittance(dynamic_cast<const CLASSNAME *>(medium), ray, tmin, tmax, sampler, ret)

#define PSDR_INVOKE_MEDIUM_HELPER_EVAL_TRANSMITTANCE_RATIO(CLASSNAME, medium, ray, tmin, tmax, sampler, ret) \
    __##CLASSNAME##_evalTransmittanceRatio(dynamic_cast<const CLASSNAME *>(medium), ray, tmin, tmax, sampler, ret)

#define PSDR_DECLARE_MEDIUM_HELPER_FUNCTIONS(CLASSNAME)       \
    PSDR_DECLARE_MEDIUM_HELPER_SIGS(CLASSNAME);               \
    PSDR_DECLARE_MEDIUM_HELPER_EVAL_TRANSMITTANCE(CLASSNAME); \
    PSDR_DECLARE_MEDIUM_HELPER_EVAL_TRANSMITTANCE_RATIO(CLASSNAME)

#define PSDR_IMPL_MEDIUM_HELPER_SIGS(CLASSNAME)                                          \
    void __##CLASSNAME##_sigS(const CLASSNAME *medium, const Vector &x, Spectrum &ret) { \
        ret = medium->sigS(x);                                                           \
    }

#define PSDR_IMPL_MEDIUM_HELPER_EVAL_TRANSMITTANCE(CLASSNAME)                                                                                                \
    void __##CLASSNAME##_evalTransmittance(const CLASSNAME *medium, const Ray &ray, const Float &tmin, const Float &tmax, RndSampler *sampler, Float &ret) { \
        ret = medium->evalTransmittance(ray, tmin, tmax, sampler);                                                                                           \
    }

#define PSDR_IMPL_MEDIUM_HELPER_EVAL_TRANSMITTANCE_RATIO(CLASSNAME)                                                                                               \
    void __##CLASSNAME##_evalTransmittanceRatio(const CLASSNAME *medium, const Ray &ray, const Float &tmin, const Float &tmax, RndSampler *sampler, Float &ret) { \
        ret = medium->evalTransmittanceRatio(ray, tmin, tmax, sampler);                                                                                           \
    }

#define PSDR_IMPL_MEDIUM_HELPER_FUNCTIONS(CLASSNAME)       \
    PSDR_IMPL_MEDIUM_HELPER_SIGS(CLASSNAME);               \
    PSDR_IMPL_MEDIUM_HELPER_EVAL_TRANSMITTANCE(CLASSNAME); \
    PSDR_IMPL_MEDIUM_HELPER_EVAL_TRANSMITTANCE_RATIO(CLASSNAME)