#include <render/scene.h>
#include "algorithm1.h"
#include "algorithm1_vol.h"

NAMESPACE_BEGIN(algorithm1_vol)

void __evalFirstVertexContrib(
    const Scene &scene, RndSampler *sampler, const Array2i &pixel_idx,
    const Intersection &curV, const Intersection &nextV, Spectrum &value)
{
    value = evalFirstVertexContrib(scene, sampler, pixel_idx, curV, nextV);
}

void __evalVertexContrib(
    const Scene &scene, RndSampler *sampler,
    const Intersection &preV, Intersection &curV, const Intersection &nextV, Spectrum &value)
{
    value = evalVertexContrib(scene, sampler, preV, curV, nextV);
}

void __evalLastVertexContrib(const Scene &scene,
                             const Intersection &curV, const Intersection &neeV, Spectrum &value)
{
    value = evalLastVertexContrib(scene, curV, neeV);
}

void d_evalFirstVertexContrib(const Scene &scene, Scene &d_scene,
                              RndSampler *sampler, const Array2i &pixel_idx,
                              const Intersection &curV, Intersection &d_curV,
                              const Intersection &nextV, Intersection &d_nextV,
                              Spectrum d_value)
{
    [[maybe_unused]] Spectrum value;
#if defined(ENZYME) && defined(VOLPATH)
    __enzyme_autodiff((void *)__evalFirstVertexContrib,
                      enzyme_dup, &scene, &d_scene,
                      enzyme_const, sampler,
                      enzyme_const, &pixel_idx,
                      enzyme_dup, &curV, &d_curV,
                      enzyme_dup, &nextV, &d_nextV,
                      enzyme_dup, &value, &d_value);
#endif
}

void d_evalVertexContrib(const Scene &scene, Scene &d_scene,
                         RndSampler *sampler,
                         const Intersection &preV, Intersection &d_preV,
                         const Intersection &curV, Intersection &d_curV,
                         const Intersection &nextV, Intersection &d_nextV,
                         Spectrum d_value)
{
    [[maybe_unused]] Spectrum value;
#if defined(ENZYME) && defined(VOLPATH)
    __enzyme_autodiff((void *)__evalVertexContrib,
                      enzyme_dup, &scene, &d_scene,
                      enzyme_const, sampler,
                      enzyme_dup, &preV, &d_preV,
                      enzyme_dup, &curV, &d_curV,
                      enzyme_dup, &nextV, &d_nextV,
                      enzyme_dup, &value, &d_value);
#endif
}
void d_evalLastVertexContrib(const Scene &scene, Scene &d_scene,
                             const Intersection &curV, Intersection &d_curV,
                             const Intersection &neeV, Intersection &d_neeV,
                             Spectrum d_value)
{
    [[maybe_unused]] Spectrum value;
#if defined(ENZYME) && defined(VOLPATH)
    __enzyme_autodiff((void *)__evalLastVertexContrib,
                      enzyme_dup, &scene, &d_scene,
                      enzyme_dup, &curV, &d_curV,
                      enzyme_dup, &neeV, &d_neeV,
                      enzyme_dup, &value, &d_value);
#endif
}

void d_evalVertex(const Scene &scene, Scene &d_scene,
                  const LightPath &path, LightPath &d_path,
                  RndSampler *sampler)
{
    const Array2i &pixel_idx = path.pixel_idx;
    auto &vertices = path.vertices;
    auto &d_vertices = d_path.vertices;
    const auto &vtxIds = path.vs;

    const Intersection &camV = vertices[0];
    Intersection &d_camV = d_vertices[0];
    if (path.vs.size() <= 1 || vertices[vtxIds[1]].type == EVInvalid)
        return;
    const Intersection &firstV = vertices[vtxIds[1]];
    Intersection &d_firstV = d_vertices[vtxIds[1]];
    d_evalFirstVertexContrib(
        scene, d_scene, sampler, pixel_idx,
        camV, d_camV, firstV, d_firstV,
        d_camV.value);

    if (firstV.type == EVEmitter)
    {
        d_evalLastVertexContrib(
            scene, d_scene,
            camV, d_camV, firstV, d_firstV,
            d_camV.nee_bsdf);
    }

    for (int i = 1; i < vtxIds.size(); i++)
    {
        const Intersection &preV = vertices[vtxIds[i - 1]];
        Intersection &d_preV = d_vertices[vtxIds[i - 1]];
        const Intersection &curV = vertices[vtxIds[i]];
        Intersection &d_curV = d_vertices[vtxIds[i]];
        if (i < vtxIds.size() - 1)
        {
            const Intersection &nextV = vertices[vtxIds[i + 1]];
            Intersection &d_nextV = d_vertices[vtxIds[i + 1]];
            d_evalVertexContrib(
                scene, d_scene, sampler,
                preV, d_preV, curV, d_curV, nextV, d_nextV,
                d_curV.value);
        }
        // handle nee: emitter sampling and bsdf sampling
        for (int nee_id : {curV.l_nee_id, curV.l_bsdf_id})
        {
            if (nee_id == -1)
                continue;
            // compute bsdf/phase * (sig_s) * transmittance * G * J / pdf
            const Intersection &neeV = vertices[nee_id];
            Intersection &d_neeV = d_vertices[nee_id];
            if (nee_id == curV.l_nee_id)
                d_evalVertexContrib(
                    scene, d_scene, sampler,
                    preV, d_preV, curV, d_curV, neeV, d_neeV,
                    d_curV.nee_bsdf);
            if (nee_id == curV.l_bsdf_id)
                d_evalVertexContrib(
                    scene, d_scene, sampler,
                    preV, d_preV, curV, d_curV, neeV, d_neeV,
                    d_curV.bsdf_bsdf);
            // compute Le
            d_evalLastVertexContrib(
                scene, d_scene,
                curV, d_curV, neeV, d_neeV,
                d_neeV.value);
        }
    }
}

void evalVertexFwd(const Scene &scene, Scene &d_scene,
                   const LightPath &path, LightPath &d_path,
                   RndSampler *sampler)
{
    [[maybe_unused]] Spectrum value(0.);
    [[maybe_unused]] Spectrum d_value(0.);
    __enzyme_fwddiff((void *)evalVertex,
                     enzyme_dup, &scene, &d_scene,
                     enzyme_dup, &path, &d_path,
                     enzyme_const, sampler);
}

NAMESPACE_END(algorithm1_vol)
