#include "volpathBase.h"

#include <algorithm1.h>
#include <core/sampler.h>
#include <fmt/color.h>
#include <render/medium.h>
#include <render/scene.h>
namespace volpath_base {
Spectrum __Li(const Scene &scene, const Ray &_ray, const RadianceQueryRecord &rRec, LightPath *path) {
    if (path) {
        path->clear(rRec.pixel_idx);
        path->append(scene.camera); // NOTE
    }

    Spectrum ret        = Spectrum::Zero(),
             throughput = Spectrum::Ones();

    // FIXME: assume the camera is outside any shapes
    int           med_id = rRec.med_id;
    const Medium *medium = (med_id != -1 ? scene.getMedium(med_id) : nullptr);

    Float pdfFailure = 1.; // keep track of the pdf of hitting a surface
    Float pdfSuccess = 1.; // keep track of the pdf of hitting a medium

    Vector preX        = scene.camera.cpos;
    bool   incEmission = true;

    Ray          ray(_ray);
    Intersection its;
    scene.rayIntersect(ray, false, its);

    MediumSamplingRecord mRec;
    RndSampler          *sampler     = rRec.sampler;
    const int            max_bounces = rRec.max_bounces;
    int                  depth = 0, null_interations = 0;
    while (depth <= max_bounces && null_interations < max_null_interactions) {
        bool inside_med = medium && medium->sampleDistance(ray, its.t, sampler, mRec);
        if (inside_med) {
            // sampled a medium interaction

            if (depth >= max_bounces)
                break;

            if (path) {
                if (incEmission)
                    path->append({ mRec, med_id, pdfSuccess * mRec.sigmaT }); // NOTE
                else
                    path->append({ mRec, med_id, pdfSuccess * geometric(preX, mRec.p) * mRec.sigmaT }); // NOTE
            }

            const PhaseFunction *phase = scene.phase_list[medium->phase_id];
            throughput *= mRec.sigmaS * mRec.transmittance / mRec.pdfSuccess;

            // ====================== emitter sampling =========================
            DirectSamplingRecord dRec(mRec.p);
            Spectrum             value = scene.sampleAttenuatedEmitterDirect(
                            dRec, sampler->next2D(), sampler, mRec.medium);
            if (!value.isZero(Epsilon)) {
                Float phaseVal = phase->eval(-ray.dir, dRec.dir);
                if (phaseVal > Epsilon) {
                    Float phasePdf   = phase->pdf(-ray.dir, dRec.dir);
                    Float mis_weight = miWeight(dRec.pdf / dRec.G, phasePdf);
                    ret += throughput * value * phaseVal * mis_weight;

                    if (path)
                        path->append_nee({ dRec, dRec.pdf / mis_weight }); // NOTE
                }
            }

            // ====================== phase sampling =============================
            Vector wo;
            Float  phaseVal = phase->sample(-ray.dir, sampler->next2D(), wo);
            Float  phasePdf = phase->pdf(-ray.dir, wo);
            if (phaseVal < Epsilon)
                break;

            throughput *= phaseVal;
            pdfFailure = phasePdf;
            pdfSuccess = phasePdf;
            ray        = Ray(mRec.p, wo);

            value = scene.rayIntersectAndLookForEmitter(
                ray, false, sampler, mRec.medium, its, dRec);
            if (!value.isZero(Epsilon)) {
                Float pdf_emitter = scene.pdfEmitterDirect(dRec);
                Float mis_weight  = miWeight(phasePdf, pdf_emitter / dRec.G);
                ret += throughput * value * mis_weight;

                if (path)
                    path->append_bsdf({ dRec, phasePdf * dRec.G / mis_weight }); // NOTE
            }

            // update loop variables
            incEmission = false;
            preX        = mRec.p;
            depth++;
        } else {
            // sampled a surface interaction

            if (medium) {
                pdfFailure *= mRec.pdfFailure;
                // pdfSuccess *= mRec.pdfFailure;
                throughput *= mRec.transmittance / mRec.pdfFailure;
            }
            if (!its.isValid())
                break;

            if (its.isEmitter() && incEmission)
                ret += throughput * its.Le(-ray.dir);

            if (depth >= max_bounces)
                break;

            // ====================== emitter sampling =========================
            DirectSamplingRecord dRec(its);
            if (!its.getBSDF()->isNull()) {
                if (path) {
                    if (incEmission)
                        path->append({ its, pdfFailure }); // NOTE
                    else
                        path->append({ its, pdfFailure * geometric(preX, its.p, its.geoFrame.n) }); // NOTE
                }
                if (dynamic_cast<const EnvironmentMap *>(its.ptr_emitter))
                    break;

                Spectrum value = scene.sampleAttenuatedEmitterDirect(
                    dRec, its, sampler->next2D(), sampler, medium);

                if (!value.isZero(Epsilon)) {
                    Spectrum bsdfVal    = its.evalBSDF(its.toLocal(dRec.dir));
                    Float    bsdfPdf    = its.pdfBSDF(its.toLocal(dRec.dir));
                    Float    mis_weight = miWeight(dRec.pdf / dRec.G, bsdfPdf);
                    ret += throughput * value * bsdfVal * mis_weight;
                    if (path)
                        path->append_nee({ dRec, dRec.pdf / mis_weight }); // NOTE
                }
            }

            // ====================== BSDF sampling =============================
            Vector   wo;
            Float    bsdfPdf, bsdfEta;
            Spectrum bsdfWeight = its.sampleBSDF(sampler->next3D(), wo, bsdfPdf, bsdfEta);
            if (bsdfWeight.isZero(Epsilon))
                break;

            wo  = its.toWorld(wo);
            ray = Ray(its.p, wo);

            throughput *= bsdfWeight;
            if (its.isMediumTransition()) {
                med_id = its.getTargetMediumId(wo);
                medium = its.getTargetMedium(wo);
            }
            if (its.getBSDF()->isNull()) {
                scene.rayIntersect(ray, true, its);
                null_interations++;
            } else {
                pdfFailure     = bsdfPdf;
                pdfSuccess     = bsdfPdf;
                Spectrum value = scene.rayIntersectAndLookForEmitter(
                    ray, true, sampler, medium, its, dRec);
                if (!value.isZero(Epsilon)) {
                    Float mis_weight = miWeight(bsdfPdf, dRec.pdf / dRec.G);

                    ret += throughput * value * mis_weight;
                    if (path)
                        path->append_bsdf({ dRec, bsdfPdf * dRec.G / mis_weight }); // NOTE
                }
                incEmission = false;
                preX        = ray.org;
                depth++;
            }
        }
    }
    if (null_interations == max_null_interactions) {
        if (verbose)
            PSDR_INFO("Max null interactions {} reached. Dead loop?", max_null_interactions);
        // Statistics::getInstance().getCounter("Warning", "Null interactions") += 1;
    }
    return ret;
}

void velocity(const Scene &scene, const Intersection &its, Float &res) {
    auto [x, n, J] = scene.getPoint(its);
    res            = (x - detach(x)).dot(detach(n));
}

void d_velocity(SceneAD &sceneAD, const Intersection &its, Float d_u) {
    auto                  &d_scene = sceneAD.getDer();
    [[maybe_unused]] Float u;
#if defined(ENZYME)
    __enzyme_autodiff((void *) velocity,
                      enzyme_dup, &sceneAD.val, &d_scene,
                      enzyme_const, &its,
                      enzyme_dup, &u, &d_u);
#endif
}
} // namespace volpath_base

VolpathBase::VolpathBase(const Properties &props) : MISIntegrator(props) {}

Spectrum VolpathBase::neeEmitter(const Scene &scene, RndSampler *sampler,
                                 const Medium *medium, const Vector &p, const Vector &wi) {
    // assert(medium);
    DirectSamplingRecord dRec(p);
    Spectrum             value = scene.sampleAttenuatedEmitterDirect(
                    dRec, sampler->next2D(), sampler, medium);
    if (value.isZero(Epsilon))
        return Spectrum(0.);
    const PhaseFunction *phase    = scene.getPhase(medium->phase_id);
    Float                phaseVal = phase->eval(wi, dRec.dir);
    if (phaseVal < Epsilon)
        return Spectrum(0.);

    value *= phaseVal;

    // handle mis
    {
        Float mis_weight = 0.;
        Float pdfA       = dRec.pdf,
              pdfB       = phase->pdf(wi, dRec.dir);
        if (m_sampling_mode & EMISBalance) {
            mis_weight = misWeightBalance(pdfA, pdfB);
        } else if (m_sampling_mode & EMISPower) {
            mis_weight = misWeightPower(pdfA, pdfB);
        }
        value *= mis_weight;
    }

    return value;
}

Spectrum VolpathBase::neePhase(const Scene &scene, RndSampler *sampler,
                               const Medium *medium, const Vector &p, const Vector &wi) {
    assert(medium);
    Vector               wo;
    const PhaseFunction *phase    = scene.getPhase(medium->phase_id);
    Float                phaseVal = phase->sample(wi, sampler->next2D(), wo); // phase=1
    if (phaseVal < Epsilon)
        return Spectrum(0.);

    Ray                  ray(p, wo);
    Intersection         its;
    DirectSamplingRecord dRec(p);
    Spectrum             value = scene.rayIntersectAndLookForEmitter(
                    ray, false, sampler, medium, its, dRec);
    if (value.isZero(Epsilon))
        return Spectrum(0.);

    Spectrum ret = value * phaseVal;

    // handle mis
    {
        Float mis_weight = 0.;
        Float pdfA       = phase->pdf(wi, wo),
              pdfB       = scene.pdfEmitterDirect(dRec) / dRec.G;
        if (m_sampling_mode & EMISBalance) {
            mis_weight = misWeightBalance(pdfA, pdfB);
        } else if (m_sampling_mode & EMISPower) {
            mis_weight = misWeightPower(pdfA, pdfB);
        }
        value *= mis_weight;
    }

    return ret;
}

Spectrum VolpathBase::nee(const Scene &scene, RndSampler *sampler,
                          const Medium *medium, const Vector &p, const Vector &wi) {
    assert(medium);
    return neeEmitter(scene, sampler, medium, p, wi) +
           neePhase(scene, sampler, medium, p, wi);
}

// input : medium, p, wi(optional), max_bounces
Spectrum VolpathBase::Lins(const Scene &scene, RndSampler *sampler,
                           const Medium *medium, const Vector &p, const Vector &wi, int max_bounces) {
    Spectrum ret(0.);
    if (max_bounces < 0)
        return ret;
    // nee
    ret += handleNee(scene, sampler, medium, p, wi);

    // sample a direction from phase function
    const PhaseFunction *phase = scene.getPhase(medium->phase_id);
    Vector               wo;
    phase->sample(wi, sampler->next2D(), wo);

    // compute radiance
    ret += Li(scene, sampler, medium, Ray(p, wo), max_bounces - 1, false);
    return ret;
}

Spectrum VolpathBase::Li1(const Scene &scene, RndSampler *sampler,
                         const Medium *medium, const Ray &ray, Array2i pixel_idx,
                         int max_bounces, bool incEmission) {
    if (not(m_sampling_mode & ESkipSensor))
        handleSensor(scene, sampler, ray, pixel_idx, medium, max_bounces);
    if(max_bounces <= 1)
        return Spectrum(0.);
    return Li(scene, sampler, medium, ray, max_bounces, incEmission);
}

Spectrum VolpathBase::Li(const Scene &scene, RndSampler *sampler,
                         const Medium *medium, const Ray &_ray, int max_bounces, bool incEmission) {
    Spectrum ret        = Spectrum::Zero(),
             throughput = Spectrum::Ones();

    Ray          ray(_ray);
    Intersection its;
    scene.rayIntersect(ray, false, its);

    MediumSamplingRecord mRec;
    int                  depth = 0, null_interations = 0;
    // FIXME
    while (depth <= max_bounces && null_interations < max_null_interactions) {
        
        if (!isNext(depth))
            break;
        
        bool inside_med = medium && medium->sampleDistance(ray, its.t, sampler, mRec);
        if (inside_med) {
            // sampled a medium interaction

            if (depth >= max_bounces)
                break;

            const PhaseFunction *phase = scene.phase_list[medium->phase_id];
            throughput *= mRec.sigmaS * mRec.transmittance / mRec.pdfSuccess;
            // nee, emitter sampling + phase sampling
            Spectrum nee_emitter = nee(scene, sampler, medium, mRec.p, -ray.dir);
            ret += throughput * nee_emitter;
            handleNee(throughput * nee_emitter);
            // ret += throughput * handleNee(scene, sampler, medium, mRec.p, -ray.dir);


            // sample direction
            Vector wo;
            Float  phaseVal = phase->sample(-ray.dir, sampler->next2D(), wo);


            throughput *= phaseVal;

            // do something on the medium event, like branching out for boundary
            handleMedium(scene, sampler, medium, mRec.p, -ray.dir, wo, max_bounces - depth - 1, throughput);
            
            ray = Ray(mRec.p, wo);
            scene.rayIntersect(ray, false, its);

            // update loop variables
            incEmission = false;
            depth++;
        } else {
            // sampled a surface interaction
            if (medium) {
                throughput *= mRec.transmittance / mRec.pdfFailure;
            }
            if (!its.isValid())
                break;

            if (its.isEmitter()){
                ret += throughput * handleEmission(its, -ray.dir, incEmission);
                if(dynamic_cast<const EnvironmentMap *>(its.ptr_emitter))
                    break;
            }

            if (!its.getBSDF()->isNull() && depth >= max_bounces)
                break;
            
            Spectrum nee_value(0.);
            // ====================== emitter sampling =========================
            DirectSamplingRecord dRec(its);
            if (!its.getBSDF()->isNull()) {
                Spectrum value = scene.sampleAttenuatedEmitterDirect(
                    dRec, its, sampler->next2D(), sampler, medium);

                if (!value.isZero(Epsilon)) {
                    Spectrum bsdfVal    = its.evalBSDF(its.toLocal(dRec.dir));
                    Float    bsdfPdf    = its.pdfBSDF(its.toLocal(dRec.dir));
                    Float    mis_weight = miWeight(dRec.pdf / dRec.G, bsdfPdf);

                    Spectrum nee_emitter = throughput * value * bsdfVal * mis_weight;
                    nee_value += nee_emitter;
                    ret += nee_emitter;
                }
            }

            // ====================== BSDF sampling =============================
            Vector   wo;
            Float    bsdfPdf, bsdfEta;
            Spectrum bsdfWeight = its.sampleBSDF(sampler->next3D(), wo, bsdfPdf, bsdfEta);
            // NOTE: need to comment out these code, otherwise the handleNee will go wrong
            // if (bsdfWeight.isZero(Epsilon))
            //     break;

            wo  = its.toWorld(wo);
            ray = Ray(its.p, wo);

            throughput *= bsdfWeight;
            if (its.isMediumTransition()) {
                medium = its.getTargetMedium(wo);
            }
            if (its.getBSDF()->isNull()) {
                scene.rayIntersect(ray, true, its);
                null_interations++;
            } else {
                Spectrum value = scene.rayIntersectAndLookForEmitter(
                    ray, true, sampler, medium, its, dRec);
                if (!value.isZero(Epsilon)) {
                    Float    mis_weight = miWeight(bsdfPdf, dRec.pdf / dRec.G);
                    Spectrum nee_bsdf   = throughput * value * mis_weight;
                    nee_value += nee_bsdf;
                    ret += nee_bsdf;
                }
                handleNee(nee_value);

                handleSurface(scene, sampler, its, max_bounces - depth - 1, throughput);

                incEmission = false;
                depth++;
            }
        }
    }
    if (null_interations == max_null_interactions) {
        if (verbose)
            PSDR_INFO("Max null interactions {} reached. Dead loop?", max_null_interactions);
        // Statistics::getInstance().getCounter("Warning", "Null interactions") += 1;
    }
    report({.depth = depth});
    return ret;
}

Spectrum VolpathBase::handleNee(
    const Scene &scene, RndSampler *sampler,
    const Medium *medium, const Vector &p, const Vector &wi) {
    assert(medium);
    Spectrum value = nee(scene, sampler, medium, p, wi);
    handleNee(value);
    return value;
}

Spectrum VolpathBase::handleEmission(const Intersection &its, const Vector &wo, bool incEmission) {
    assert(its.isEmitter());
    Spectrum ret(0.);
    if (incEmission)
        ret = its.Le(wo);
    return ret;
}