#include <core/sampler.h>
#include <medium/heterogeneous.h>

Heterogeneous::Heterogeneous(const Properties &props) : Medium(props) {
    this->m_type = TYPE_ID;
    m_scale  = props.get<Float>("scale");
    m_sigmaT = VolumeGrid(props.get<Properties>("density"));
    m_albedo = VolumeGrid(props.get<Properties>("albedo"));
    configure();
}

Heterogeneous::Heterogeneous(const VolumeGrid &sigmaT,
                             const VolumeGrid &albedo,
                             Float             scale,
                             int               phase_id)
    : Medium(phase_id) {
    this->m_type = TYPE_ID;
    m_scale  = scale;
    m_sigmaT = sigmaT;
    m_albedo = albedo;
    configure();
}

Heterogeneous::Heterogeneous(const Heterogeneous &other)
    : Medium(other.phase_id) {
    this->m_type = TYPE_ID;
    m_sigmaT = other.m_sigmaT;
    m_albedo = other.m_albedo;
    m_scale  = other.m_scale;
    configure();
}

void Heterogeneous::configure() {
    m_max_density     = m_sigmaT.max() * m_scale;
    m_inv_max_density = 1.0 / m_max_density;
}

Medium *Heterogeneous::clone() const {
    return new Heterogeneous(*this);
}

void Heterogeneous::merge(Medium *other) {
    Heterogeneous *m = dynamic_cast<Heterogeneous *>(other);
    m_sigmaT.merge(m->m_sigmaT);
    m_albedo.merge(m->m_albedo);
    m_scale += m->m_scale;
}

void Heterogeneous::setZero() {
    m_sigmaT.setZero();
    m_albedo.setZero();
    m_scale = 0.0;
}

bool Heterogeneous::sampleDistance(const Ray &ray, const Float &tmax,
                                   const Array2 &rnd2, RndSampler *sampler,
                                   Vector   &p_scatter,
                                   Spectrum &throughput) const {
    // Throw("Not implemented");
    assert(false);
    return false;
}
// return a variable that has value 1 and gradient -\int \dot\sigma_t(x) dx
inline Float Heterogeneous::intSigmaT(const Ray &ray, const Float &tmin,
                                      const Float &tmax, RndSampler *sampler) const {
    Float         int_sigmaT          = 0.;
    constexpr int lineIntegralSamples = 20; // hyperparameter
    const Float   len                 = tmax - tmin;
    for (int i = 0; i < lineIntegralSamples; ++i) {
        Float rnd    = static_cast<Float>(i + sampler->next1D()) / lineIntegralSamples;
        rnd          = detach(rnd);
        Float sigmaT = m_scale * m_sigmaT.lookupFloat(ray(tmin + len * rnd));
        int_sigmaT += sigmaT;
    }
    int_sigmaT *= len / static_cast<Float>(lineIntegralSamples);
    return -int_sigmaT + detach(int_sigmaT) + 1; // val : 1
}

Float Heterogeneous::evalTransmittance(const Ray &ray, const Float &tmin,
                                       const Float &tmax,
                                       RndSampler  *sampler) const {
    Float mint, maxt;
    if (!m_sigmaT.getAABB().rayIntersect(ray, mint, maxt))
        return 1.;
    mint               = std::max(tmin, mint);
    maxt               = std::min(tmax, maxt);
    const int nSamples = 20;
    Float     result   = 0;

    for (int i = 0; i < nSamples; i++) {
        Float t = mint;
        while (true) {
            t -= std::log(sampler->next1D()) * m_inv_max_density;
            if (t >= maxt) {
                result += 1;
                break;
            }
            Vector p       = ray(t);
            Float  density = m_sigmaT.lookupFloat(p) * m_scale;
            if (density * m_inv_max_density > sampler->next1D())
                break;
        }
    }
    Float ret        = result / nSamples;
    Float int_sigmaT = intSigmaT(ray, mint, maxt, sampler);
    return int_sigmaT * detach(ret);
}

Float Heterogeneous::evalTransmittanceRatio(
    const Ray &ray, const Float &tmin,
    const Float &tmax, RndSampler *sampler) const {
    return intSigmaT(ray, tmin, tmax, sampler);
}

bool Heterogeneous::sampleDistance(const Ray &ray, const Float &tmax,
                                   RndSampler           *sampler,
                                   MediumSamplingRecord &mRec) const {
    // delta tracking
    // the following information is invalid
    mRec.pdfFailure    = 1.0;
    mRec.pdfSuccess    = 1.0;
    mRec.transmittance = 1.;
    Float mint, maxt;
    if (!m_sigmaT.getAABB().rayIntersect(ray, mint, maxt)) {
        return false;
    }
    mint          = std::max(0., mint);
    maxt          = std::min(tmax, maxt);
    bool  success = false;
    Float t = mint, density_at_t = 0;
    for (int depth = 0;; depth++) {
        if (depth > 10000) {
            PSDR_WARN("Heterogeneous::sampleDistance: depth > 10000");
            break;
        }
        // sample a medium interaction
        t -= std::log(sampler->next1D()) * m_inv_max_density;
        // if we are outside the volume, return false
        if (t >= maxt) {
            mRec.t = maxt;
            mRec.p = ray(mRec.t);
            break;
        }
        Vector p     = ray(t);
        density_at_t = m_sigmaT.lookupFloat(p) * m_scale;
        // decide if this is real interaction or null interaction
        if (density_at_t * m_inv_max_density > sampler->next1D()) {
            mRec.t          = t;
            mRec.p          = p;
            Spectrum albedo = m_albedo.lookupSpectrum(p);
            mRec.sigmaS     = albedo * density_at_t;
            mRec.sigmaT     = density_at_t;
            if (density_at_t != 0.)
                mRec.transmittance = 1. / density_at_t;
            else
                mRec.transmittance = 0.;
            if (!std::isfinite(mRec.transmittance)) {
                // WARN("Transmittance is not finite");
                mRec.transmittance = 0.;
            }
            mRec.wi     = -ray.dir;
            mRec.medium = this;
            success     = true;
            break;
        }
    }
    return success;
}

Spectrum Heterogeneous::sigS(const Vector &x) const {
    Spectrum albedo = m_albedo.lookupSpectrum(x);
    Float    sigmaT = m_sigmaT.lookupFloat(x);
    return albedo * sigmaT * m_scale;
}

PSDR_IMPL_MEDIUM_HELPER_FUNCTIONS(Heterogeneous)