#include "algorithm1.h"
#include <render/scene.h>

NAMESPACE_BEGIN(algorithm1_pathwas)

bool getPoint(const Scene &scene, Intersection &v)
{
    if (v.type == EVSensor)
    {
        v.p = scene.camera.cpos;
        return true;
    }
    if (v.type & EVSurface)
    {
        const Shape *shape = scene.shape_list[v.shape_id];
        const Vector3i &ind = shape->indices[v.triangle_id];
        const Vector &v0 = shape->getVertex(ind[0]),
                     &v1 = shape->getVertex(ind[1]),
                     &v2 = shape->getVertex(ind[2]);
        v.p = (1. - v.barycentric.x() - v.barycentric.y()) * v0 +
              v.barycentric.x() * v1 +
              v.barycentric.y() * v2;
        Vector geo_n = shape->getFaceNormal(v.triangle_id);
        Vector sh_n = shape->getShadingNormal(v.triangle_id, v.barycentric);
        v.geoFrame = Frame(geo_n);
        v.shFrame = Frame(sh_n);
        v.J = shape->getArea(v.triangle_id);
        v.J /= detach(v.J);
        return true;
    }
    if (v.type == EVInvalid)
        return true;
    assert(false);
    return false;
}

bool getPath(const Scene &scene, LightPathWAS&path)
{
    for (int i = 0; i < path.vertices.size(); i++)
    {
        if (!getPoint(scene, path.vertices[i]))
            return false;
    }
    return true;
}

Spectrum evalFirstVertexContrib(
    const Scene &scene, RndSampler *sampler, const Array2i &pixel_idx,
    const Intersection &curV, const Intersection &nextV)
{
    assert(curV.type & EVSensor);
    Spectrum value = Spectrum::Zero();
    if (nextV.type & EVVolume)
        value = scene.camera.evalFilter(pixel_idx[0], pixel_idx[1],
                                        nextV.p);
    if (nextV.type & EVSurface)
        value = scene.camera.evalFilter(pixel_idx[0], pixel_idx[1],
                                        nextV.p, nextV.geoFrame.n);

    return value;
}
void __evalFirstVertexContrib(
    const Scene &scene, RndSampler *sampler, const Array2i &pixel_idx,
    const Intersection &curV, const Intersection &nextV, Spectrum &value)
{
    value = evalFirstVertexContrib(scene, sampler, pixel_idx, curV, nextV);
}

Spectrum evalVertexContrib(
    const Scene &scene, RndSampler *sampler,
    const Intersection &preV, Intersection &curV, const Intersection &nextV)
{
    curV.wi = curV.toLocal((preV.p - curV.p).normalized());
    assert(curV.type != EVInvalid);

    Spectrum value = Spectrum::Ones();

    assert(curV.type & EVSurface);
    if (curV.type & EVSurface)
    {
        const Shape *ptr_shape = scene.shape_list[curV.shape_id];
        const BSDF *ptr_bsdf = scene.bsdf_list[ptr_shape->bsdf_id];
        value *= ptr_bsdf->eval(curV, curV.toLocal((nextV.p - curV.p).normalized()));
    }

    assert(nextV.type & EVSurface);
    if (nextV.type & EVSurface)
    {
        value *= geometric(curV.p, nextV.p, nextV.geoFrame.n);
    }
    value *= curV.J / curV.pdf;
    return value;
}
void __evalVertexContrib(
    const Scene &scene, RndSampler *sampler,
    const Intersection &preV, Intersection &curV, const Intersection &nextV, Spectrum &value)
{
    value = evalVertexContrib(scene, sampler, preV, curV, nextV);
}

Spectrum evalLastVertexContrib(const Scene &scene,
                               const Intersection &curV, const Intersection &neeV)
{
    const Shape *ptr_emitter = scene.shape_list[neeV.shape_id];
    const Emitter *emitter = scene.emitter_list[ptr_emitter->light_id];
    return emitter->eval(neeV.geoFrame.n, (curV.p - neeV.p).normalized()) *
           neeV.J / neeV.pdf;
}
void __evalLastVertexContrib(const Scene &scene,
                             const Intersection &curV, const Intersection &neeV, Spectrum &value)
{
    value = evalLastVertexContrib(scene, curV, neeV);
}

void d_evalFirstVertexContrib(const Scene &scene, Scene &d_scene,
                              RndSampler *sampler, const Array2i &pixel_idx,
                              const Intersection &curV, Intersection &d_curV,
                              const Intersection &nextV, Intersection &d_nextV,
                              Spectrum d_value)
{
    [[maybe_unused]] Spectrum value;
#if defined(ENZYME) && defined(PATH)
    __enzyme_autodiff((void *)__evalFirstVertexContrib,
                      enzyme_dup, &scene, &d_scene,
                      enzyme_const, sampler,
                      enzyme_const, &pixel_idx,
                      enzyme_dup, &curV, &d_curV,
                      enzyme_dup, &nextV, &d_nextV,
                      enzyme_dup, &value, &d_value);
#endif
}
void d_evalVertexContrib(const Scene &scene, Scene &d_scene,
                         RndSampler *sampler,
                         const Intersection &preV, Intersection &d_preV,
                         const Intersection &curV, Intersection &d_curV,
                         const Intersection &nextV, Intersection &d_nextV,
                         Spectrum d_value)
{
    [[maybe_unused]] Spectrum value;
#if defined(ENZYME) && defined(PATH)
    __enzyme_autodiff((void *)__evalVertexContrib,
                      enzyme_dup, &scene, &d_scene,
                      enzyme_const, sampler,
                      enzyme_dup, &preV, &d_preV,
                      enzyme_dup, &curV, &d_curV,
                      enzyme_dup, &nextV, &d_nextV,
                      enzyme_dup, &value, &d_value);
#endif
}
void d_evalLastVertexContrib(const Scene &scene, Scene &d_scene,
                             const Intersection &curV, Intersection &d_curV,
                             const Intersection &neeV, Intersection &d_neeV,
                             Spectrum d_value)
{
    [[maybe_unused]] Spectrum value;
#if defined(ENZYME) && defined(PATH)
    __enzyme_autodiff((void *)__evalLastVertexContrib,
                      enzyme_dup, &scene, &d_scene,
                      enzyme_dup, &curV, &d_curV,
                      enzyme_dup, &neeV, &d_neeV,
                      enzyme_dup, &value, &d_value);
#endif
}

void evalVertex(const Scene &scene, LightPathWAS &path, RndSampler *sampler)
{
    const Array2i &pixel_idx = path.pixel_idx;
    auto &vertices = path.vertices;
    const auto &vtxIds = path.vs;

    Intersection &camV = vertices[0];
    if (path.vs.size() <= 1 || vertices[vtxIds[1]].type == EVInvalid)
        return;
    Intersection &firstV = vertices[vtxIds[1]];
    camV.value = evalFirstVertexContrib(scene, sampler, pixel_idx, camV, firstV);
    if (firstV.type == EVEmitter)
    {
        camV.nee_bsdf = evalLastVertexContrib(scene, camV, firstV);
        
        // std::cout << "<-------------CamV.nee_bsdf------------>" << std::endl;
        // std::cout << camV.nee_bsdf << std::endl;
    }
    for (int i = 1; i < vtxIds.size(); i++)
    {
        Intersection &preV = vertices[vtxIds[i - 1]];
        Intersection &curV = vertices[vtxIds[i]];
        if (i < vtxIds.size() - 1)
        {
            Intersection &nextV = vertices[vtxIds[i + 1]];
            curV.value = evalVertexContrib(scene, sampler, preV, curV, nextV);
        }
        // handle nee: emitter sampling and bsdf sampling
        for (int nee_id : {curV.l_nee_id, curV.l_bsdf_id})
        {
            if (nee_id == -1)
                continue;
            // compute bsdf/phase * (sig_s) * transmittance * G * J / pdf
            Intersection &neeV = vertices[nee_id];
            Spectrum value = evalVertexContrib(scene, sampler, preV, curV, neeV);
            if (nee_id == curV.l_nee_id){
                curV.nee_bsdf = value;
                // std::cout << "<-------------V[" << i << "].nee_bsdf------------>" << std::endl;
                // std::cout << curV.nee_bsdf << std::endl;
            }
            if (nee_id == curV.l_bsdf_id)
                curV.bsdf_bsdf = value;

            // compute Le
            neeV.value = evalLastVertexContrib(scene, curV, neeV);
        }
    }
}

void d_evalVertex(const Scene &scene, Scene &d_scene,
                  const LightPathWAS &path, LightPathWAS &d_path,
                  RndSampler *sampler)
{
    const Array2i &pixel_idx = path.pixel_idx;
    auto &vertices = path.vertices;
    auto &d_vertices = d_path.vertices;
    const auto &vtxIds = path.vs;

    const Intersection &camV = vertices[0];
    Intersection &d_camV = d_vertices[0];
    if (path.vs.size() <= 1 || vertices[vtxIds[1]].type == EVInvalid)
        return;
    const Intersection &firstV = vertices[vtxIds[1]];
    Intersection &d_firstV = d_vertices[vtxIds[1]];
    d_evalFirstVertexContrib(
        scene, d_scene, sampler, pixel_idx,
        camV, d_camV, firstV, d_firstV,
        d_camV.value);

    if (firstV.type == EVEmitter)
    {
        d_evalLastVertexContrib(
            scene, d_scene,
            camV, d_camV, firstV, d_firstV,
            d_camV.nee_bsdf);
    }

    for (int i = 1; i < vtxIds.size(); i++)
    {
        const Intersection &preV = vertices[vtxIds[i - 1]];
        Intersection &d_preV = d_vertices[vtxIds[i - 1]];
        const Intersection &curV = vertices[vtxIds[i]];
        Intersection &d_curV = d_vertices[vtxIds[i]];
        if (i < vtxIds.size() - 1)
        {
            const Intersection &nextV = vertices[vtxIds[i + 1]];
            Intersection &d_nextV = d_vertices[vtxIds[i + 1]];
            d_evalVertexContrib(
                scene, d_scene, sampler,
                preV, d_preV, curV, d_curV, nextV, d_nextV,
                d_curV.value);
        }
        // handle nee: emitter sampling and bsdf sampling
        for (int nee_id : {curV.l_nee_id, curV.l_bsdf_id})
        {
            if (nee_id == -1)
                continue;
            // compute bsdf/phase * (sig_s) * transmittance * G * J / pdf
            const Intersection &neeV = vertices[nee_id];
            Intersection &d_neeV = d_vertices[nee_id];
            if (nee_id == curV.l_nee_id)
                d_evalVertexContrib(
                    scene, d_scene, sampler,
                    preV, d_preV, curV, d_curV, neeV, d_neeV,
                    d_curV.nee_bsdf);
            if (nee_id == curV.l_bsdf_id)
                d_evalVertexContrib(
                    scene, d_scene, sampler,
                    preV, d_preV, curV, d_curV, neeV, d_neeV,
                    d_curV.bsdf_bsdf);
            // compute Le
            d_evalLastVertexContrib(
                scene, d_scene,
                curV, d_curV, neeV, d_neeV,
                d_neeV.value);
        }
    }
}

Spectrum evalPath(const Scene &scene, LightPathWAS&path)
{
    Spectrum value = Spectrum::Zero();
    Spectrum throughput = Spectrum::Ones();

    if (path.vs.size() <= 1 || path.vertices[path.vs[1]].type == EVInvalid)
        return Spectrum::Zero();
    const Intersection &camV = path.vertices[0];
    const Intersection &firstV = path.vertices[path.vs[1]];
    if (firstV.type == EVEmitter)
        value += camV.value * camV.nee_bsdf;

    for (int i = 0; i < path.vs.size() - 1; i++)
    {
        Intersection &preV = path.vertices[path.vs.at(i)];
        Intersection &curV = path.vertices[path.vs.at(i + 1)];
        throughput *= preV.value;

        // emitter sampling
        if (curV.l_nee_id >= 0)
        {
            Intersection &neeV = path.vertices[curV.l_nee_id];
            value += throughput * curV.nee_bsdf * neeV.value;
        }

        // bsdf sampling
        if (curV.l_bsdf_id >= 0)
        {
            Intersection &bsdfV = path.vertices[curV.l_bsdf_id];
            value += throughput * curV.bsdf_bsdf * bsdfV.value;
        }
    }
    return value;
}

void __evalPath(const Scene &scene, LightPathWAS &path, Spectrum &value)
{
    value = evalPath(scene, path);
}

void d_evalPath(const Scene &scene, Scene &d_scene,
                const LightPathWAS &path, LightPathWAS &d_path,
                Spectrum d_value)
{
    [[maybe_unused]] Spectrum value;
#if defined(ENZYME) && defined(PATH)
    __enzyme_autodiff((void *)__evalPath,
                      enzyme_dup, &scene, &d_scene,
                      enzyme_dup, &path, &d_path,
                      enzyme_dup, &value, &d_value);
#endif
}

int sampleSurface(const Vector2 &rnd2, const Scene& scene, const Intersection &primal, Intersection& aux, Float sigma) {
    // Vector2 rnd2(_rnd2);
    Vector2 _rnd2(rnd2);
    // std::cout << "<------sample surface------>" << std::endl;
    Float rnd1 = _rnd2[0];
    _rnd2[0] = _rnd2[0] > 0.5 ? (_rnd2[0] - 0.5) * 2.0 : _rnd2[0] * 2.0;
    const Shape* shape = scene.getShape(primal.shape_id);
    const Vector3i &ind = shape->indices[primal.triangle_id];
    const Vector &v0 = shape->vertices[ind[0]], &v1 = shape->vertices[ind[1]], &v2 = shape->vertices[ind[2]];
    Vector2 barycentric = primal.barycentric;
    Vector x0 = (1.0f - primal.barycentric[0] - primal.barycentric[1])*v0 + primal.barycentric[0]*v1 + primal.barycentric[1]*v2;
    
    // std::cout << barycentric << std::endl;
    Vector coord_system[3];
    coord_system[0] = shape->faceNormals[primal.triangle_id];
    coordinateSystem(coord_system[0], coord_system[1], coord_system[2]);
    // Intersection its;
    Ray ray;
    Vector2 diskpos = squareToGaussianDisk(_rnd2, sigma);
    int axis = 0;
    
    Vector sphereCoord = Vector(diskpos(0), diskpos(1), 1e-2);
    Vector origin = x0 + sphereCoord[2] * coord_system[axis]// * direction
                        + sphereCoord[0] * coord_system[(axis + 1) % 3]
                        + sphereCoord[1] * coord_system[(axis + 2) % 3];
    
    Intersection its0, its1;
    //one ray down
    ray = Ray(origin, -coord_system[axis]);
    bool its0_exists = scene.rayIntersect(ray, its0, IntersectionMode::ESpatial);
    if (its0.t > 0.1) its0_exists = false;
    if (its0.shape_id != primal.shape_id) its0_exists = false;

    //one ray up
    ray = Ray(origin, coord_system[axis]);
    bool its1_exists = scene.rayIntersect(ray, its1, IntersectionMode::ESpatial);
    // std::cout << "origin: " << origin << std::endl;
    // std::cout << "ray_dir: " << coord_system[axis] << std::endl;
    // std::cout << "its.p: " << its0.p << ", " << its1.p << std::endl;
    if (its1.t > 0.1) its1_exists = false;
    if (its1.shape_id != primal.shape_id) its1_exists = false;
    // std::cout << "exists: " << its0_exists << ", " << its1_exists << std::endl;
    if (its0_exists && its1_exists) {
        if (rnd1 > 0.5) 
            aux = its0;
        else 
            aux = its1;
        return 2;
    } else if (its0_exists) {
        aux = its0;
        return 1;
    } else if (its1_exists) {
        aux = its1;
        return 1;
    }
    else {
        return 0;
    }
    // std::cout << "<----sample surface end---->" << std::endl;
}

void sampleSurfacePair(const Vector2 &rnd2, const Scene& scene, const Intersection &primal, Intersection& aux_0, Intersection& aux_1, int &found0, int &found1, Float sigma) {
    Vector2 _rnd2 = rnd2;
    found0 = sampleSurface(_rnd2, scene, primal, aux_0, sigma);
    _rnd2[1] = (_rnd2[1] + 0.5) > 1 ? _rnd2[1] - 0.5 : _rnd2[1] + 0.5;
    found1 = sampleSurface(_rnd2, scene, primal, aux_1, sigma);
}

Float sampleSurfacePdf(const Scene& scene, const Intersection &primal, const Intersection &aux, Float sigma) {
    const Shape* shape = scene.getShape(primal.shape_id);
    
    Vector coord_system[3];
    coord_system[0] = shape->faceNormals[primal.triangle_id];
    coordinateSystem(coord_system[0], coord_system[1], coord_system[2]);
    Vector off = aux.p - primal.p;
    int axis = 0;
    Vector2 diskpos = Vector2(coord_system[(axis + 1) % 3].dot(off), coord_system[(axis + 2) % 3].dot(off));
    // std::cout << "<---------------sampleSurfacePdf-------------->" << std::endl;
    Float pdf = squareToGaussianDiskPdf(diskpos, sigma);
    // std::cout << diskpos << std::endl;
    // std::cout << sigma << ", " << pdf << std::endl;
    // std::cout << "<-------------sampleSurfacePdf end------------>" << std::endl;
    // assert(pdf > 0);
    return pdf;
}

void harmonic_weight(const Vector &prim, const Vector &aux, const Float &rayLength, const Float &sigma, const Float &B, Float &w, Vector &dw) { // hand derived
    Float dist = (prim - aux).squaredNorm();
    Float D = pow(rayLength / sigma, power) * (1.0 - exp(-dist / sigma / rayLength));
    w = 1.0 / ((pow(D, 3) + B));
    dw = -3.0 * pow(D, 2) / pow((pow(D, 3) + B), 2) * pow(rayLength / sigma, power) * (exp(-dist / sigma / rayLength) / sigma / rayLength * 2.0 * (prim - aux));
}

void harmonic_weight_dir_fwd(const Vector &prim, const Vector &aux, const Vector orig, const Float &rayLength, const Float &sigma, const Float &B, Float &w) { // hand derived
    Vector prim_dir = (prim - orig).normalized();
    Vector aux_dir = (aux - orig).normalized();
    Float dist = rayLength * rayLength;
    Float kappa = (dist * dist) / (2 * sigma * sigma);
    Float D = exp(kappa * (1 - prim_dir.dot(aux_dir))) - 1;
    // FloatAD w = 1.0 / (D + B).pow(2);
    Float gauss = exp(kappa / 2.0 * (prim_dir.dot(aux_dir) - 1));
    w = pow(gauss, 2) / pow(1.0 - gauss * B, 2);
}

void d_harmonic_weight_dir(const Vector &prim, const Vector &aux, const Vector orig, const Float &rayLength, const Float &sigma, const Float &B, Float &w, Vector &dw) {
    Float _dw = 1.0;
    __enzyme_autodiff((void *)harmonic_weight_dir_fwd,
                      enzyme_dup, &prim, &dw,
                      enzyme_const, &aux,
                      enzyme_const, &orig,
                      enzyme_const, &rayLength,
                      enzyme_const, &sigma,
                      enzyme_const, &B,
                      enzyme_dup, &w, &_dw);
}

void harmonic_weight_dir(const Vector &prim, const Vector &aux, const Vector orig, const Float &rayLength, const Float &sigma, const Float &B, Float &w, Vector &dw) { // hand derived
    harmonic_weight_dir_fwd(prim, aux, orig, rayLength, sigma, B, w);
    Float _w = w;
    d_harmonic_weight_dir(prim, aux, orig, rayLength, sigma, B, _w, dw);
}

void sampleVertex(const Scene& scene, RndSampler *sampler, const Intersection &cur, const Intersection &prev, 
                std::vector<WASCache> &cache_list) {
    
    // std::cout << "<---------------sample Vertex-------------->" << std::endl;
    int valid_sample[64];
    Float rayLength = std::sqrt((prev.p - cur.p).norm());
    Float sigma = base_sigma * rayLength;
    bool ANTITHETIC = true;
    assert(!ANTITHETIC || (NUM_AUX_SAMPLES % 2 == 0));
    Intersection antithetic_buffer;
    
    for (int i = 0; i < NUM_AUX_SAMPLES; i++) {
        // std::cout << "<-------------iter " << i << " ------------>" << std::endl;
        Intersection x_aux;
        x_aux.shape_id = cur.shape_id;
        WASCache was_cache;
        // std::cout << "x bary" << x->its.barycentric << std::endl;
        if (!ANTITHETIC) {
            valid_sample[i] = sampleSurface(sampler->next2D(), scene, cur, x_aux, sigma);
        } else if (i % 2 == 0) {
            sampleSurfacePair(sampler->next2D(), scene, cur, 
                                x_aux, antithetic_buffer, valid_sample[i], valid_sample[i + 1], sigma);
        } else {
            x_aux = antithetic_buffer;
        }
        
        if (valid_sample[i] == 0) {
            was_cache.VALID = false;
            cache_list[i] = was_cache;
            continue;
        }
        
        Vector n_local = cur.geoFrame.toLocal(x_aux.geoFrame.n);
        Float area_J = n_local[2];
        Float pdf = sampleSurfacePdf(scene, cur, x_aux, sigma) * area_J / valid_sample[i];
        was_cache.aux = x_aux;

        Intersection isec_half;
        Vector dir = (x_aux.p - prev.p).normalized();
        Ray ray(prev.p, dir);
        bool half_found = scene.rayIntersect(ray, true, isec_half);
        Float B = 0.0;
        if (!half_found) {
            was_cache.VALID = false;
            cache_list[i] = was_cache;
            continue;
        }
        // if (dir.dot(prev.geoFrame.n) < 1e-2 && !x_aux.ptr_bsdf->isTwosided()) { // if aux_point is in the back hemisphere
        //     B = (scene.shape_list[isec_half.shape_id])->B(ray, isec_half);
        //     was_cache.force_zero = true;
        // }
        if (isec_half.t < 1e-5) {
            B = (scene.shape_list[isec_half.shape_id])->B(ray, isec_half);
            was_cache.force_zero = true;
        }
        if ((isec_half.p - x_aux.p).norm() < 1e-3) {
            B = (scene.shape_list[isec_half.shape_id])->B(ray, isec_half);
            was_cache.force_zero = true;
        }
        if (isec_half.shape_id < 0) {
            was_cache.VALID = false;
            cache_list[i] = was_cache;
            continue;
        }
        // harmonic_weight(cur.p, x_aux.p, rayLength, sigma, B, was_cache.w, was_cache.dw);
        harmonic_weight_dir(cur.p, x_aux.p, prev.p, rayLength, sigma, B, was_cache.w, was_cache.dw);
        was_cache.w /= pdf;
        was_cache.dw /= pdf;
        if ((isec_half.p - x_aux.p).norm() > 1e-3) {
            // std::cout << "<-------------sampleVertex " << i << "------------>" << std::endl;
            // std::cout << "force_zero: " << was_cache.force_zero << std::endl;
            // std::cout << "prim: " << cur.p << ", " << cur.shape_id << ", " << cur.triangle_id << std::endl;
            // std::cout << "aux: " << x_aux.p << ", " << x_aux.shape_id << ", " << x_aux.triangle_id << std::endl;
            // std::cout << "w, dw: " << was_cache.w << ", " << was_cache.dw << std::endl;
            // std::cout << "aux_half: " << isec_half.p << ", " << isec_half.shape_id << std::endl;
        }
        was_cache.aux_half = isec_half;
        cache_list[i] = was_cache;
    }
    assert(cache_list.size() == NUM_AUX_SAMPLES);
    // std::cout << cache_list.size() << "\n";
    // std::cout << "<-------------sample Vertex end------------>" << std::endl;
}

void sampleWarp(const Scene &scene, LightPathWAS &path, RndSampler *sampler){
    if (path.vs.size() <= 1 || path.vertices[path.vs[1]].type == EVInvalid)
        return;
    for (int i = 0; i < path.vs.size() - 1; i++) // i==0: Camera, i==1: CameraVertex
    {
        Intersection &preV = path.vertices[path.vs.at(i)];
        Intersection &curV = path.vertices[path.vs.at(i + 1)];
        // std::cout << "path size: " << path.vs.size() << std::endl;
        if (i > 0){
            sampleVertex(scene, sampler, curV, preV, path.vertex_cache[path.vs.at(i + 1)]);
        }
        // emitter sampling
        if (curV.l_nee_id >= 0)
        {
            Intersection &neeV = path.vertices[curV.l_nee_id];
            // std::cout << "<-------------sample_NEE------------>" << std::endl;
            // std::cout << "nee: " << neeV.p << ", " << neeV.shape_id << ", " << neeV.triangle_id << std::endl;
            sampleVertex(scene, sampler, neeV, curV, path.vertex_cache[curV.l_nee_id]);
            // std::cout << "<-----------sample_NEE end---------->" << std::endl;
        }

        // bsdf sampling
        if (curV.l_bsdf_id >= 0)
        {
            Intersection &bsdfV = path.vertices[curV.l_bsdf_id];
            // std::cout << "<-------------sample_BSDF------------>" << std::endl;
            // std::cout << "nee: " << bsdfV.p << ", " << bsdfV.shape_id << ", " << bsdfV.triangle_id << std::endl;
            sampleVertex(scene, sampler, bsdfV, curV, path.vertex_cache[curV.l_bsdf_id]);
            // std::cout << "<-----------sample_BSDF end---------->" << std::endl;
        }
    }
}

INACTIVE_FN(sampleWarp, sampleWarp);

void velocity(const Vector &xS_0, const Vector &xS_1, const Vector &xS_2, const Float &uS, const Float &vS,
                const Vector &xB_0, const Vector &xB_1, const Vector &xB_2, const Float &uB, const Float &vB,
                const Vector &xD_0, const Vector &xD_1, const Vector &xD_2,
                Vector &x)
{
    const Vector &xB = (1.0 - uB - vB) * xB_0 +
               uB * xB_1 + vB * xB_2;
    const Vector &xS = (1.0 - uS - vS) * xS_0 +
               uS * xS_1 + vS * xS_2;
    Ray ray(xS, (xB - xS).normalized());
    
    Vector uvt = rayIntersectTriangle(xD_0, xD_1, xD_2, ray);
    Float u = uvt(0), v = uvt(1);
    x = (1.0 - u - v) * detach(xD_0) +
            u * detach(xD_1) +
            v * detach(xD_2);
}

// void d_velocity(const Vector &xS_0, Vector &d_xS_0, const Vector &xS_1, Vector &d_xS_1, const Vector &xS_2, Vector &d_xS_2, 
//                 const Float &uS, const Float &vS,
//                 const Vector &xB_0, Vector &d_xB_0, const Vector &xB_1, Vector &d_xB_1, const Vector &xB_2, Vector &d_xB_2, 
//                 const Float &uB, const Float &vB,
//                 const Vector &xD_0, Vector &d_xD_0, const Vector &xD_1, Vector &d_xD_1, const Vector &xD_2, Vector &d_xD_2, 
//                 Vector &d_x){
//     [[maybe_unused]] Vector x;
//     [[maybe_unused]] Float d_uS;
//     [[maybe_unused]] Float d_vS;
//     #if defined(ENZYME) && defined(PATH)
//         __enzyme_fwddiff((void *)velocity,
//                         enzyme_dup, &xS_0, &d_xS_0, 
//                         enzyme_dup, &xS_1, &d_xS_1, 
//                         enzyme_dup, &xS_2, &d_xS_2,
//                         enzyme_const, &uS,
//                         enzyme_const, &vS,
//                         enzyme_dup, &xB_0, &d_xB_0, 
//                         enzyme_dup, &xB_1, &d_xB_1, 
//                         enzyme_dup, &xB_2, &d_xB_2,
//                         enzyme_const, &uB, 
//                         enzyme_const, &vB,
//                         enzyme_dup, &xD_0, &d_xD_0, 
//                         enzyme_dup, &xD_1, &d_xD_1, 
//                         enzyme_dup, &xD_2, &d_xD_2, 
//                         enzyme_dup, &x, &d_x);
//     #endif
// }

void rayIntersectEdgeExt(const Scene& scene, const Intersection& origin, const Intersection &edge_ext, const Intersection &aux, Vector& x) {
    /* the desired dxdt will be stored in dxdt.der, since it has the correct size */
    // std::cout << "<-------------convUniDir------------>" << std::endl;
    const Shape *shapeB = scene.shape_list[edge_ext.shape_id];
    const Vector3i &indB = shapeB->getIndices(edge_ext.triangle_id);
    const Vector &xB_0 = shapeB->getVertex(indB[0]);
    const Vector &xB_1 = shapeB->getVertex(indB[1]);
    const Vector &xB_2 = shapeB->getVertex(indB[2]);
    const Float uB = edge_ext.barycentric[0],
                vB = edge_ext.barycentric[1];

    const Shape *shapeS = scene.shape_list[origin.shape_id];
    const auto &indS = shapeS->getIndices(origin.triangle_id);
    const Vector &xS_0 = shapeS->getVertex(indS[0]);
    const Vector &xS_1 = shapeS->getVertex(indS[1]);
    const Vector &xS_2 = shapeS->getVertex(indS[2]);
    const Float uS = origin.barycentric[0],
                vS = origin.barycentric[1];

    const Shape *shapeD = scene.shape_list[aux.shape_id];
    const auto &indD = shapeD->getIndices(aux.triangle_id);
    const Vector &xD_0 = shapeD->getVertex(indD[0]);
    const Vector &xD_1 = shapeD->getVertex(indD[1]);
    const Vector &xD_2 = shapeD->getVertex(indD[2]);

    velocity(xS_0, xS_1, xS_2, 
            uS, vS,
            xB_0, xB_1, xB_2,
            uB, vB,
            xD_0, xD_1, xD_2,
            x);
}

void getVertex(const Scene& scene,
               const int &shape_id, const int &vertex_id, Vector& vertex) {
    vertex = scene.shape_list[shape_id]->getVertex(vertex_id);
}

void d_getVertex(const Scene& scene, Scene &d_scene,
                int shape_id, int vertex_id, Vector& d_vertex) {
    [[maybe_unused]] Vector vertex;
    #if defined(ENZYME) && defined(PATH)
        __enzyme_autodiff((void *)getVertex,
                        enzyme_dup, &scene, &d_scene,
                        enzyme_const, &shape_id,
                        enzyme_const, &vertex_id,
                        enzyme_dup, &vertex, &d_vertex);
    #endif
}

void d_rayIntersectEdgeExt(const Scene& scene, Scene& d_scene, 
                            const Intersection& origin, const Intersection &edge_ext, const Intersection &aux, Vector& d_x) {
    [[maybe_unused]] Vector x;
    #if defined(ENZYME) && defined(PATH)
            __enzyme_autodiff((void *)rayIntersectEdgeExt,
                            enzyme_dup, &scene, &d_scene,
                            enzyme_const, &origin,
                            enzyme_const, &edge_ext,
                            enzyme_const, &aux,
                            enzyme_dup, &x, &d_x);
    #endif
}

void d_getWASPath(const Scene &scene, Scene &d_scene, LightPathWAS&path, LightPathWAS&d_path)
{
    if (path.vs.size() <= 1 || path.vertices[path.vs[1]].type == EVInvalid)
        return;
    for (int i = 0; i < path.vs.size() - 1; i++) // i==0: Camera, i==1: CameraVertex
    {
        Intersection &preV = path.vertices[path.vs.at(i)];
        Intersection &curV = path.vertices[path.vs.at(i + 1)];
        if (i > 0){
            
            for (int j = 0; j < path.vertex_cache[path.vs.at(i + 1)].size(); j++) { // for every aux_point
                WASCache& cache = path.vertex_cache[path.vs.at(i + 1)][j];
                WASCache& d_cache = d_path.vertex_cache[path.vs.at(i + 1)][j];
                if (!cache.VALID) continue;
                
                d_rayIntersectEdgeExt(scene, d_scene, preV, cache.aux_half, cache.aux, d_cache.aux.p);
            }
        }
        // emitter
        if (curV.l_nee_id >= 0)
        {
            for (int j = 0; j < path.vertex_cache[curV.l_nee_id].size(); j++) {
                WASCache& cache = path.vertex_cache[curV.l_nee_id][j];
                WASCache& d_cache = d_path.vertex_cache[curV.l_nee_id][j];
                if (!cache.VALID) continue;

                d_rayIntersectEdgeExt(scene, d_scene, curV, cache.aux_half, cache.aux, d_cache.aux.p);
            }
        }

        // bsdf
        if (curV.l_bsdf_id >= 0)
        {
            for (int j = 0; j < path.vertex_cache[curV.l_bsdf_id].size(); j++) {
                WASCache& cache = path.vertex_cache[curV.l_bsdf_id][j];
                WASCache& d_cache = d_path.vertex_cache[curV.l_bsdf_id][j];
                if (!cache.VALID) continue;
                d_rayIntersectEdgeExt(scene, d_scene, curV, cache.aux_half, cache.aux, d_cache.aux.p);
            }
        }
    }
}

void warpSurface(const std::vector<WASCache> &cache_list, 
                Vector &Warp, Float &divWarp){
    double Z = 0.0;
    Vector dZ;
    dZ.setZero();

    for (int i = 0; i < NUM_AUX_SAMPLES; i++) {
        if (!cache_list[i].VALID) continue;
        Z += cache_list[i].w;
        dZ += cache_list[i].dw;
    }

    Vector X_holder;
    X_holder.setZero();
    Float dwV = 0.0, wVd = 0.0;
    for (int i = 0; i < NUM_AUX_SAMPLES; i++) {
        const WASCache& cache = cache_list[i];
        if (!cache.VALID) continue;
        if (cache.force_zero) continue;
        X_holder += cache.w * cache.aux.p;
        dwV += cache.dw.dot(cache.aux.p);
        wVd += cache.w * dZ.dot(cache.aux.p);
    }

    Warp = X_holder / Z;
    divWarp = (dwV / Z - wVd / (Z * Z));
}

void d_warpSurface(const std::vector<WASCache> &cache_list, std::vector<WASCache> &d_cache_list,
                Vector &d_warp, Float &d_div_warp)
{
    [[maybe_unused]] Vector warp;
    [[maybe_unused]] Float div_warp;
#if defined(ENZYME) && defined(PATH)
    __enzyme_autodiff((void *)warpSurface,
                      enzyme_dup, &cache_list,  &d_cache_list,
                      enzyme_dup, &warp, &d_warp,
                      enzyme_dup, &div_warp, &d_div_warp);
#endif
}


void evalWarp(const Scene &scene, LightPathWAS &path)
{
    if (path.vs.size() <= 1 || path.vertices[path.vs[1]].type == EVInvalid)
        return;

    for (int i = 0; i < path.vs.size() - 1; i++) // i==0: Camera, i==1: CameraVertex
    {
        Intersection &preV = path.vertices[path.vs.at(i)];
        Intersection &curV = path.vertices[path.vs.at(i + 1)];
        bool valid;
        if (i > 0) {
            warpSurface(path.vertex_cache[path.vs.at(i + 1)], 
                        path.warped_X[path.vs.at(i + 1)], path.div_warped_X[path.vs.at(i + 1)]);
        }

        // emitter sampling
        if (curV.l_nee_id >= 0)
        {
            Intersection &neeV = path.vertices[curV.l_nee_id];
            warpSurface(path.vertex_cache[curV.l_nee_id], 
                        path.warped_X[curV.l_nee_id], path.div_warped_X[curV.l_nee_id]);
        }

        // bsdf sampling
        if (curV.l_bsdf_id >= 0)
        {
            Intersection &bsdfV = path.vertices[curV.l_bsdf_id];
            warpSurface(path.vertex_cache[curV.l_bsdf_id], 
                        path.warped_X[curV.l_bsdf_id], path.div_warped_X[curV.l_bsdf_id]);
        }
    }
}

void d_evalWarp(const Scene &scene,
                const LightPathWAS &path, LightPathWAS &d_path)
{
    if (path.vs.size() <= 1 || path.vertices[path.vs[1]].type == EVInvalid)
        return;
    for (int i = 0; i < path.vs.size() - 1; i++) // i==0: Camera, i==1: CameraVertex
    {
        const Intersection &preV = path.vertices[path.vs.at(i)];
        const Intersection &curV = path.vertices[path.vs.at(i + 1)];
        
        if (i > 0) {
            // std::cout << "<-------------d_evalwarp " << i << "------------>" << std::endl;
            // std::cout << "d_warp: " << d_path.warped_X[curV.l_nee_id] << "\n";
            // std::cout << "d_div_warp: " << d_path.div_warped_X[curV.l_nee_id] << "\n";
            d_warpSurface(path.vertex_cache[path.vs.at(i + 1)], d_path.vertex_cache[path.vs.at(i + 1)], 
                        d_path.warped_X[path.vs.at(i + 1)], d_path.div_warped_X[path.vs.at(i + 1)]);
        }

        // emitter sampling
        if (curV.l_nee_id >= 0)
        {
            const Intersection &neeV = path.vertices[curV.l_nee_id];
            // std::cout << "<-------------d_evalwarp " << i << "------------>" << std::endl;
            // std::cout << "d_warp: " << d_path.warped_X[curV.l_nee_id] << "\n";
            // std::cout << "d_div_warp: " << d_path.div_warped_X[curV.l_nee_id] << "\n";
            d_warpSurface(path.vertex_cache[curV.l_nee_id], d_path.vertex_cache[curV.l_nee_id], 
                        d_path.warped_X[curV.l_nee_id], d_path.div_warped_X[curV.l_nee_id]);
        }

        // bsdf sampling
        if (curV.l_bsdf_id >= 0)
        {
            const Intersection &bsdfV = path.vertices[curV.l_bsdf_id];
            d_warpSurface(path.vertex_cache[curV.l_bsdf_id], d_path.vertex_cache[curV.l_bsdf_id], 
                        d_path.warped_X[curV.l_bsdf_id], d_path.div_warped_X[curV.l_bsdf_id]);
        }
    }
    // std::cout << "<-------------d_warp end------------>" << std::endl;
}

Spectrum evalBoundary(const Scene &scene, const LightPathWAS &path, const LightPathWAS &d_vertex)
{
    // std::cout << "<-------------evalBoundary------------>" << std::endl;
    Spectrum value = Spectrum::Zero();
    Spectrum throughput = Spectrum::Ones();

    if (path.vs.size() <= 1 || path.vertices[path.vs[1]].type == EVInvalid)
        return Spectrum::Zero();
    const Intersection &camV = path.vertices[0];

    for (int i = 0; i < path.vs.size() - 1; i++)
    {
        const Intersection &preV = path.vertices[path.vs.at(i)];
        const Intersection &curV = path.vertices[path.vs.at(i + 1)];
        throughput *= (preV.value);

        // emitter sampling
        if (curV.l_nee_id >= 0)
        {
            const Intersection &neeV = path.vertices[curV.l_nee_id];
            Spectrum nee_throughput = (throughput * curV.nee_bsdf * neeV.value);
            for (int j = 1; j < path.vs.size() - 1; j++) { // 0->1: pixel discont.;
                value += nee_throughput * path.div_warped_X[path.vs.at(j + 1)]; // warp stored in "the next vertex"
            }
            value += nee_throughput * (path.div_warped_X[curV.l_nee_id]); // also add nee contribution
        }

        // bsdf sampling
        if (curV.l_bsdf_id >= 0)
        {
            const Intersection &bsdfV = path.vertices[curV.l_bsdf_id];
            Spectrum bsdf_throughput = (throughput * curV.bsdf_bsdf * bsdfV.value);
            for (int j = 1; j < path.vs.size() - 1; j++) { // 0->1: pixel discont.;
                value += bsdf_throughput * path.div_warped_X[path.vs.at(j + 1)]; // warp stored in "the next vertex"
            }
            value += bsdf_throughput * (path.div_warped_X[curV.l_bsdf_id]); // also add bsdf contribution
        }
    }
    // for (int i = 2; i < d_vertex.vertices.size(); i++)
    // {
    //     value += d_vertex.vertices[i].p.dot(path.warped_X[i]);
    // }
    // std::cout << "<-----------evalBoundary end---------->" << std::endl;
    return value;
}

void __evalBoundary(const Scene &scene, const LightPathWAS &path, const LightPathWAS &d_vertex, Spectrum &value)
{
    value = evalBoundary(scene, path, d_vertex);
}

void d_evalBoundary(const Scene &scene,
                    const LightPathWAS &path, LightPathWAS &d_path, // only use Warp and divWarp
                    const LightPathWAS &d_vertex, // here d_vertex is dL/dx (or dL/dp), and is used as a constant
                    Spectrum d_value)
{
    [[maybe_unused]] Spectrum value;
    __enzyme_autodiff((void *)__evalBoundary,
                      enzyme_const, &scene,
                      enzyme_dup, &path, &d_path,
                      enzyme_const, d_vertex,
                      enzyme_dup, &value, &d_value);
    return;

    // // std::cout << "<-------------evalBoundary------------>" << std::endl;
    // Spectrum throughput = Spectrum::Ones();

    // if (path.vs.size() <= 1 || path.vertices[path.vs[1]].type == EVInvalid)
    //     return;
    // const Intersection &camV = path.vertices[0];

    // for (int i = 0; i < path.vs.size() - 1; i++)
    // {
    //     const Intersection &preV = path.vertices[path.vs.at(i)];
    //     const Intersection &curV = path.vertices[path.vs.at(i + 1)];
    //     throughput *= (preV.value);

    //     // emitter sampling
    //     if (curV.l_nee_id >= 0)
    //     {
    //         // std::cout << "<-------------d_evalwarp------------>" << std::endl;
    //         const Intersection &neeV = path.vertices[curV.l_nee_id];
    //         Spectrum nee_throughput = (throughput * curV.nee_bsdf * neeV.value);
    //         for (int j = 1; j < path.vs.size() - 1; j++) { // 0->1: pixel discont.;
    //             d_path.div_warped_X[path.vs.at(j + 1)] += d_value[0] * nee_throughput[0]
    //                                                  + d_value[1] * nee_throughput[1]
    //                                                   + d_value[2] * nee_throughput[2]; // warp stored in "the next vertex"
    //         }
    //         d_path.div_warped_X[curV.l_nee_id] += d_value[0] * nee_throughput[0]
    //                                             + d_value[1] * nee_throughput[1]
    //                                             + d_value[2] * nee_throughput[2]; // warp stored in "the next vertex"
    //         // std::cout << "d_div_warp: " << d_path.div_warped_X[curV.l_nee_id] << "\n";
    //     }

    //     // bsdf sampling
    //     if (curV.l_bsdf_id >= 0)
    //     {
    //         const Intersection &bsdfV = path.vertices[curV.l_bsdf_id];
    //         Spectrum bsdf_throughput = (throughput * curV.bsdf_bsdf * bsdfV.value);
    //         for (int j = 1; j < path.vs.size() - 1; j++) { // 0->1: pixel discont.;
    //             d_path.div_warped_X[path.vs.at(j + 1)] += d_value[0] * bsdf_throughput[0]
    //                                                  + d_value[1] * bsdf_throughput[1]
    //                                                   + d_value[2] * bsdf_throughput[2]; // warp stored in "the next vertex"
    //         }
    //         d_path.div_warped_X[curV.l_bsdf_id] += d_value[0] * bsdf_throughput[0]
    //                                             + d_value[1] * bsdf_throughput[1]
    //                                             + d_value[2] * bsdf_throughput[2]; // warp stored in "the next vertex"
    //     }
    // }
    // for (int i = 2; i < d_vertex.vertices.size(); i++)
    // {
    //     value += d_vertex.vertices[i].p.dot(path.warped_X[i]);
    // }
    // std::cout << "<-----------evalBoundary end---------->" << std::endl;
}

__attribute__((optnone)) Spectrum eval(const Scene &scene, LightPathWAS&path, RndSampler *sampler)
{
    // LightPathWASpath(_path);
    if (!getPath(scene, path))
        return Spectrum::Zero();
    evalVertex(scene, path, sampler);
    Spectrum ret = evalPath(scene, path);
    return ret;
}

void __eval(const Scene &scene, LightPathWAS&path, RndSampler *sampler, Spectrum &ret)
{
    ret = eval(scene, path, sampler);
}

void d_eval(const Scene &scene, Scene &d_scene,
            LightPathWASAD &pathAD,
            Spectrum d_value, RndSampler *sampler)
{
    auto &path = pathAD.val;
    auto &d_path = pathAD.der;
    
    LightPathWAS d_vertex = path;
    d_vertex.setZero();
    if (!getPath(scene, path))
        return;
    evalVertex(scene, path, sampler);
    Spectrum L = evalPath(scene, path);
    Spectrum dL = Spectrum(1.0);
    if (!L.allFinite())
        return;
    
    d_evalPath(scene, d_scene, path, d_vertex, dL); // gets dL/dF(x)
    d_evalVertex(scene, d_scene, path, d_vertex, sampler); // dL/dF(x) * dF(x)/dx

    sampleWarp(scene, path, sampler); // sample warp field

    evalWarp(scene, path);
    L = evalBoundary(scene, path, d_vertex);
    if (!L.allFinite())
        return;

    d_path.resize(path);
    d_path.setZero();
    Spectrum _dv = d_value;
    d_evalBoundary(scene, path, d_path, d_vertex, d_value);
    d_evalWarp(scene, path, d_path);
    d_getWASPath(scene, d_scene, path, d_path);
    d_path.setZero();
    path.setZero();
}

NAMESPACE_END(algorithm1_was)
