#include "volpath2.h"
#include <render/common.h>
#include <render/imageblock.h>
#include <core/ray.h>
#include <core/sampler.h>
#include <render/scene.h>
#include <core/timer.h>
#include <iomanip>
#include <algorithm1.h>
#include <core/statistics.h>
#include <render/spiral.h>
#include <fmt/core.h>

namespace volpath2_meta
{
    Spectrum __Li(const Scene &scene, const Ray &_ray, const RadianceQueryRecord &rRec, LightPath *path)
    {
        if (path)
        {
            path->clear(rRec.pixel_idx);
            path->append(scene.camera); // NOTE
        }

        Spectrum ret = Spectrum::Zero(),
                 throughput = Spectrum::Ones();

        // FIXME: assume the camera is outside any shapes
        int med_id = rRec.med_id;
        const Medium *medium = (med_id != -1 ? scene.getMedium(med_id) : nullptr);

        Float pdfFailure = 1.; // keep track of the pdf of hitting a surface
        Float pdfSuccess = 1.; // keep track of the pdf of hitting a medium

        Vector preX = scene.camera.cpos;
        bool incEmission = true;

        Ray ray(_ray);
        Intersection its;
        scene.rayIntersect(ray, false, its);

        MediumSamplingRecord mRec;
        RndSampler *sampler = rRec.sampler;
        const int max_bounces = rRec.max_bounces;
        int depth = 0, null_interations = 0;
        while (depth <= max_bounces && null_interations < max_null_interactions)
        {
            bool inside_med = medium && medium->sampleDistance(ray, its.t, sampler, mRec);
            if (inside_med)
            {
                // sampled a medium interaction

                if (depth >= max_bounces)
                    break;

                if (path)
                {
                    if (incEmission)
                        path->append({mRec, med_id, pdfSuccess * mRec.pdfSuccess}); // NOTE
                    else
                        path->append({mRec, med_id, pdfSuccess * mRec.pdfSuccess * geometric(preX, mRec.p)}); // NOTE
                }

                const PhaseFunction *phase = scene.phase_list[medium->phase_id];
                throughput *= mRec.sigmaS * mRec.transmittance / mRec.pdfSuccess;

                // ====================== emitter sampling =========================
                DirectSamplingRecord dRec(mRec.p);
                Spectrum value = scene.sampleAttenuatedEmitterDirect(
                    dRec, sampler->next2D(), sampler, mRec.medium);
                if (!value.isZero(Epsilon))
                {
                    Float phaseVal = phase->eval(-ray.dir, dRec.dir);
                    if (phaseVal > Epsilon)
                    {
                        Float phasePdf = phase->pdf(-ray.dir, dRec.dir);
                        Float mis_weight = miWeight(dRec.pdf / dRec.G, phasePdf);
                        ret += throughput * value * phaseVal * mis_weight;

                        if (path)
                            path->append_nee({dRec, dRec.pdf / mis_weight}); // NOTE
                    }
                }

                // ====================== phase sampling =============================
                Vector wo;
                Float phaseVal = phase->sample(-ray.dir, sampler->next2D(), wo);
                Float phasePdf = phase->pdf(-ray.dir, wo);
                if (phaseVal < Epsilon)
                    break;

                throughput *= phaseVal;
                pdfFailure = phasePdf;
                pdfSuccess = phasePdf;
                ray = Ray(mRec.p, wo);

                value = scene.rayIntersectAndLookForEmitter(
                    ray, false, sampler, mRec.medium, its, dRec);
                if (!value.isZero(Epsilon))
                {
                    Float mis_weight = miWeight(phasePdf, dRec.pdf / dRec.G);
                    ret += throughput * value * mis_weight;

                    if (path)
                        path->append_bsdf({dRec, phasePdf * dRec.G / mis_weight}); // NOTE
                }

                // update loop variables
                incEmission = false;
                preX = mRec.p;
                depth++;
            }
            else
            {
                // sampled a surface interaction

                if (medium)
                {
                    pdfFailure *= mRec.pdfFailure;
                    pdfSuccess *= mRec.pdfFailure;
                    throughput *= mRec.transmittance / mRec.pdfFailure;
                }
                if (!its.isValid())
                    break;

                if (its.isEmitter() && incEmission)
                    ret += throughput * its.Le(-ray.dir);

                if (depth >= max_bounces)
                    break;

                // ====================== emitter sampling =========================
                DirectSamplingRecord dRec(its);
                if (!its.getBSDF()->isNull())
                {
                    if (path)
                    {
                        if (incEmission)
                            path->append({its, pdfFailure}); // NOTE
                        else
                            path->append({its, pdfFailure * geometric(preX, its.p, its.geoFrame.n)}); // NOTE
                    }

                    Spectrum value = scene.sampleAttenuatedEmitterDirect(
                        dRec, its, sampler->next2D(), sampler, medium);

                    if (!value.isZero(Epsilon))
                    {
                        Spectrum bsdfVal = its.evalBSDF(its.toLocal(dRec.dir));
                        Float bsdfPdf = its.pdfBSDF(its.toLocal(dRec.dir));
                        Float mis_weight = miWeight(dRec.pdf / dRec.G, bsdfPdf);
                        ret += throughput * value * bsdfVal * mis_weight;
                        if (path)
                            path->append_nee({dRec, dRec.pdf / mis_weight}); // NOTE
                    }
                }

                // ====================== BSDF sampling =============================
                Vector wo;
                Float bsdfPdf, bsdfEta;
                Spectrum bsdfWeight = its.sampleBSDF(sampler->next3D(), wo, bsdfPdf, bsdfEta);
                if (bsdfWeight.isZero(Epsilon))
                    break;

                wo = its.toWorld(wo);
                ray = Ray(its.p, wo);

                throughput *= bsdfWeight;
                if (its.isMediumTransition())
                {
                    med_id = its.getTargetMediumId(wo);
                    medium = its.getTargetMedium(wo);
                }
                if (its.getBSDF()->isNull())
                {
                    scene.rayIntersect(ray, true, its);
                    null_interations++;
                }
                else
                {
                    pdfFailure = bsdfPdf;
                    pdfSuccess = bsdfPdf;
                    Spectrum value = scene.rayIntersectAndLookForEmitter(
                        ray, true, sampler, medium, its, dRec);
                    if (!value.isZero(Epsilon))
                    {
                        Float mis_weight = miWeight(bsdfPdf, dRec.pdf / dRec.G);
                        ret += throughput * value * mis_weight;
                        if (path)
                            path->append_bsdf({dRec, bsdfPdf * dRec.G / mis_weight}); // NOTE
                    }
                    incEmission = false;
                    preX = ray.org;
                    depth++;
                }
            }
        }
        if (null_interations == max_null_interactions)
        {
            if (verbose)
                fprintf(stderr, "Max null interactions (%d) reached. Dead loop?\n", max_null_interactions);
            // Statistics::getInstance().getCounter("Warning", "Null interactions") += 1;
        }
        return ret;
    }

    void velocity(const Scene &scene, const Intersection &its, Float &res)
    {
        auto [x, n, J] = scene.getPoint(its);
        res = (x - detach(x)).dot(detach(n));
    }

    void d_velocity(SceneAD &sceneAD, const Intersection &its, Float &d_u)
    {
        [[maybe_unused]] Float u;
#if defined(ENZYME)
        __enzyme_autodiff((void *)velocity,
                          enzyme_dup, &sceneAD.val, &sceneAD.getDer(),
                          enzyme_const, &its,
                          enzyme_dup, &u, &d_u);
#endif
    }
} // namespace volpath_meta

Spectrum VolpathInterior::Li(const Scene &scene, const Ray &ray, RadianceQueryRecord &rRec) const
{
    return volpath2_meta::__Li(scene, ray, rRec, nullptr);
}

__attribute__((optnone)) void VolpathInterior::LiAD(SceneAD &sceneAD, const Ray &ray, RadianceQueryRecord &rRec, const Spectrum &d_res) const
{
    LightPath path;
    Spectrum value = volpath2_meta::__Li(sceneAD.val, ray, rRec, &path);
    if (value.isZero(Epsilon))
        return;

    LightPathAD pathAD(path);
    algorithm1_vol::d_eval(sceneAD.val, sceneAD.getDer(), pathAD, d_res, rRec.sampler);
}

Spectrum VolpathBoundary::Li(const Scene &scene, const Ray &ray, RadianceQueryRecord &rRec) const
{
    return Spectrum::Zero();
}

void VolpathBoundary::LiAD(SceneAD &sceneAD, const Ray &ray, RadianceQueryRecord &rRec, const Spectrum &d_res) const
{
    const Scene &scene = sceneAD.val;
    // find all null intersections
    std::vector<Intersection> its_list = scene.rayIntersectAll(ray, true);
    if (its_list.empty())
        return;
    // estimate the boundary integral
    for (size_t i = 0; i < its_list.size(); i++)
    {
        const Intersection &its = its_list[i];
        DirectSamplingRecord dRec(its.p);
        scene.sampleAttenuatedEmitterDirect(
            dRec, rRec.sampler->next2D(), rRec.sampler, nullptr);

        const Medium *medium = its.getTargetMedium(dRec.dir);
        Spectrum value = dRec.emittance *
                         scene.evalTransmittance(its.p, true, dRec.p, true, medium, rRec.sampler) /
                         dRec.pdf;

        const Medium *medium1 = scene.getMedium(0);
        auto phase = scene.getPhase(medium1->phase_id);
        Float phaseVal = phase->eval(-ray.dir, dRec.dir);

        // eq. 31
        Float transmittance = scene.evalTransmittance(scene.camera.cpos, true, its.p, true, nullptr, rRec.sampler);
        Vector v = Vector(1, 0, 0); // FIXME: hard-code translation
        Vector dir = (its.p - ray.org).normalized();
        Float d_u = (d_res * value * dRec.G * medium1->sigS(its.p) * phaseVal * transmittance / std::abs(-ray.dir.dot(its.geoFrame.n))).sum();
        volpath2_meta::d_velocity(sceneAD, its, d_u);
    }
}

ArrayXd Volpath2::renderC(const Scene &scene, const RenderOptions &options) const
{
    return volpathBoundary.renderC(scene, options);
}

ArrayXd Volpath2::renderD(SceneAD &sceneAD, const RenderOptions &options, const ArrayXd &d_image) const
{
    return volpathBoundary.renderD(sceneAD, options, d_image);
}