#include "algorithm1.h"
#include <render/scene.h>

NAMESPACE_BEGIN(algorithm1)

void velocity(const Scene &scene, const Intersection &its, Float &res)
{
    auto [x, n, J] = scene.getPoint(its);
    res = (x - detach(x)).dot(detach(n));
}

void d_velocity(SceneAD &sceneAD, const Intersection &its, Float d_u)
{
    auto &d_scene = sceneAD.getDer();
    [[maybe_unused]] Float u;
#if defined(ENZYME)
    __enzyme_autodiff((void *)velocity,
                      enzyme_dup, &sceneAD.val, &d_scene,
                      enzyme_const, &its,
                      enzyme_dup, &u, &d_u);
#endif
}

bool getPoint(const Scene &scene, Intersection &v)
{
    if (v.type == EVSensor)
    {
        v.p = scene.camera.cpos;
        return true;
    }
    if (v.type & EVSurface)
    {
        const Shape *shape = scene.shape_list[v.shape_id];
        const Vector3i &ind = shape->indices[v.triangle_id];
        const Vector &v0 = shape->getVertex(ind[0]),
                     &v1 = shape->getVertex(ind[1]),
                     &v2 = shape->getVertex(ind[2]);
        v.p = (1. - v.barycentric.x() - v.barycentric.y()) * v0 +
              v.barycentric.x() * v1 +
              v.barycentric.y() * v2;
        Vector geo_n = shape->getFaceNormal(v.triangle_id);
        Vector sh_n = shape->getShadingNormal(v.triangle_id, v.barycentric);
        v.geoFrame = Frame(geo_n);
        v.shFrame = Frame(sh_n);
        v.J = shape->getArea(v.triangle_id);
        v.J /= detach(v.J);
        return true;
    }
    if (v.type == EVInvalid)
        return true;
    assert(false);
    return false;
}

bool getPath(const Scene &scene, LightPath &path)
{
    for (int i = 0; i < path.vertices.size(); i++)
    {
        if (!getPoint(scene, path.vertices[i]))
            return false;
    }
    return true;
}

void d_getPoint(const Scene &scene, Scene &d_scene, const Intersection &v, Intersection &d_v)
{
#if defined(ENZYME) && defined(PATH)
    __enzyme_autodiff((void *)getPoint,
                      enzyme_dup, &scene, &d_scene,
                      enzyme_dup, &v, &d_v);
#endif
}

void d_getPath(const Scene &scene, Scene &d_scene, LightPath &path, LightPath &d_path)
{
    for (int i = 0; i < static_cast<int>(path.vertices.size()); i++)
        if (path.vertices[i].type & EVSurface || path.vertices[i].type & EVVolume)
            d_getPoint(scene, d_scene, path.vertices[i], d_path.vertices[i]);
}

Spectrum evalFirstVertexContrib(
    const Scene &scene, RndSampler *sampler, const Array2i &pixel_idx,
    const Intersection &curV, const Intersection &nextV)
{
    assert(curV.type & EVSensor);
    Spectrum value = Spectrum::Zero();
    if (nextV.type & EVVolume)
        value = scene.camera.evalFilter(pixel_idx[0], pixel_idx[1],
                                        nextV.p);
    if (nextV.type & EVSurface)
        value = scene.camera.evalFilter(pixel_idx[0], pixel_idx[1],
                                        nextV.p, nextV.geoFrame.n);

    return value;
}
void __evalFirstVertexContrib(
    const Scene &scene, RndSampler *sampler, const Array2i &pixel_idx,
    const Intersection &curV, const Intersection &nextV, Spectrum &value)
{
    value = evalFirstVertexContrib(scene, sampler, pixel_idx, curV, nextV);
}

Spectrum evalVertexContrib(
    const Scene &scene, RndSampler *sampler,
    const Intersection &preV, Intersection &curV, const Intersection &nextV)
{
    curV.wi = curV.toLocal((preV.p - curV.p).normalized());
    assert(curV.type != EVInvalid);

    Spectrum value = Spectrum::Ones();

    assert(curV.type & EVSurface);
    if (curV.type & EVSurface)
    {
        const Shape *ptr_shape = scene.shape_list[curV.shape_id];
        const BSDF *ptr_bsdf = scene.bsdf_list[ptr_shape->bsdf_id];
        value *= ptr_bsdf->eval(curV, curV.toLocal((nextV.p - curV.p).normalized()));
    }

    assert(nextV.type & EVSurface);
    if (nextV.type & EVSurface)
    {
        value *= geometric(curV.p, nextV.p, nextV.geoFrame.n);
    }
    value *= curV.J / curV.pdf;
    return value;
}
void __evalVertexContrib(
    const Scene &scene, RndSampler *sampler,
    const Intersection &preV, Intersection &curV, const Intersection &nextV, Spectrum &value)
{
    value = evalVertexContrib(scene, sampler, preV, curV, nextV);
}

Spectrum evalLastVertexContrib(const Scene &scene, RndSampler *sampler,
                               const Intersection &curV, const Intersection &neeV)
{
    const Shape *ptr_emitter = scene.shape_list[neeV.shape_id];
    const Emitter *emitter = scene.emitter_list[ptr_emitter->light_id];
    return emitter->eval(neeV.geoFrame.n, (curV.p - neeV.p).normalized()) *
           neeV.J / neeV.pdf;
}
void __evalLastVertexContrib(const Scene &scene, RndSampler *sampler,
                             const Intersection &curV, const Intersection &neeV, Spectrum &value)
{
    value = evalLastVertexContrib(scene, sampler, curV, neeV);
}

void d_evalFirstVertexContrib(const Scene &scene, Scene &d_scene,
                              RndSampler *sampler, const Array2i &pixel_idx,
                              const Intersection &curV, Intersection &d_curV,
                              const Intersection &nextV, Intersection &d_nextV,
                              Spectrum d_value)
{
    [[maybe_unused]] Spectrum value;
#if defined(ENZYME) && defined(PATH)
    __enzyme_autodiff((void *)__evalFirstVertexContrib,
                      enzyme_dup, &scene, &d_scene,
                      enzyme_const, sampler,
                      enzyme_const, &pixel_idx,
                      enzyme_dup, &curV, &d_curV,
                      enzyme_dup, &nextV, &d_nextV,
                      enzyme_dup, &value, &d_value);
#endif
}
void d_evalVertexContrib(const Scene &scene, Scene &d_scene,
                         RndSampler *sampler,
                         const Intersection &preV, Intersection &d_preV,
                         const Intersection &curV, Intersection &d_curV,
                         const Intersection &nextV, Intersection &d_nextV,
                         Spectrum d_value)
{
    [[maybe_unused]] Spectrum value;
#if defined(ENZYME) && defined(PATH)
    __enzyme_autodiff((void *)__evalVertexContrib,
                      enzyme_dup, &scene, &d_scene,
                      enzyme_const, sampler,
                      enzyme_dup, &preV, &d_preV,
                      enzyme_dup, &curV, &d_curV,
                      enzyme_dup, &nextV, &d_nextV,
                      enzyme_dup, &value, &d_value);
#endif
}
void d_evalLastVertexContrib(const Scene &scene, Scene &d_scene,
                             RndSampler *sampler,
                             const Intersection &curV, Intersection &d_curV,
                             const Intersection &neeV, Intersection &d_neeV,
                             Spectrum d_value)
{
    [[maybe_unused]] Spectrum value;
#if defined(ENZYME) && defined(PATH)
    __enzyme_autodiff((void *)__evalLastVertexContrib,
                      enzyme_dup, &scene, &d_scene,
                      enzyme_const, sampler,
                      enzyme_dup, &curV, &d_curV,
                      enzyme_dup, &neeV, &d_neeV,
                      enzyme_dup, &value, &d_value);
#endif
}

void evalVertex(const Scene &scene, LightPath &path, RndSampler *sampler)
{
    const Array2i &pixel_idx = path.pixel_idx;
    auto &vertices = path.vertices;
    const auto &vtxIds = path.vs;

    Intersection &camV = vertices[0];
    if (path.vs.size() <= 1 || vertices[vtxIds[1]].type == EVInvalid)
        return;
    Intersection &firstV = vertices[vtxIds[1]];
    camV.value = evalFirstVertexContrib(scene, sampler, pixel_idx, camV, firstV);
    if (firstV.type == EVEmitter)
    {
        camV.nee_bsdf = evalLastVertexContrib(scene, sampler, camV, firstV);
    }
    for (int i = 1; i < vtxIds.size(); i++)
    {
        Intersection &preV = vertices[vtxIds[i - 1]];
        Intersection &curV = vertices[vtxIds[i]];
        if (i < vtxIds.size() - 1)
        {
            Intersection &nextV = vertices[vtxIds[i + 1]];
            curV.value = evalVertexContrib(scene, sampler, preV, curV, nextV);
        }
        // handle nee: emitter sampling and bsdf sampling
        for (int nee_id : {curV.l_nee_id, curV.l_bsdf_id})
        {
            if (nee_id == -1)
                continue;
            // compute bsdf/phase * (sig_s) * transmittance * G * J / pdf
            Intersection &neeV = vertices[nee_id];
            Spectrum value = evalVertexContrib(scene, sampler, preV, curV, neeV);
            if (nee_id == curV.l_nee_id)
                curV.nee_bsdf = value;
            if (nee_id == curV.l_bsdf_id)
                curV.bsdf_bsdf = value;

            // compute Le
            neeV.value = evalLastVertexContrib(scene, sampler, curV, neeV);
        }
    }
}

void d_evalVertex(const Scene &scene, Scene &d_scene,
                  const LightPath &path, LightPath &d_path,
                  RndSampler *sampler)
{
    const Array2i &pixel_idx = path.pixel_idx;
    auto &vertices = path.vertices;
    auto &d_vertices = d_path.vertices;
    const auto &vtxIds = path.vs;

    const Intersection &camV = vertices[0];
    Intersection &d_camV = d_vertices[0];
    if (path.vs.size() <= 1 || vertices[vtxIds[1]].type == EVInvalid)
        return;
    const Intersection &firstV = vertices[vtxIds[1]];
    Intersection &d_firstV = d_vertices[vtxIds[1]];
    d_evalFirstVertexContrib(
        scene, d_scene, sampler, pixel_idx,
        camV, d_camV, firstV, d_firstV,
        d_camV.value);

    if (firstV.type == EVEmitter)
    {
        d_evalLastVertexContrib(
            scene, d_scene,
            sampler,
            camV, d_camV, firstV, d_firstV,
            d_camV.nee_bsdf);
    }

    for (int i = 1; i < vtxIds.size(); i++)
    {
        const Intersection &preV = vertices[vtxIds[i - 1]];
        Intersection &d_preV = d_vertices[vtxIds[i - 1]];
        const Intersection &curV = vertices[vtxIds[i]];
        Intersection &d_curV = d_vertices[vtxIds[i]];
        if (i < vtxIds.size() - 1)
        {
            const Intersection &nextV = vertices[vtxIds[i + 1]];
            Intersection &d_nextV = d_vertices[vtxIds[i + 1]];
            d_evalVertexContrib(
                scene, d_scene, sampler,
                preV, d_preV, curV, d_curV, nextV, d_nextV,
                d_curV.value);
        }
        // handle nee: emitter sampling and bsdf sampling
        for (int nee_id : {curV.l_nee_id, curV.l_bsdf_id})
        {
            if (nee_id == -1)
                continue;
            // compute bsdf/phase * (sig_s) * transmittance * G * J / pdf
            const Intersection &neeV = vertices[nee_id];
            Intersection &d_neeV = d_vertices[nee_id];
            if (nee_id == curV.l_nee_id)
                d_evalVertexContrib(
                    scene, d_scene, sampler,
                    preV, d_preV, curV, d_curV, neeV, d_neeV,
                    d_curV.nee_bsdf);
            if (nee_id == curV.l_bsdf_id)
                d_evalVertexContrib(
                    scene, d_scene, sampler,
                    preV, d_preV, curV, d_curV, neeV, d_neeV,
                    d_curV.bsdf_bsdf);
            // compute Le
            d_evalLastVertexContrib(
                scene, d_scene,
                sampler,
                curV, d_curV, neeV, d_neeV,
                d_neeV.value);
        }
    }
}

Spectrum evalPath(const Scene &scene, LightPath &path)
{
    Spectrum value = Spectrum::Zero();
    Spectrum throughput = Spectrum::Ones();

    if (path.vs.size() <= 1 || path.vertices[path.vs[1]].type == EVInvalid)
        return Spectrum::Zero();
    const Intersection &camV = path.vertices[0];
    const Intersection &firstV = path.vertices[path.vs[1]];
    if (firstV.type == EVEmitter)
        value += camV.value * camV.nee_bsdf;

    for (int i = 0; i < path.vs.size() - 1; i++)
    {
        Intersection &preV = path.vertices[path.vs.at(i)];
        Intersection &curV = path.vertices[path.vs.at(i + 1)];
        throughput *= preV.value;

        // emitter sampling
        if (curV.l_nee_id >= 0)
        {
            Intersection &neeV = path.vertices[curV.l_nee_id];
            value += throughput * curV.nee_bsdf * neeV.value;
        }

        // bsdf sampling
        if (curV.l_bsdf_id >= 0)
        {
            Intersection &bsdfV = path.vertices[curV.l_bsdf_id];
            value += throughput * curV.bsdf_bsdf * bsdfV.value;
        }
    }
    return value;
}

void __evalPath(const Scene &scene, LightPath &path, Spectrum &value)
{
    value = evalPath(scene, path);
}

void d_evalPath(const Scene &scene, Scene &d_scene,
                const LightPath &path, LightPath &d_path,
                Spectrum d_value)
{
    [[maybe_unused]] Spectrum value;
#if defined(ENZYME) && defined(PATH)
    __enzyme_autodiff((void *)__evalPath,
                      enzyme_dup, &scene, &d_scene,
                      enzyme_dup, &path, &d_path,
                      enzyme_dup, &value, &d_value);
#endif
}

__attribute__((optnone)) Spectrum eval(const Scene &scene, LightPath &path, RndSampler *sampler)
{
    // LightPath path(_path);
    if (!getPath(scene, path))
        return Spectrum::Zero();
    evalVertex(scene, path, sampler);
    Spectrum ret = evalPath(scene, path);
    return ret;
}

void __eval(const Scene &scene, LightPath &path, RndSampler *sampler, Spectrum &ret)
{
    ret = eval(scene, path, sampler);
}

void d_eval(const Scene &scene, Scene &d_scene,
            LightPathAD &pathAD,
            Spectrum d_value, RndSampler *sampler)
{
    auto &path = pathAD.val;
    auto &d_path = pathAD.der;
    d_path.resize(pathAD.val.vertices.size());
    d_path.setZero();
    if (!getPath(scene, path))
        return;
    evalVertex(scene, path, sampler);
    Spectrum value = evalPath(scene, path);
    if (!value.allFinite())
        return;
    d_evalPath(scene, d_scene, path, d_path, d_value); // algorithm 1
    d_evalVertex(scene, d_scene, path, d_path, sampler);
    d_getPath(scene, d_scene, path, d_path);
}

void baseline(const Scene &scene, Scene &d_scene,
              LightPathAD &pathAD,
              Spectrum d_value, RndSampler *sampler)
{
    [[maybe_unused]] Spectrum value;
#if defined(ENZYME) && defined(PATH)
    __enzyme_autodiff((void *)__eval,
                      enzyme_dup, &scene, &d_scene,
                      enzyme_dup, &pathAD.val, &pathAD.der,
                      enzyme_const, sampler,
                      enzyme_dup, &value, &d_value);
#endif
}

void __evalFwd(const Scene &scene, LightPath &path, RndSampler *sampler, Spectrum &ret) {
    ret = evalFwd(scene, path, sampler);
}

std::pair<Spectrum, Spectrum> evalPathFwd(const Scene &scene, Scene &d_scene,
            const LightPath &path, LightPath &d_path) {
    [[maybe_unused]] Spectrum value(0.);
    [[maybe_unused]] Spectrum d_value(0.);
#if defined(ENZYME) && defined(VOLPATH)
    __enzyme_fwddiff((void *)__evalPath,
                     enzyme_dup, &scene, &d_scene,
                     enzyme_dup, &path, &d_path,
                     enzyme_dup, &value, &d_value);
#endif
    return {value, d_value};
}

void getPointFwd(const Scene &scene, Scene &d_scene, const Intersection &v, Intersection &d_v) {
    #if defined(ENZYME)
    __enzyme_fwddiff((void *)getPoint,
                     enzyme_dup, &scene, &d_scene,
                     enzyme_dup, &v, &d_v);
    #endif
}

void getPathFwd(const Scene &scene, Scene &d_scene, LightPath &path, LightPath &d_path) {
    for (int i = 0; i < path.vertices.size(); i++)
        if (path.vertices[i].type & EVSurface || path.vertices[i].type & EVVolume)
            getPointFwd(scene, d_scene, path.vertices[i], d_path.vertices[i]);
}

void evalVertexFwd(const Scene &scene, Scene &d_scene,
                   const LightPath &path, LightPath &d_path,
                   RndSampler *sampler) {
    [[maybe_unused]] Spectrum value(0.);
    [[maybe_unused]] Spectrum d_value(0.);
    #ifdef ENZYME
    __enzyme_fwddiff((void *)evalVertex,
                     enzyme_dup, &scene, &d_scene,
                     enzyme_dup, &path, &d_path,
                     enzyme_const, sampler);
    #endif
}

Spectrum evalFwd(const Scene &scene, LightPath &path, RndSampler *sampler) {
    if (!getPath(scene, path))
        return Spectrum::Zero();
    const Array2i &pixel_idx = path.pixel_idx;

    Spectrum ret = Spectrum::Zero();
    Spectrum throughput = Spectrum::Ones();

    if (path.vs.size() <= 1 || path.vertices[path.vs[1]].type == EVInvalid)
        return Spectrum::Zero();
    const Intersection &camV = path.vertices[0];
    const Intersection &firstV = path.vertices[path.vs[1]];
    throughput *= evalFirstVertexContrib(scene, sampler, pixel_idx, camV, firstV); 
    if (firstV.type == EVEmitter) {
        ret += throughput * evalLastVertexContrib(scene, sampler, camV, firstV);
    }

    for (int i = 0; i < path.vs.size() - 1; i++) {
        Intersection &preV = path.vertices[path.vs.at(i)];
        Intersection &curV = path.vertices[path.vs.at(i + 1)];

        // emitter sampling
        if (curV.l_nee_id >= 0) {
            Intersection &neeV = path.vertices[curV.l_nee_id];
            auto bsdf_val = evalVertexContrib(scene, sampler, preV, curV, neeV);
            auto value = evalLastVertexContrib(scene, sampler, curV, neeV);
            ret += throughput * bsdf_val * value;
        }

        // bsdf sampling
        if (curV.l_bsdf_id >= 0) {
            Intersection &bsdfV = path.vertices[curV.l_bsdf_id];
            auto bsdf_val = evalVertexContrib(scene, sampler, preV, curV, bsdfV);
            auto value = evalLastVertexContrib(scene, sampler, curV, bsdfV);
            ret += throughput * bsdf_val * value;
        }

        if (i < path.vs.size() - 2) {
            Intersection &nextV = path.vertices[path.vs.at(i + 2)];
            throughput *= evalVertexContrib(scene, sampler, preV, curV, nextV);
        }
    }
    return ret;
}

std::pair<Spectrum, Spectrum> d_evalFwd(const Scene &scene, Scene &d_scene, LightPathAD &pathAD, RndSampler *sampler) {
#if 1
    auto &path = pathAD.val;
    auto &d_path = pathAD.der;
    d_path.resize(pathAD.val.vertices.size());
    d_path.setZero();
    if (!getPath(scene, path))
        return {Spectrum(0.), Spectrum(0.)};
    getPathFwd(scene, d_scene, path, d_path);
    evalVertexFwd(scene, d_scene, path, d_path, sampler);
    auto [value, d_value] = evalPathFwd(scene, d_scene, path, d_path);
    if (!value.allFinite())
        return {Spectrum(0.), Spectrum(0.)};
    return {value, d_value};
#else
    [[maybe_unused]] Spectrum value(0.);
    [[maybe_unused]] Spectrum d_value(0.);
#if defined(ENZYME)
    __enzyme_fwddiff((void *)__evalFwd,
                     enzyme_dup, &scene, &d_scene,
                     enzyme_dup, &pathAD.val, &pathAD.der,
                     enzyme_const, sampler,
                     enzyme_dup, &value, &d_value);
#endif
    return {value, d_value};
#endif
}

NAMESPACE_END(algorithm1)
