#include "volpath.h"
#include <render/common.h>
#include <render/imageblock.h>
#include <core/ray.h>
#include <core/sampler.h>
#include <render/scene.h>
#include <core/timer.h>
#include <iomanip>
#include <algorithm1.h>
#include <core/statistics.h>

static StatsCounter avgPathLength("Volumetric path tracer", "Average path length", EAverage);

namespace volpath_meta
{
    Spectrum __Li(const Scene &scene, const Ray &_ray, const RadianceQueryRecord &rRec, LightPath *path)
    {
        if (path)
        {
            path->clear(rRec.pixel_idx);
            path->append(scene.camera); // NOTE
        }

        Spectrum ret = Spectrum::Zero(),
                 throughput = Spectrum::Ones();

        // FIXME: assume the camera is outside any shapes
        int med_id = rRec.med_id;
        const Medium *medium = (med_id != -1 ? scene.getMedium(med_id) : nullptr);

        Float pdfFailure = 1.; // keep track of the pdf of hitting a surface
        Float pdfSuccess = 1.; // keep track of the pdf of hitting a medium

        Vector preX = scene.camera.cpos;
        bool incEmission = rRec.incEmission;

        Ray ray(_ray);
        Intersection its;
        scene.rayIntersect(ray, false, its);

        MediumSamplingRecord mRec;
        RndSampler *sampler = rRec.sampler;
        const int max_bounces = rRec.max_bounces;
        int depth = 0, null_interations = 0;
        while (depth <= max_bounces && null_interations < max_null_interactions)
        {
            bool inside_med = medium && medium->sampleDistance(ray, its.t, sampler, mRec);
            if (inside_med)
            {
                // sampled a medium interaction

                if (depth >= max_bounces)
                    break;

                if (path)
                {
                    if (incEmission)
                        path->append({mRec, med_id, pdfSuccess * mRec.sigmaT}); // NOTE
                    else
                        path->append({mRec, med_id, pdfSuccess * geometric(preX, mRec.p) * mRec.sigmaT}); // NOTE
                }

                const PhaseFunction *phase = scene.phase_list[medium->phase_id];
                throughput *= mRec.sigmaS * mRec.transmittance / mRec.pdfSuccess;
                // ====================== emitter sampling =========================
                DirectSamplingRecord dRec(mRec.p);
                Spectrum value = scene.sampleAttenuatedEmitterDirect(
                    dRec, sampler->next2D(), sampler, mRec.medium);
                if (!value.isZero(Epsilon))
                {
                    Float phaseVal = phase->eval(-ray.dir, dRec.dir);
                    if (phaseVal > Epsilon)
                    {
                        Float phasePdf = phase->pdf(-ray.dir, dRec.dir);
                        Float mis_weight = miWeight(dRec.pdf / dRec.G, phasePdf);
                        ret += throughput * value * phaseVal * mis_weight;
                        if (path)
                            path->append_nee({dRec, dRec.pdf / mis_weight}); // NOTE
                    }
                }

                // ====================== phase sampling =============================
                Vector wo;
                Float phaseVal = phase->sample(-ray.dir, sampler->next2D(), wo);
                Float phasePdf = phase->pdf(-ray.dir, wo);
                if (phaseVal < Epsilon)
                    break;

                throughput *= phaseVal;
                pdfFailure = phasePdf;
                pdfSuccess = phasePdf;
                ray = Ray(mRec.p, wo);

                value = scene.rayIntersectAndLookForEmitter(
                    ray, false, sampler, mRec.medium, its, dRec);
                if (!value.isZero(Epsilon))
                {
                    Float pdf_emitter = scene.pdfEmitterDirect(dRec);
                    Float mis_weight = miWeight(phasePdf, pdf_emitter / dRec.G);
                    ret += throughput * value * mis_weight;
                    if (path)
                        path->append_bsdf({dRec, phasePdf * dRec.G / mis_weight}); // NOTE
                }

                // update loop variables
                incEmission = false;
                preX = mRec.p;
                depth++;
            }
            else
            {
                // sampled a surface interaction

                if (medium)
                {
                    pdfFailure *= mRec.pdfFailure; // REVIEW
                    // pdfSuccess *= mRec.pdfFailure; // REVIEW
                    throughput *= mRec.transmittance / mRec.pdfFailure;
                }
                if (!its.isValid())
                    break;

                if (its.isEmitter() && incEmission)
                    ret += throughput * its.Le(-ray.dir);

                if (depth >= max_bounces)
                    break;

                // ====================== emitter sampling =========================
                DirectSamplingRecord dRec(its);
                if (!its.getBSDF()->isNull())
                {
                    if (path)
                    {
                        if (incEmission)
                            path->append({its, pdfFailure}); // REVIEW
                        else
                            path->append({its, pdfFailure * geometric(preX, its.p, its.geoFrame.n)}); // REVIEW
                    }

                    if(dynamic_cast<const EnvironmentMap *>(its.ptr_emitter))
                        break;

                    Spectrum value = scene.sampleAttenuatedEmitterDirect(
                        dRec, its, sampler->next2D(), sampler, medium);

                    if (!value.isZero(Epsilon))
                    {
                        Spectrum bsdfVal = its.evalBSDF(its.toLocal(dRec.dir));
                        Float bsdfPdf = its.pdfBSDF(its.toLocal(dRec.dir));
                        Float mis_weight = miWeight(dRec.pdf / dRec.G, bsdfPdf);
                        ret += throughput * value * bsdfVal * mis_weight;
                        if (path)
                            path->append_nee({dRec, dRec.pdf / mis_weight}); // NOTE
                    }
                }

                // ====================== BSDF sampling =============================
                Vector wo;
                Float bsdfPdf, bsdfEta;
                Spectrum bsdfWeight = its.sampleBSDF(sampler->next3D(), wo, bsdfPdf, bsdfEta);
                if (bsdfWeight.isZero(Epsilon))
                    break;

                wo = its.toWorld(wo);
                ray = Ray(its.p, wo);

                throughput *= bsdfWeight;
                if (its.isMediumTransition())
                {
                    med_id = its.getTargetMediumId(wo);
                    medium = its.getTargetMedium(wo);
                }
                if (its.getBSDF()->isNull())
                {
                    scene.rayIntersect(ray, true, its);
                    null_interations++;
                }
                else
                {
                    pdfFailure = bsdfPdf;
                    pdfSuccess = bsdfPdf;
                    Spectrum value = scene.rayIntersectAndLookForEmitter(
                        ray, true, sampler, medium, its, dRec);
                    if (!value.isZero(Epsilon))
                    {
                        Float mis_weight = miWeight(bsdfPdf, dRec.pdf / dRec.G);
                        ret += throughput * value * mis_weight;
                        if (path)
                            path->append_bsdf({dRec, bsdfPdf * dRec.G / mis_weight}); // NOTE
                    }
                    incEmission = false;
                    preX = ray.org;
                    depth++;
                }
            }
        }
        if (null_interations == max_null_interactions)
        {
            if (verbose)
                PSDR_INFO("Max null interactions {} reached. Dead loop?", max_null_interactions);
            // Statistics::getInstance().getCounter("Warning", "Null interactions") += 1;
        }
        if (depth != 0) {
            avgPathLength.incrementBase();
            avgPathLength += depth;
        }
        return ret;
    }
} // namespace volpath_meta

Spectrum Volpath::Li(const Scene &scene, const Ray &ray, RadianceQueryRecord &rRec) const
{
#if 1
    LightPath path;
    Spectrum value1 = volpath_meta::__Li(scene, ray, rRec, &path);
    Spectrum value2 = algorithm1_vol::eval(scene, path, rRec.sampler);
    return value2;
#else
    return volpath_meta::__Li(scene, ray, rRec, nullptr);
#endif
}

void Volpath::LiAD(SceneAD &sceneAD, const Ray &ray, RadianceQueryRecord &rRec, const Spectrum &d_res) const
{
    LightPath path;
    Spectrum value = volpath_meta::__Li(sceneAD.val, ray, rRec, &path);
    if (!value.isZero(Epsilon))
    {
        LightPathAD pathAD(path);
        algorithm1_vol::d_eval(sceneAD.val, sceneAD.getDer(), pathAD, d_res, rRec.sampler);
    }
}

/** Lins(x, wi) = \int f(x, wi, wo) L(x, wo) dwo
 *  L(x, wi) = Lins(x, -wi)
 *  sample f or L
 *  eval f and L
 *  FIXME: currently only works for medium interaction
 */
Spectrum Volpath::Lins(const Scene &scene, const Vector &p, const Vector &wi,
                       const Medium *medium, const RadianceQueryRecord &_rRec) const
{
    RadianceQueryRecord rRec(_rRec);
    if (_rRec.max_bounces < 0)
        return Spectrum(0.0f);
    RndSampler *sampler = rRec.sampler;

    Spectrum ret = Spectrum::Zero();
    Spectrum throughput = Spectrum::Ones();
    Ray ray(p, -wi); // incoming ray
    ray.shifted(-ShadowEpsilon);
    // compute next event estimation
    // sample an outgoing ray
    assert(medium != nullptr);
    const PhaseFunction *phase = scene.getPhase(medium->phase_id);
    // ====================== emitter sampling =========================
    // no side-effect except ret
    DirectSamplingRecord dRec(p);
    Spectrum value = scene.sampleAttenuatedEmitterDirect(
        dRec, sampler->next2D(), sampler, medium);
    if (!value.isZero(Epsilon))
    {
        Float phaseVal = phase->eval(-ray.dir, dRec.dir);
        if (phaseVal > Epsilon)
        {
            Float phasePdf = phase->pdf(-ray.dir, dRec.dir);
            Float mis_weight = miWeight(dRec.pdf / dRec.G, phasePdf);
            ret += throughput * value * phaseVal * mis_weight;
        }
    }
    // ====================== phase sampling =============================
    Vector wo;
    Float phaseVal = phase->sample(-ray.dir, sampler->next2D(), wo);
    Float phasePdf = phase->pdf(-ray.dir, wo);
    if (phaseVal < Epsilon)
        return ret;

    throughput *= phaseVal;
    ray = Ray(p, wo);
    Intersection its; // dummy
    value = scene.rayIntersectAndLookForEmitter(
        ray, false, sampler, medium, its, dRec);
    if (!value.isZero(Epsilon))
    {
        Float pdf_emitter = scene.pdfEmitterDirect(dRec);
        Float mis_weight = miWeight(phasePdf, pdf_emitter / dRec.G);
        ret += throughput * value * mis_weight;
    }

    rRec.incEmission = false;
    // FIXME: hardcoded
    rRec.med_id = 0;
    // assert((throughput * Li(scene, ray, rRec)).isZero(Epsilon));
    ret += throughput * volpath_meta::__Li(scene, ray, rRec, nullptr);
    return ret;
}

Spectrum Volpath::LiFwd(SceneAD &sceneAD, const Ray &ray, RadianceQueryRecord &rRec) const
{
    LightPath path;
    Spectrum value = volpath_meta::__Li(sceneAD.val, ray, rRec, &path);
    if (!value.isZero(Epsilon))
    {
        LightPathAD pathAD(path);
        auto [value, d_value] = algorithm1_vol::evalFwd(sceneAD.val, sceneAD.getDer(), pathAD, rRec.sampler);
        return d_value;
    }
    return Spectrum(0.0f);
}