#include "algorithm1.h"
#include <render/scene.h>

NAMESPACE_BEGIN(algorithm1_bdptwas)

void __evalVertex(const Scene& scene, LightPath& path, int index)
{
    const Camera& camera = scene.camera;
    Float G = 1.0f;
    Vector dir;
    if (index == 0) {
        if (!path.isCameraPath) {
            int emitter_id = path[0].ptr_shape->light_id;
            const Emitter* ptr_emitter = scene.emitter_list[emitter_id];
            path[0].value = ptr_emitter->getIntensity() / path[0].pdf * path[0].J;  //may need to change this if not using area light
        } else {
            path[0].value = camera.eval(path.pixelIdx[0], path.pixelIdx[1], path[0].p, path[0].geoFrame.n) / path[0].pdf * path[0].J;
        }
    } else {
        dir = path[index].p - path[index-1].p;
        Float dist = dir.norm();
        dir /= dist;
        G = std::abs(path[index].geoFrame.n.dot(dir)) / (dist * dist);
        if ( index == 1 && !path.isCameraPath) {
            int emitter_id = path[index-1].ptr_shape->light_id;
            const Emitter* ptr_emitter = scene.emitter_list[emitter_id];
            path[index].value = ptr_emitter->evalDirection(path[index-1].geoFrame.n, dir);
        } else {
            int bsdf_id = path[index-1].ptr_shape->bsdf_id;
            const BSDF* ptr_bsdf = scene.bsdf_list[bsdf_id];
            path[index].value = ptr_bsdf->eval(path[index-1], path[index-1].toLocal(dir),
                                               path.isCameraPath ? ERadiance : EImportanceWithCorrection);
        }
        path[index].value *= path[index].J * G / path[index].pdf;
        path[index].value *= path[index-1].value;
    }
}

void d_evalVertex(SceneAD &sceneAD, LightPathAD& pathAD, int index)
{
    auto &d_scene = sceneAD.gm.get(omp_get_thread_num());
    __enzyme_autodiff((void *)__evalVertex,
                      enzyme_dup, &sceneAD.val, &d_scene,
                      enzyme_dup, &pathAD.val, &pathAD.der,
                      enzyme_const, index);
}

void __evalPath(const Scene &scene, const LightPath& camera_path, int s, const LightPath& light_path, int t,
                Float w, Spectrum& value, Float hack)
{
    if ( s == -1 ) {
        const Camera& camera = scene.camera;
        const Intersection& v_lgt = light_path[t];
        Vector dir = (camera.cpos - v_lgt.p).normalized();
        int bsdf_id = v_lgt.ptr_shape->bsdf_id;
        Spectrum bsdf_val = scene.bsdf_list[bsdf_id]->eval(v_lgt, v_lgt.toLocal(dir), EBSDFMode::EImportanceWithCorrection);
        Float camera_val = camera.eval(light_path.pixelIdx[0], light_path.pixelIdx[1], v_lgt.p);
        value = v_lgt.value * camera_val * bsdf_val;
    } else if ( t == -1 ) {
        const Camera& camera = scene.camera;
        const Intersection& v_cam = camera_path[s];
        const Vector& pre_p = (s == 0) ? camera.cpos : camera_path[s-1].p;
        Vector dir = (pre_p - v_cam.p).normalized();
        int emitter_id = v_cam.ptr_shape->light_id;
        value = scene.emitter_list[emitter_id]->eval(v_cam.geoFrame.n, dir) * v_cam.value;
    } else {
        const Intersection& v_lgt = light_path[t];
        const Intersection& v_cam = camera_path[s];
        Vector dir = v_lgt.p - v_cam.p;
        Float dist2 = dir.squaredNorm();
        dir /= std::sqrt(dist2);
        Spectrum seg_lgt(0.f);
        if ( t == 0 ) {
            int emitter_id = v_lgt.ptr_shape->light_id;
            seg_lgt = scene.emitter_list[emitter_id]->evalDirection(v_lgt.geoFrame.n, -dir);
        } else {
            int bsdf_id = v_lgt.ptr_shape->bsdf_id;
            seg_lgt = scene.bsdf_list[bsdf_id]->eval(v_lgt, v_lgt.toLocal(-dir), EBSDFMode::EImportanceWithCorrection);
        }
        seg_lgt /= dist2;
        int bsdf_id = v_cam.ptr_shape->bsdf_id;
        value = seg_lgt * scene.bsdf_list[bsdf_id]->eval(v_cam, v_cam.toLocal(dir)) * v_lgt.value * v_cam.value;

    }
    value *= w;
}

Spectrum d_evalPath(SceneAD &sceneAD, LightPathAD& cameraPathAD, int s, LightPathAD& lightPathAD, int t,
                    Float w, Spectrum d_value)
{
    auto &d_scene = sceneAD.gm.get(omp_get_thread_num());
    Spectrum value = Spectrum::Ones();
    Float hack = 0.0f;      // hack because of enzyme detach bug (?)
    __enzyme_autodiff((void *)__evalPath,
                      enzyme_dup, &sceneAD.val, &d_scene,
                      enzyme_dup, &cameraPathAD.val, &cameraPathAD.der,
                      enzyme_const, s,
                      enzyme_dup, &lightPathAD.val, &lightPathAD.der,
                      enzyme_const, t,
                      enzyme_const, w,
                      enzyme_dup, &value, &d_value,
                      enzyme_const, hack);
    return value;
}

void __getPoint(const Scene& scene, LightPath& path, int index) {
    Intersection &v = path[index];
    const Shape *shape = scene.shape_list[v.shape_id];
    const Vector3i &ind = shape->indices[v.triangle_id];
    const Vector &v0 = shape->getVertex(ind[0]),
                &v1 = shape->getVertex(ind[1]),
                &v2 = shape->getVertex(ind[2]);
    v.p = (1. - v.barycentric.x() - v.barycentric.y()) * v0 +
        v.barycentric.x() * v1 +
        v.barycentric.y() * v2;
    Vector geo_n = shape->getFaceNormal(v.triangle_id);
    Vector sh_n = shape->getShadingNormal(v.triangle_id, v.barycentric);
    v.geoFrame = Frame(geo_n);
    v.shFrame = Frame(sh_n);
    v.J = shape->getArea(v.triangle_id);
    v.J /= detach(v.J);

    if (path.isCameraPath) {
        const Vector& pre_p = (index == 0) ? scene.camera.cpos 
                                           : (path[index-1].p);
        Vector dir = (pre_p - v.p).normalized();
        v.wi = v.toLocal(dir);
    } else {
        if (index > 0) {
            Vector dir = (path[index-1].p - v.p).normalized();
            v.wi = v.toLocal(dir);
        }
    }
}

void d_getPoint(SceneAD &sceneAD, LightPathAD& pathAD, int index) {
    auto &d_scene = sceneAD.gm.get(omp_get_thread_num());
    __enzyme_autodiff((void *)__getPoint,
                      enzyme_dup, &sceneAD.val, &d_scene,
                      enzyme_dup, &pathAD.val, &pathAD.der,
                      enzyme_const, index);
}


void velocity(const Vector &xS_0, const Vector &xS_1, const Vector &xS_2, const Float &uS, const Float &vS,
                const Vector &xB_0, const Vector &xB_1, const Vector &xB_2, const Float &uB, const Float &vB,
                const Vector &xD_0, const Vector &xD_1, const Vector &xD_2,
                Vector &x)
{
    const Vector &xB = (1.0 - uB - vB) * xB_0 +
               uB * xB_1 + vB * xB_2;
    const Vector &xS = (1.0 - uS - vS) * xS_0 +
               uS * xS_1 + vS * xS_2;
    Ray ray(xS, (xB - xS).normalized());
    
    Vector uvt = rayIntersectTriangle(xD_0, xD_1, xD_2, ray);
    Float u = uvt(0), v = uvt(1);
    x = (1.0 - u - v) * detach(xD_0) +
            u * detach(xD_1) +
            v * detach(xD_2);
}

// void d_velocity(const Vector &xS_0, Vector &d_xS_0, const Vector &xS_1, Vector &d_xS_1, const Vector &xS_2, Vector &d_xS_2, 
//                 const Float &uS, const Float &vS,
//                 const Vector &xB_0, Vector &d_xB_0, const Vector &xB_1, Vector &d_xB_1, const Vector &xB_2, Vector &d_xB_2, 
//                 const Float &uB, const Float &vB,
//                 const Vector &xD_0, Vector &d_xD_0, const Vector &xD_1, Vector &d_xD_1, const Vector &xD_2, Vector &d_xD_2, 
//                 Vector &d_x){
//     [[maybe_unused]] Vector x;
//     [[maybe_unused]] Float d_uS;
//     [[maybe_unused]] Float d_vS;
//     #if defined(ENZYME) && defined(PATH)
//         __enzyme_fwddiff((void *)velocity,
//                         enzyme_dup, &xS_0, &d_xS_0, 
//                         enzyme_dup, &xS_1, &d_xS_1, 
//                         enzyme_dup, &xS_2, &d_xS_2,
//                         enzyme_const, &uS,
//                         enzyme_const, &vS,
//                         enzyme_dup, &xB_0, &d_xB_0, 
//                         enzyme_dup, &xB_1, &d_xB_1, 
//                         enzyme_dup, &xB_2, &d_xB_2,
//                         enzyme_const, &uB, 
//                         enzyme_const, &vB,
//                         enzyme_dup, &xD_0, &d_xD_0, 
//                         enzyme_dup, &xD_1, &d_xD_1, 
//                         enzyme_dup, &xD_2, &d_xD_2, 
//                         enzyme_dup, &x, &d_x);
//     #endif
// }

void rayIntersectEdgeExt(const Scene& scene, const Intersection& origin, const Intersection &edge_ext, const Intersection &aux, Vector& x) {
    /* the desired dxdt will be stored in dxdt.der, since it has the correct size */
    // std::cout << "<-------------convUniDir------------>" << std::endl;
    const Shape *shapeB = scene.shape_list[edge_ext.shape_id];
    const Vector3i &indB = shapeB->getIndices(edge_ext.triangle_id);
    const Vector &xB_0 = shapeB->getVertex(indB[0]);
    const Vector &xB_1 = shapeB->getVertex(indB[1]);
    const Vector &xB_2 = shapeB->getVertex(indB[2]);
    const Float uB = edge_ext.barycentric[0],
                vB = edge_ext.barycentric[1];

    const Shape *shapeS = scene.shape_list[origin.shape_id];
    const auto &indS = shapeS->getIndices(origin.triangle_id);
    const Vector &xS_0 = shapeS->getVertex(indS[0]);
    const Vector &xS_1 = shapeS->getVertex(indS[1]);
    const Vector &xS_2 = shapeS->getVertex(indS[2]);
    const Float uS = origin.barycentric[0],
                vS = origin.barycentric[1];

    const Shape *shapeD = scene.shape_list[aux.shape_id];
    const auto &indD = shapeD->getIndices(aux.triangle_id);
    const Vector &xD_0 = shapeD->getVertex(indD[0]);
    const Vector &xD_1 = shapeD->getVertex(indD[1]);
    const Vector &xD_2 = shapeD->getVertex(indD[2]);

    velocity(xS_0, xS_1, xS_2, 
            uS, vS,
            xB_0, xB_1, xB_2,
            uB, vB,
            xD_0, xD_1, xD_2,
            x);
}

void getVertex(const Scene& scene,
               const int &shape_id, const int &vertex_id, Vector& vertex) {
    vertex = scene.shape_list[shape_id]->getVertex(vertex_id);
}

void d_getVertex(const Scene& scene, Scene &d_scene,
                int shape_id, int vertex_id, Vector& d_vertex) {
    [[maybe_unused]] Vector vertex;
    #if defined(ENZYME) && defined(PATH)
        __enzyme_autodiff((void *)getVertex,
                        enzyme_dup, &scene, &d_scene,
                        enzyme_const, &shape_id,
                        enzyme_const, &vertex_id,
                        enzyme_dup, &vertex, &d_vertex);
    #endif
}

void d_rayIntersectEdgeExt(SceneAD& sceneAD, 
                            const Intersection& origin, const Intersection &edge_ext, const Intersection &aux, Vector& d_x) {
    auto &d_scene = sceneAD.gm.get(omp_get_thread_num());
    [[maybe_unused]] Vector x;
    #if defined(ENZYME) && defined(PATH)
            __enzyme_autodiff((void *)rayIntersectEdgeExt,
                            enzyme_dup, &sceneAD.val, &d_scene,
                            enzyme_const, &origin,
                            enzyme_const, &edge_ext,
                            enzyme_const, &aux,
                            enzyme_dup, &x, &d_x);
    #endif
}

void d_getWAS(SceneAD &sceneAD, 
                    LightPathAD &camera_path, int s, 
                    LightPathAD &light_path, int t)
{
    
    auto &d_scene = sceneAD.gm.get(omp_get_thread_num());
    if ( s == -1 ) {
        for (int i = 0; i < t; i++) {
            const Intersection &preV = light_path.val.vertices[i + 1];
            for (int j = 0; j < NUM_AUX_SAMPLES; j++) { // for every aux_point
                const WASCache& cache = light_path.val.vertex_container[i].vertex_cache[j];
                WASCache& d_cache = light_path.der.vertex_container[i].vertex_cache[j];
                if (!cache.VALID) continue;
                d_rayIntersectEdgeExt(sceneAD, preV, cache.aux_half, cache.aux, d_cache.aux.p);
            }
        }
    } else if ( t == -1 ) {
        for (int i = 1; i <= s; i++) {
            const Intersection &preV = camera_path.val.vertices[i - 1];
            for (int j = 0; j < NUM_AUX_SAMPLES; j++) { // for every aux_point
                const WASCache& cache = camera_path.val.vertex_container[i].vertex_cache[j];
                WASCache& d_cache = camera_path.der.vertex_container[i].vertex_cache[j];
                if (!cache.VALID) continue;
                d_rayIntersectEdgeExt(sceneAD, preV, cache.aux_half, cache.aux, d_cache.aux.p);
            }
        }
    } else {
        for (int i = 1; i <= s; i++) {
            const Intersection &preV = camera_path.val.vertices[i - 1];
            for (int j = 0; j < NUM_AUX_SAMPLES; j++) { // for every aux_point
                const WASCache& cache = camera_path.val.vertex_container[i].vertex_cache[j];
                WASCache& d_cache = camera_path.der.vertex_container[i].vertex_cache[j];
                if (!cache.VALID) continue;
                d_rayIntersectEdgeExt(sceneAD, preV, cache.aux_half, cache.aux, d_cache.aux.p);
            }
        }
        for (int i = 0; i < t; i++) {
            const Intersection &preV = light_path.val.vertices[i + 1];
            for (int j = 0; j < NUM_AUX_SAMPLES; j++) { // for every aux_point
                const WASCache& cache = light_path.val.vertex_container[i].vertex_cache[j];
                WASCache& d_cache = light_path.der.vertex_container[i].vertex_cache[j];
                if (!cache.VALID) continue;
                d_rayIntersectEdgeExt(sceneAD, preV, cache.aux_half, cache.aux, d_cache.aux.p);
            }
        }
        const Intersection &connect_preV = camera_path.val.vertices[s];
        for (int j = 0; j < NUM_AUX_SAMPLES; j++) { // for every aux_point
            const WASCache& cache = light_path.val.connect_vertex_cache[j];
            WASCache& d_cache = light_path.der.connect_vertex_cache[j];
            if (!cache.VALID) continue;
            d_rayIntersectEdgeExt(sceneAD, connect_preV, cache.aux_half, cache.aux, d_cache.aux.p);
        }
    }
    
}

void d_getWASPath(SceneAD &sceneAD, 
                    LightPathAD &path)
{
    auto &d_scene = sceneAD.gm.get(omp_get_thread_num());
    if (path.val.isCameraPath) {
        for (int i = 1; i < path.val.vertices.size(); i++) {
            if (!path.val.vertex_container[i].hasGradient) continue;
            const Intersection &preV = path.val.vertices[i - 1];
            if (path.val.vertex_container[i].vertex_cache.size() < NUM_AUX_SAMPLES) continue;
            for (int j = 0; j < NUM_AUX_SAMPLES; j++) { // for every aux_point
                const WASCache& cache = path.val.vertex_container[i].vertex_cache[j];
                WASCache& d_cache = path.der.vertex_container[i].vertex_cache[j];
                if (!cache.VALID) continue;
                d_rayIntersectEdgeExt(sceneAD, preV, cache.aux_half, cache.aux, d_cache.aux.p);
            }
        }
    } else {
        for (int i = 0; i < path.val.vertices.size() - 1; i++) {
            if (!path.val.vertex_container[i].hasGradient) continue;
            const Intersection &preV = path.val.vertices[i + 1];
            if (path.val.vertex_container[i].vertex_cache.size() < NUM_AUX_SAMPLES) continue;
            for (int j = 0; j < NUM_AUX_SAMPLES; j++) { // for every aux_point
                const WASCache& cache = path.val.vertex_container[i].vertex_cache[j];
                WASCache& d_cache = path.der.vertex_container[i].vertex_cache[j];
                if (!cache.VALID) continue;
                d_rayIntersectEdgeExt(sceneAD, preV, cache.aux_half, cache.aux, d_cache.aux.p);
            }
        }
    }
    
}

// INACTIVE_FN(harmonic_weight, harmonic_weight);

void warpSurface(const std::vector<WASCache> &cache_list, 
                Vector &Warp, Float &divWarp){
    double Z = 0.0;
    Vector dZ;
    dZ.setZero();
    if (cache_list.size() < NUM_AUX_SAMPLES) return;

    for (int i = 0; i < NUM_AUX_SAMPLES; i++) {
        if (!cache_list[i].VALID) continue;
        Z += cache_list[i].w;
        dZ += cache_list[i].dw;
    }

    Vector X_holder;
    X_holder.setZero();
    Float dwV = 0.0, wVd = 0.0;
    for (int i = 0; i < NUM_AUX_SAMPLES; i++) {
        const WASCache& cache = cache_list[i];
        if (!cache.VALID) continue;
        if (!cache.force_zero) {
            X_holder += cache.w * cache.aux.p;
        }
        dwV += cache.dw.dot(cache.aux.p);
        wVd += cache.w * dZ.dot(cache.aux.p);
    }

    Warp = X_holder / Z;
    divWarp = (dwV / Z - wVd / (Z * Z));
}

void d_warpSurface(const std::vector<WASCache> &cache_list, std::vector<WASCache> &d_cache_list,
                Vector &d_warp, Float &d_div_warp)
{
    [[maybe_unused]] Vector warp;
    [[maybe_unused]] Float div_warp;
#if defined(ENZYME) && defined(PATH)
    __enzyme_autodiff((void *)warpSurface,
                      enzyme_dup, &cache_list,  &d_cache_list,
                      enzyme_dup, &warp, &d_warp,
                      enzyme_dup, &div_warp, &d_div_warp);
#endif
}


void evalWarp(const Scene &scene, 
                LightPath &camera_path, int s,
                LightPath &light_path, int t)
{
    if ( s == -1 ) {
        for (int i = 0; i < t; i++) {
            warpSurface(light_path.vertex_container[i].vertex_cache, 
                        light_path.vertex_container[i].warped_X, light_path.vertex_container[i].div_warped_X);
        }
    } else if ( t == -1 ) {
        for (int i = 1; i <= s; i++) {
            warpSurface(camera_path.vertex_container[i].vertex_cache,
                        camera_path.vertex_container[i].warped_X, camera_path.vertex_container[i].div_warped_X);
        }
    } else {
        for (int i = 1; i <= s; i++) {
            warpSurface(camera_path.vertex_container[i].vertex_cache,
                        camera_path.vertex_container[i].warped_X, camera_path.vertex_container[i].div_warped_X);
        }
        for (int i = 0; i < t; i++) {
            warpSurface(light_path.vertex_container[i].vertex_cache,
                        light_path.vertex_container[i].warped_X, light_path.vertex_container[i].div_warped_X);
        }
        warpSurface(light_path.connect_vertex_cache, light_path.connect_warped_X, light_path.connect_div_warped_X);
    }
}

void evalWarpPath(const Scene &scene, 
            LightPath &path) {
    if (path.isCameraPath) {
        for (int i = 1; i < path.vertices.size(); i++) {
            warpSurface(path.vertex_container[i].vertex_cache, path.vertex_container[i].warped_X, path.vertex_container[i].div_warped_X);
        }
    } else {
        for (int i = 0; i < path.vertices.size() - 1; i++) {
            warpSurface(path.vertex_container[i].vertex_cache, path.vertex_container[i].warped_X, path.vertex_container[i].div_warped_X);
        }
    }
}

void d_evalWarpPath(const Scene &scene, 
            LightPathAD &path) {
    if (path.val.isCameraPath) {
        for (int i = 1; i < path.val.vertices.size(); i++) {
            if (!path.val.vertex_container[i].hasGradient) continue;
            d_warpSurface(path.val.vertex_container[i].vertex_cache, path.der.vertex_container[i].vertex_cache, 
                        path.der.vertex_container[i].warped_X, path.der.vertex_container[i].div_warped_X);
        }
    } else {
        for (int i = 0; i < path.val.vertices.size() - 1; i++) {
            if (!path.val.vertex_container[i].hasGradient) continue;
            d_warpSurface(path.val.vertex_container[i].vertex_cache, path.der.vertex_container[i].vertex_cache, 
                        path.der.vertex_container[i].warped_X, path.der.vertex_container[i].div_warped_X);
        }
    }
}

void d_evalWarp(const Scene &scene,
                LightPathAD &camera_path, int s,
                LightPathAD &light_path, int t)
{
    
    if ( s == -1 ) {
        for (int i = 0; i < t; i++) {
            d_warpSurface(light_path.val.vertex_container[i].vertex_cache, light_path.der.vertex_container[i].vertex_cache, 
                        light_path.der.vertex_container[i].warped_X, light_path.der.vertex_container[i].div_warped_X);
        }
    } else if ( t == -1 ) {
        for (int i = 1; i <= s; i++) {
            d_warpSurface(camera_path.val.vertex_container[i].vertex_cache, camera_path.der.vertex_container[i].vertex_cache, 
                        camera_path.der.vertex_container[i].warped_X, camera_path.der.vertex_container[i].div_warped_X);
        }
    } else {
        for (int i = 1; i <= s; i++) {
            d_warpSurface(camera_path.val.vertex_container[i].vertex_cache, camera_path.der.vertex_container[i].vertex_cache, 
                        camera_path.der.vertex_container[i].warped_X, camera_path.der.vertex_container[i].div_warped_X);
        }
        for (int i = 0; i < t; i++) {
            d_warpSurface(light_path.val.vertex_container[i].vertex_cache, light_path.der.vertex_container[i].vertex_cache, 
                        light_path.der.vertex_container[i].warped_X, light_path.der.vertex_container[i].div_warped_X);
        }
        d_warpSurface(light_path.val.connect_vertex_cache,  light_path.der.connect_vertex_cache,
                    light_path.der.connect_warped_X, light_path.der.connect_div_warped_X);
    }
}

void evalPrefix(LightPath& path) {
    for (int i = 0; i < path.vertices.size(); i++) {
        path.vertex_container[i].sum_div_warped_X = path.vertex_container[i].div_warped_X;
        path.vertex_container[i].sum_warped_X = path.vertex_container[i].warped_X;
        if (i > 0) {
            path.vertex_container[i].sum_div_warped_X += path.vertex_container[i-1].sum_div_warped_X;
            path.vertex_container[i].sum_warped_X += path.vertex_container[i-1].sum_warped_X;
        }
    }
}

void d_evalPrefix(LightPathAD& path) {
    int n = path.val.vertex_container.size();
    for (int i = n - 1; i >= 0; i--) {
        // Gradient for sum_warped_X
        path.der.vertex_container[i].warped_X += path.der.vertex_container[i].sum_warped_X;
        
        // Gradient for sum_div_warped_X
        path.der.vertex_container[i].div_warped_X += path.der.vertex_container[i].sum_div_warped_X;
        
        // Propagate the gradients to the previous elements
        if (i > 0) {
            path.der.vertex_container[i-1].sum_warped_X += path.der.vertex_container[i].sum_warped_X;
            path.der.vertex_container[i-1].sum_div_warped_X += path.der.vertex_container[i].sum_div_warped_X;
        }
    }
}

Spectrum evalBoundary(const Scene &scene, const LightPath& camera_path, int s, const LightPath& light_path, int t,
                      const Spectrum &L)
{
    // std::cout << "<-------------evalBoundary------------>" << std::endl;
    Spectrum value = Spectrum::Zero();
    // Spectrum throughput = Spectrum::Ones();
    
    if ( s == -1 ) {
        for (int i = 0; i < t; i++) {
            value += L * light_path.vertex_container[i].div_warped_X;
        }
    } else if ( t == -1 ) {
        for (int i = 1; i <= s; i++) {
            value += L * camera_path.vertex_container[i].div_warped_X;
        }
    } else {
        for (int i = 0; i < t; i++) {
            value += L * light_path.vertex_container[i].div_warped_X;
        }
        for (int i = 1; i <= s; i++) {
            value += L * camera_path.vertex_container[i].div_warped_X;
        }
        value += L * light_path.connect_div_warped_X;
    }
    return value;
}

Spectrum evalBoundary_prefix(const Scene &scene, const LightPath& camera_path, int s, const LightPath& light_path, int t,
                      const Spectrum &L)
{
    // std::cout << "<-------------evalBoundary------------>" << std::endl;
    Spectrum value = Spectrum::Zero();
    // Spectrum throughput = Spectrum::Ones();
    
    if ( s == -1 ) {
        value += L * light_path.vertex_container[t].sum_div_warped_X;
    } else if ( t == -1 ) {
        value += L * camera_path.vertex_container[s].sum_div_warped_X;
    } else {
        value += L * light_path.vertex_container[t].sum_div_warped_X;
        value += L * camera_path.vertex_container[s].sum_div_warped_X;
        value += L * light_path.connect_div_warped_X;
    }
    return value;
}

void __evalBoundary(const Scene &scene, const LightPath& camera_path, int s, const LightPath& light_path, int t,
                    const Spectrum &L, Spectrum &value)
{
    value = evalBoundary(scene, camera_path, s, light_path, t, L);
}

void __evalBoundary_prefix(const Scene &scene, const LightPath& camera_path, int s, const LightPath& light_path, int t,
                    const Spectrum &L, Spectrum &value)
{
    value = evalBoundary_prefix(scene, camera_path, s, light_path, t, L);
}

void d_evalBoundary(const Scene &scene,
                    LightPathAD &camera_path, int s, LightPathAD &light_path, int t, // only use Warp and divWarp
                    const Spectrum &L,
                    Spectrum d_value)
{
    // std::cout << "<-------------evalBoundary------------>" << std::endl;
    // [[maybe_unused]] Spectrum value;
    // __enzyme_autodiff((void *)__evalBoundary,
    //                   enzyme_const, &scene,
    //                   enzyme_dup, &camera_path.val, &camera_path.der,
    //                   enzyme_const, s,
    //                   enzyme_dup, &light_path.val, &light_path.der,
    //                   enzyme_const, t,
    //                   enzyme_const, L,
    //                   enzyme_dup, &value, &d_value);
    if ( s == -1 ) {
        for (int i = 0; i < t; i++) {
            light_path.der.vertex_container[i].div_warped_X += d_value[0] * L[0] + d_value[1] * L[1] + d_value[2] * L[2];
        }
    } else if ( t == -1 ) {
        for (int i = 1; i <= s; i++) {
            camera_path.der.vertex_container[i].div_warped_X += d_value[0] * L[0] + d_value[1] * L[1] + d_value[2] * L[2];
        }
    } else {
        for (int i = 0; i < t; i++) {
            light_path.der.vertex_container[i].div_warped_X += d_value[0] * L[0] + d_value[1] * L[1] + d_value[2] * L[2];
        }
        for (int i = 1; i <= s; i++) {
            camera_path.der.vertex_container[i].div_warped_X += d_value[0] * L[0] + d_value[1] * L[1] + d_value[2] * L[2];
        }
        light_path.der.connect_div_warped_X += d_value[0] * L[0] + d_value[1] * L[1] + d_value[2] * L[2];
    }
}

void d_evalBoundary_prefix(const Scene &scene,
                    LightPathAD &camera_path, int s, LightPathAD &light_path, int t, // only use Warp and divWarp
                    const Spectrum &L,
                    Spectrum d_value)
{
    // std::cout << "<-------------evalBoundary------------>" << std::endl;
    [[maybe_unused]] Spectrum value;
    __enzyme_autodiff((void *)__evalBoundary_prefix,
                      enzyme_const, &scene,
                      enzyme_dup, &camera_path.val, &camera_path.der,
                      enzyme_const, s,
                      enzyme_dup, &light_path.val, &light_path.der,
                      enzyme_const, t,
                      enzyme_const, L,
                      enzyme_dup, &value, &d_value);
    // if ( s == -1 ) {
    //     light_path.der.vertex_container[t].sum_div_warped_X += d_value[0] * L[0] + d_value[1] * L[1] + d_value[2] * L[2];
    // } else if ( t == -1 ) {
    //     camera_path.der.vertex_container[s].sum_div_warped_X += d_value[0] * L[0] + d_value[1] * L[1] + d_value[2] * L[2];
    // } else {
    //     light_path.der.vertex_container[t].sum_div_warped_X += d_value[0] * L[0] + d_value[1] * L[1] + d_value[2] * L[2];
    //     camera_path.der.vertex_container[s].sum_div_warped_X += d_value[0] * L[0] + d_value[1] * L[1] + d_value[2] * L[2];
    //     light_path.der.connect_div_warped_X += d_value[0] * L[0] + d_value[1] * L[1] + d_value[2] * L[2];
    // }
}

NAMESPACE_END(algorithm1_bdptwas)