#include <render/scene.h>
#include "algorithm1.h"
#include "algorithm1_vol.h"

NAMESPACE_BEGIN(algorithm1_vol)

void __eval(const Scene &scene, LightPath &path, RndSampler *sampler, Spectrum &ret)
{
    ret = eval(scene, path, sampler);
}

void d_eval(const Scene &scene, Scene &d_scene,
            LightPathAD &pathAD,
            Spectrum d_value, RndSampler *sampler)
{
    auto &path = pathAD.val;
    auto &d_path = pathAD.der;
    d_path.resize(pathAD.val.vertices.size());
    d_path.setZero();
    if (!getPath(scene, path))
        return;
    evalVertex(scene, path, sampler);
    Spectrum value = evalPath(scene, path);
    if (!value.allFinite())
        return;
    d_evalPath(scene, d_scene, path, d_path, d_value); // algorithm 1
    d_evalVertex(scene, d_scene, path, d_path, sampler);
    d_getPath(scene, d_scene, path, d_path);
}

void baseline(const Scene &scene, Scene &d_scene,
              LightPathAD &pathAD,
              Spectrum d_value, RndSampler *sampler)
{
    fmt::print("baseline\n");
    [[maybe_unused]] Spectrum value;
#if defined(ENZYME) && defined(VOLPATH)
    __enzyme_autodiff((void *)__eval,
                      enzyme_dup, &scene, &d_scene,
                      enzyme_dup, &pathAD.val, &pathAD.der,
                      enzyme_const, sampler,
                      enzyme_dup, &value, &d_value);
#endif
}

std::pair<Spectrum, Spectrum> evalFwd(const Scene &scene, Scene &d_scene,
                                      LightPathAD &pathAD, RndSampler *sampler)
{
    auto &path = pathAD.val;
    auto &d_path = pathAD.der;
    d_path.resize(pathAD.val.vertices.size());
    d_path.setZero();
    if (!getPath(scene, path))
        return {Spectrum(0.), Spectrum(0.)};
    getPathFwd(scene, d_scene, path, d_path);
    evalVertexFwd(scene, d_scene, path, d_path, sampler);
    auto [value, d_value] = evalPathFwd(scene, d_scene, path, d_path);
    if (!value.allFinite())
        return {Spectrum(0.), Spectrum(0.)};
    return {value, d_value};
}

std::pair<Spectrum, Spectrum> baselineFwd(const Scene &scene, Scene &d_scene,
                                          LightPathAD &pathAD, RndSampler *sampler)
{
    [[maybe_unused]] Spectrum value(0.);
    [[maybe_unused]] Spectrum d_value(0.);
#if defined(ENZYME) && defined(VOLPATH)
    __enzyme_fwddiff((void *)__eval,
                     enzyme_dup, &scene, &d_scene,
                     enzyme_dup, &pathAD.val, &pathAD.der,
                     enzyme_const, sampler,
                     enzyme_dup, &value, &d_value);
#endif
    return {value, d_value};
}

NAMESPACE_END(algorithm1_vol)