#include "algorithm1.h"
#include <render/scene.h>
#include <core/math_func.h>
#include <bsdf/roughconductor.h>
#include <bsdf/diffuse.h>
#include <bsdf/roughdielectric.h>
#include <mala_utils.h>

NAMESPACE_BEGIN(algorithm1_MALA_indirect)    


    void sampleEdgeRay(const Scene &scene, const Array3 rnd,
                       const DiscreteDistribution &edge_dist,
                       const std::vector<Vector2i> &edge_indices,
                       EdgeRaySamplingRecord &eRec, 
                       EdgePrimarySampleRecord *ePSRec = nullptr)
    {
        /* Sample a point on the boundary */
        scene.sampleEdgePoint(rnd[0],
                              edge_dist, edge_indices,
                              eRec, ePSRec);
        if (eRec.shape_id == -1)
        {
            PSDR_WARN(eRec.shape_id == -1);
            return;
        }
        const Shape *shape = scene.shape_list[eRec.shape_id];
        const Edge &edge = shape->edges[eRec.edge_id];
        assert(edge.f0 >= 0);

        /* Sample edge ray */
        if (edge.f1 < 0) // Case 1: boundary edge
        {
            eRec.dir = squareToUniformSphere(Vector2{rnd[1], rnd[2]});
            eRec.pdf /= 4. * M_PI;
        }
        else // Case 2: non-boundary edge
        {
            Vector n0 = shape->getGeoNormal(edge.f0);
            Vector n1 = shape->getGeoNormal(edge.f1);
            Float pdf1;
            eRec.dir = squareToEdgeRayDirection(Vector2(rnd[1], rnd[2]), n0, n1, pdf1);
            eRec.pdf *= pdf1;

            Float dotn0 = eRec.dir.dot(n0), dotn1 = eRec.dir.dot(n1);
            if (math::signum(dotn0) * math::signum(dotn1) > -0.5f)
            {
                std::cerr << "\n[WARN] Bad edge ray sample: [" << dotn0 << ", " << dotn1 << "]" << std::endl;
                PSDR_INFO("rnd: {}, {}, {}", rnd[0], rnd[1], rnd[2]);
                PSDR_INFO("f0: {}, f1: {}", edge.f0, edge.f1);
                PSDR_INFO("n0: {}, {}, {}", n0[0], n0[1], n0[2]);
                PSDR_INFO("n1: {}, {}, {}", n1[0], n1[1], n1[2]);
                PSDR_INFO("dir: {}, {}, {}", eRec.dir[0], eRec.dir[1], eRec.dir[2]);
                eRec.shape_id = -1;
            }
        }
        if (ePSRec != nullptr) {
            ePSRec->pdf = eRec.pdf;
            ePSRec->edge_id = eRec.edge_id;
            ePSRec->shape_id = eRec.shape_id;
        }
    }

    void edgePtDirToIntersection(const Vector &edge_point, const Vector &dir, const Scene &scene, 
                    const EdgePrimarySampleRecord &ePSRec, 
                    const Array2i &indicesD, const Array2i &indicesS,
                    Float &contrib, Intersection &itsD, Intersection &itsS){
        const Shape *shape = scene.shape_list[ePSRec.shape_id];
        const Edge &edge = shape->edges[ePSRec.edge_id];
        
        Ray edgeRay(edge_point, dir);

        const Shape* shapeD = scene.getShape(indicesD[0]);
        const Shape* shapeS = scene.getShape(indicesS[0]);
        

        Vector3i idxD = shapeD->getIndices(indicesD[1]);
        Vector v0D = detach(shapeD->getVertex(idxD[0]));
        Vector v1D = detach(shapeD->getVertex(idxD[1]));
        Vector v2D = detach(shapeD->getVertex(idxD[2]));

        Array uvtD = rayIntersectTriangle(v0D, v1D, v2D, edgeRay.flipped());
        Vector detector_point = edge_point + uvtD[2] * (-dir);
        Vector geo_nD = detach(shapeD->getFaceNormal(indicesD[1]));
        Vector sh_nD = detach(shapeD->getShadingNormal(indicesD[1], Vector2(uvtD[0], uvtD[1])));
        itsD.p = detector_point;
        itsD.geoFrame = Frame(geo_nD);
        itsD.shFrame = Frame(sh_nD);
        itsD.wi = itsD.toLocal(dir);
        itsD.uv = Vector2(uvtD[0], uvtD[1]);

        Vector3i idxS = shapeS->getIndices(indicesS[1]);
        Vector v0S = detach(shapeS->getVertex(idxS[0]));
        Vector v1S = detach(shapeS->getVertex(idxS[1]));
        Vector v2S = detach(shapeS->getVertex(idxS[2]));

        Array uvtS = rayIntersectTriangle(v0S, v1S, v2S, edgeRay);
        Vector emitter_point = edge_point + uvtS[2] * (dir);
        Vector geo_nS = detach(shapeS->getFaceNormal(indicesS[1]));
        Vector sh_nS = detach(shapeS->getShadingNormal(indicesS[1], Vector2(uvtS[0], uvtS[1])));
        itsS.p = detector_point;
        itsS.geoFrame = Frame(geo_nS);
        itsS.shFrame = Frame(sh_nS);
        itsS.wi = itsS.toLocal(-dir);
        itsS.uv = Vector2(uvtS[0], uvtS[1]);

        // Array uvtS = rayIntersectTriangle(dPSRec.v0, dPSRec.v1, dPSRec.v2, edgeRay);
        // Vector emitter_point = edge_point + uvtS[2] * dir;

        // Spectrum value = scene.emitter_list[dPSRec.emitter_id]->eval(light_norm, -dir);
        Float J = dlD_dlB(emitter_point,
                          edge_point, (ePSRec.v0 - ePSRec.v1).normalized(),
                          detector_point, geo_nD) *
                  dA_dw(edge_point, emitter_point, geo_nS);

        // Float sensorVal = scene.camera.eval(pixel_idx[0], pixel_idx[1], its_Cam.p);
        
        // Spectrum bsdf_val = its_Cam.evalBSDF(-dir);
        Float baseValue = J * geometric(detector_point, geo_nD, emitter_point, geo_nS);
        contrib = baseValue / ePSRec.pdf; // eRec.pdf, bsdf_val
    }

    void d_edgePtDirToIntersection(const Vector &edge_point, const Vector &dir, Vector &d_edge_point, Vector &d_dir, const Scene &scene, Scene &d_scene, 
                    const EdgePrimarySampleRecord &ePSRec, 
                    const Array2i &indicesD, const Array2i &indicesS,
                    Float d_baseValue, Intersection d_itsD, Intersection d_itsS) {
        [[maybe_unused]] Float baseValue;
        [[maybe_unused]] Intersection itsD, itsS;
#if defined(ENZYME) && defined(ENZYME_BOUNDARY_DIRECT)
        __enzyme_autodiff((void *)edgePtDirToIntersection,
                            enzyme_dup, &edge_point, &d_edge_point,
                            enzyme_dup, &dir, &d_dir,
                            enzyme_dup, &scene, &d_scene,
                            enzyme_const, &ePSRec,
                            enzyme_const, &indicesD,
                            enzyme_const, &indicesS,
                            enzyme_dup, &baseValue, &d_baseValue,
                            enzyme_dup, &itsD, &d_itsD,
                            enzyme_dup, &itsS, &d_itsS);
#endif
    }

    Float dA_dw_att(const Vector &ref,
                const Vector &p, const Vector &n)
    {
        return (p - ref).squaredNorm() / abs(n.dot(-(p - ref).normalized()));
    }

    Float dlS_dlB_att(const Vector &xD,
                    const Vector &xB, const Vector &v,
                    const Vector &xS, const Vector &n)
    {
        Vector xD_xS = xS - xD;
        Vector xD_xB = xB - xD;
        Vector dir = xD_xS.normalized();
        Vector proj = project(v, dir, n);
        Float sinB = v.cross(dir).norm();
        Float sinD = dir.cross(proj).norm();
        Float J = 0.;
        if (sinB > Epsilon && sinD > Epsilon)
        {
            J = xD_xS.norm() / xD_xB.norm() * sinB / sinD;
        }
        return J;
    }

    Float dlD_dlB_att(const Vector &xS,
                                        const Vector &xB, const Vector &v,
                                        const Vector &xD, const Vector &n)
    {
        return dlS_dlB_att(xS,
                    xB, v,
                    xD, n);
    }

    void getEdgeIntersectionDir(const Vector &rnd, const Scene &scene, 
                    const EdgePrimarySampleRecord &ePSRec, 
                    const Array2i &indicesD, const Array2i &indicesS,
                    Float &contrib, Intersection &itsD, Intersection &itsS) {
        Float pdf = 1.0;
        Vector _rnd(rnd);
        _rnd[0] = (_rnd[0] - ePSRec.offset) / ePSRec.scale;
        pdf = ePSRec.scale;
        Vector edge_point = ePSRec.v0 + _rnd[0] * (ePSRec.v1 - ePSRec.v0);
        pdf /= (ePSRec.v1 - ePSRec.v0).norm();
        const Shape *shape = scene.shape_list[ePSRec.shape_id];
        const Edge &edge = shape->edges[ePSRec.edge_id];
        
        Vector dir; // = squareToUniformSphere(Vector2{_rnd[1], _rnd[2]});
        assert(edge.f0 >= 0);
        if (edge.f1 < 0) // Case 1: boundary edge
        {
            dir = squareToUniformSphere(Vector2{rnd[1], rnd[2]});
            pdf /= 4. * M_PI;
        }
        else // Case 2: non-boundary edge
        {
            Vector n0 = shape->getGeoNormal(edge.f0);
            Vector n1 = shape->getGeoNormal(edge.f1);
            Float pdf1;
            dir = squareToEdgeRayDirection(Vector2(rnd[1], rnd[2]), n0, n1, pdf1);
            pdf *= pdf1;
        }
        
        Ray edgeRay(edge_point, dir);

        const Shape* shapeD = scene.getShape(indicesD[0]);
        const Shape* shapeS = scene.getShape(indicesS[0]);
        

        Vector3i idxD = shapeD->getIndices(indicesD[1]);
        Vector v0D = detach(shapeD->getVertex(idxD[0]));
        Vector v1D = detach(shapeD->getVertex(idxD[1]));
        Vector v2D = detach(shapeD->getVertex(idxD[2]));

        Array uvtD = rayIntersectTriangle(v0D, v1D, v2D, edgeRay.flipped());
        Vector detector_point = edge_point + uvtD[2] * (-dir);
        Vector geo_nD = detach(shapeD->getFaceNormal(indicesD[1]));
        Vector sh_nD = detach(shapeD->getShadingNormal(indicesD[1], Vector2(uvtD[0], uvtD[1])));
        itsD.p = detector_point;
        itsD.geoFrame = Frame(geo_nD);
        itsD.shFrame = Frame(sh_nD);
        itsD.wi = itsD.toLocal(dir);
        itsD.uv = Vector2(uvtD[0], uvtD[1]);

        Vector3i idxS = shapeS->getIndices(indicesS[1]);
        Vector v0S = detach(shapeS->getVertex(idxS[0]));
        Vector v1S = detach(shapeS->getVertex(idxS[1]));
        Vector v2S = detach(shapeS->getVertex(idxS[2]));

        Array uvtS = rayIntersectTriangle(v0S, v1S, v2S, edgeRay);
        Vector emitter_point = edge_point + uvtS[2] * (dir);
        Vector geo_nS = detach(shapeS->getFaceNormal(indicesS[1]));
        Vector sh_nS = detach(shapeS->getShadingNormal(indicesS[1], Vector2(uvtS[0], uvtS[1])));
        itsS.p = emitter_point;
        itsS.geoFrame = Frame(geo_nS);
        itsS.shFrame = Frame(sh_nS);
        itsS.wi = itsS.toLocal(-dir);
        itsS.uv = Vector2(uvtS[0], uvtS[1]);

        // Array uvtS = rayIntersectTriangle(dPSRec.v0, dPSRec.v1, dPSRec.v2, edgeRay);
        // Vector emitter_point = edge_point + uvtS[2] * dir;

        // Spectrum value = scene.emitter_list[dPSRec.emitter_id]->eval(light_norm, -dir);
        Float J = dlD_dlB_att(emitter_point,
                          edge_point, (ePSRec.v0 - ePSRec.v1).normalized(),
                          detector_point, geo_nD) *
                  dA_dw_att(edge_point, emitter_point, geo_nS);

        // Float sensorVal = scene.camera.eval(pixel_idx[0], pixel_idx[1], its_Cam.p);
        
        // Spectrum bsdf_val = its_Cam.evalBSDF(-dir);
        Float baseValue = J * geometric(detector_point, geo_nD, emitter_point, geo_nS);
        // assert(abs(pdf - ePSRec.pdf) < 1e-6);
        contrib = baseValue / pdf; // eRec.pdf, bsdf_val
    }

    void d_getEdgeIntersectionDir(const Vector &rnd, Vector &d_rnd, const Scene &scene, Scene &d_scene,
                    const EdgePrimarySampleRecord &ePSRec, 
                    const Array2i &indicesD, const Array2i &indicesS,
                    Float d_baseValue, Intersection d_itsD, Intersection d_itsS) {
        [[maybe_unused]] Float baseValue;
        [[maybe_unused]] Intersection itsD, itsS;
#if defined(ENZYME) && defined(ENZYME_BOUNDARY_DIRECT)
        __enzyme_autodiff((void *)getEdgeIntersectionDir,
                            enzyme_dup, &rnd, &d_rnd,
                            enzyme_dup, &scene, &d_scene,
                            enzyme_const, &ePSRec,
                            enzyme_const, &indicesD,
                            enzyme_const, &indicesS,
                            enzyme_dup, &baseValue, &d_baseValue,
                            enzyme_dup, &itsD, &d_itsD,
                            enzyme_dup, &itsS, &d_itsS);
#endif
    }

    // void getEdgeIntersection(const Vector &rnd, const Scene &scene, 
    //                 const EdgePrimarySampleRecord &ePSRec, const EmitterPrimarySampleRecord &dPSRec, 
    //                 const Array2i &indices,
    //                 Spectrum &contrib) {

    //     /* Sample point on emitters */
    //     // DirectSamplingRecord dRec(eRec.ref);
    //     Vector light_norm = dPSRec.n;
    //     Vector _rnd(rnd);
    //     _rnd[0] = (_rnd[0] - ePSRec.offset) / ePSRec.scale;

    //     Vector edge_point = ePSRec.v0 + _rnd[0] * (ePSRec.v1 - ePSRec.v0);

    //     Vector emitter_point;
    //     if (!dPSRec.continuous){ // sample_triangle -> sample barycentric
    //         _rnd[1] = (_rnd[1] - dPSRec.offset) / dPSRec.scale;
    //         Float a = std::sqrt(_rnd[1]);
    //         Float u = 1.0f - a, v = a * _rnd[2];
    //         emitter_point = dPSRec.v0 + (dPSRec.v1 - dPSRec.v0) * u + (dPSRec.v2 - dPSRec.v0) * v;
    //     } else { // sample UV -> find barycentric -> find position
    //         Vector query_point = Vector(_rnd[1], _rnd[2], 1.0);
    //         Ray query_ray = Ray(query_point, Vector(0.0, 0.0, -1.0));
    //         Array uvt = rayIntersectTriangle(dPSRec.uv_mesh0, dPSRec.uv_mesh1, dPSRec.uv_mesh2, query_ray); // find barycentric coord
    //         emitter_point = dPSRec.v0 + (dPSRec.v1 - dPSRec.v0) * uvt[0] + (dPSRec.v2 - dPSRec.v0) * uvt[1]; // interpolate position
    //     }

    //     Vector dir = edge_point - emitter_point;
    //     Float dist = dir.norm();
    //     dir /= dist;
    //     // value = scene.emitter_list[dPSRec.emitter_id]->eval(light_norm, dir);
        
    //     Ray ray(edge_point, dir);

    //     // Intersection its;
    //     // scene.rayIntersect(ray, true, its);
    //     Shape* shape_cam = scene.getShape(indices[0]);
    //     // scene.rayIntersect(ray, true, its);
    //     // shape_cam->rayIntersect(its_Cam.indices[1], ray, holder);
    //     Vector3i idx_cam = shape_cam->getIndices(indices[1]);
    //     Vector v0_cam = detach(shape_cam->getVertex(idx_cam[0]));
    //     Vector v1_cam = detach(shape_cam->getVertex(idx_cam[1]));
    //     Vector v2_cam = detach(shape_cam->getVertex(idx_cam[2]));

    //     Array uvt = rayIntersectTriangle(v0_cam, v1_cam, v2_cam, ray);
    //     Vector p = edge_point + uvt[2] * dir;
    //     Vector geo_n = detach(shape_cam->getFaceNormal(indices[1]));

    //     Float G = std::abs(dPSRec.n.dot(dir)) / (dist * dist);

    //     Spectrum value = scene.emitter_list[dPSRec.emitter_id]->eval(light_norm, dir);
    //     value = value * G / dPSRec.pdf; // compensate for sampling on emitter surface
    //     Float J = dlD_dlB(emitter_point,
    //                       edge_point, (ePSRec.v0 - ePSRec.v1).normalized(),
    //                       p, geo_n) *
    //               dA_dw(edge_point, emitter_point, light_norm);

    //     // // Float sensorVal = scene.camera.eval(pixel_idx[0], pixel_idx[1], its_Cam.p);
        
    //     // // Spectrum bsdf_val = its_Cam.evalBSDF(-dir);
    //     // int bsdf_id = its.ptr_shape->bsdf_id;
    //     Float baseValue = J * geometric(p, geo_n, emitter_point, light_norm);
    //     contrib = baseValue * value / ePSRec.pdf; // eRec.pdf, bsdf_val
    // }
    

//     void d_getEdgeIntersection(const Vector &rnd, Vector &d_rnd, const Scene &scene, Scene &d_scene,
//                     const EdgePrimarySampleRecord &ePSRec, const EmitterPrimarySampleRecord &dPSRec, 
//                     const Array2i &indices,
//                     Spectrum &d_contrib) {
//         [[maybe_unused]] Spectrum contrib;
// #if defined(ENZYME) && defined(ENZYME_BOUNDARY_DIRECT)
//         __enzyme_autodiff((void *)getEdgeIntersection,
//                             enzyme_dup, &rnd, &d_rnd,
//                             enzyme_dup, &scene, &d_scene,
//                             enzyme_const, &ePSRec,
//                             enzyme_const, &dPSRec,
//                             enzyme_const, &indices,
//                             enzyme_dup, &contrib, &d_contrib);
// #endif
//     }

    void evalIntersection(const Scene &scene, const BSDF &bsdf, const Vector &rnd, const Array2i &next_indices, const EBSDFMode &mode,
                            const Intersection &cur, Intersection &next) { // assuming that we already have something in next
        Vector wo_local, wo;
        Float bsdf_pdf, bsdf_eta;
        int bsdf_id = cur.ptr_shape->bsdf_id;
        next.value = bsdf.sample(cur, rnd, wo_local,
                        bsdf_pdf, bsdf_eta,
                        mode);
        wo = cur.toWorld(wo_local);
        Ray ray = Ray(cur.p, wo);
        Shape* shape = scene.getShape(next_indices[0]);
        Vector3i idx = shape->getIndices(next_indices[1]);
        Vector v0_next = detach(shape->getVertex(idx[0]));
        Vector v1_next = detach(shape->getVertex(idx[1]));
        Vector v2_next = detach(shape->getVertex(idx[2]));
        Array uvt = rayIntersectTriangle(v0_next, v1_next, v2_next, ray);
        next.p = cur.p + uvt[2] * wo;
        Vector geo_n = detach(shape->getFaceNormal(next_indices[1]));
        Vector sh_n = detach(shape->getShadingNormal(next_indices[1], Vector2(uvt[0], uvt[1])));
        next.geoFrame = Frame(geo_n);
        next.shFrame = Frame(sh_n);
        next.wi = next.toLocal(-wo);
        next.uv = Vector2(uvt[0], uvt[1]);
        return;
    }

    void d_evalIntersection(const Scene &scene, Scene &d_scene,
                            const BSDF &bsdf, 
                            const Vector &rnd, Vector &d_rnd, 
                            const Array2i next_indices, 
                            const EBSDFMode &mode,
                            const Intersection &cur, Intersection &d_cur,
                            Intersection &d_its) {
        [[maybe_unused]] Intersection its;
#if defined(ENZYME) && defined(ENZYME_BOUNDARY_DIRECT)
        __enzyme_autodiff((void *)evalIntersection,
                            enzyme_dup, &scene, &d_scene,
                            enzyme_const, &bsdf,
                            enzyme_dup, &rnd, &d_rnd,
                            enzyme_const, &next_indices,
                            enzyme_const, &mode,
                            enzyme_dup, &cur, &d_cur,
                            enzyme_dup, &its, &d_its);
#endif
    }

    // I wish we didn't have to use this macro, but it's the only way to get Enzyme to work with the codebase
    void evalIntersectionRDRad(const Scene &scene, const RoughDielectricBSDF &rd, const Vector &rnd, const Array2i next_indices, 
                            const Intersection &cur, Intersection &next) { // assuming that we already have something in next
        Vector wo_local, wo;
        Float bsdf_pdf, bsdf_eta;
        int bsdf_id = cur.ptr_shape->bsdf_id;
        next.value = rd.sample(cur, rnd, wo_local,
                        bsdf_pdf, bsdf_eta,
                        EBSDFMode::ERadiance);
        wo = cur.toWorld(wo_local);
        Ray ray = Ray(cur.p, wo);
        Shape* shape = scene.getShape(next_indices[0]);
        Vector3i idx = shape->getIndices(next_indices[1]);
        Vector v0_next = detach(shape->getVertex(idx[0]));
        Vector v1_next = detach(shape->getVertex(idx[1]));
        Vector v2_next = detach(shape->getVertex(idx[2]));
        Array uvt = rayIntersectTriangle(v0_next, v1_next, v2_next, ray);
        next.p = cur.p + uvt[2] * wo;
        Vector geo_n = detach(shape->getFaceNormal(next_indices[1]));
        Vector sh_n = detach(shape->getShadingNormal(next_indices[1], Vector2(uvt[0], uvt[1])));
        next.geoFrame = Frame(geo_n);
        next.shFrame = Frame(sh_n);
        next.wi = next.toLocal(-wo);
        next.uv = Vector2(uvt[0], uvt[1]);
        return;
    }

    void evalIntersectionRDIC(const Scene &scene, const RoughDielectricBSDF &rd, const Vector &rnd, const Array2i next_indices, 
                            const Intersection &cur, Intersection &next) { // assuming that we already have something in next
        Vector wo_local, wo;
        Float bsdf_pdf, bsdf_eta;
        int bsdf_id = cur.ptr_shape->bsdf_id;
        next.value = rd.sample(cur, rnd, wo_local,
                        bsdf_pdf, bsdf_eta,
                        EBSDFMode::EImportanceWithCorrection);
        wo = cur.toWorld(wo_local);
        Ray ray = Ray(cur.p, wo);
        Shape* shape = scene.getShape(next_indices[0]);
        Vector3i idx = shape->getIndices(next_indices[1]);
        Vector v0_next = detach(shape->getVertex(idx[0]));
        Vector v1_next = detach(shape->getVertex(idx[1]));
        Vector v2_next = detach(shape->getVertex(idx[2]));
        Array uvt = rayIntersectTriangle(v0_next, v1_next, v2_next, ray);
        next.p = cur.p + uvt[2] * wo;
        Vector geo_n = detach(shape->getFaceNormal(next_indices[1]));
        Vector sh_n = detach(shape->getShadingNormal(next_indices[1], Vector2(uvt[0], uvt[1])));
        next.geoFrame = Frame(geo_n);
        next.shFrame = Frame(sh_n);
        next.wi = next.toLocal(-wo);
        next.uv = Vector2(uvt[0], uvt[1]);
        return;
    }

    void d_evalIntersectionRDRad(const Scene &scene, Scene &d_scene,
                            const RoughDielectricBSDF &rd, 
                            const Vector &rnd, Vector &d_rnd, 
                            const Array2i next_indices, 
                            const Intersection &cur, Intersection &d_cur,
                            Intersection &d_its) {
        [[maybe_unused]] Intersection its;
#if defined(ENZYME) && defined(ENZYME_BOUNDARY_DIRECT)
        __enzyme_autodiff((void *)evalIntersectionRDRad,
                            enzyme_dup, &scene, &d_scene,
                            enzyme_const, &rd,
                            enzyme_dup, &rnd, &d_rnd,
                            enzyme_const, &next_indices,
                            enzyme_dup, &cur, &d_cur,
                            enzyme_dup, &its, &d_its);
#endif
    }

    void d_evalIntersectionRDIC(const Scene &scene, Scene &d_scene,
                            const RoughDielectricBSDF &rd, 
                            const Vector &rnd, Vector &d_rnd, 
                            const Array2i next_indices, 
                            const Intersection &cur, Intersection &d_cur,
                            Intersection &d_its) {
        [[maybe_unused]] Intersection its;
#if defined(ENZYME) && defined(ENZYME_BOUNDARY_DIRECT)
        __enzyme_autodiff((void *)evalIntersectionRDIC,
                            enzyme_dup, &scene, &d_scene,
                            enzyme_const, &rd,
                            enzyme_dup, &rnd, &d_rnd,
                            enzyme_const, &next_indices,
                            enzyme_dup, &cur, &d_cur,
                            enzyme_dup, &its, &d_its);
#endif
    }

    void evalIntersectionRCRad(const Scene &scene, const RoughConductorBSDF &rc, const Vector &rnd, 
                            const Array2i next_indices, const Intersection &cur, Intersection &next) {
        Vector wo_local, wo;
        Float bsdf_pdf, bsdf_eta;
        int bsdf_id = cur.ptr_shape->bsdf_id;
        next.value = rc.sample(cur, rnd, wo_local,
                        bsdf_pdf, bsdf_eta,
                        EBSDFMode::EImportanceWithCorrection);
        wo = cur.toWorld(wo_local);
        Ray ray = Ray(cur.p, wo);
        Shape* shape = scene.getShape(next_indices[0]);
        Vector3i idx = shape->getIndices(next_indices[1]);
        Vector v0_next = detach(shape->getVertex(idx[0]));
        Vector v1_next = detach(shape->getVertex(idx[1]));
        Vector v2_next = detach(shape->getVertex(idx[2]));
        Array uvt = rayIntersectTriangle(v0_next, v1_next, v2_next, ray);
        next.p = cur.p + uvt[2] * wo;
        Vector geo_n = detach(shape->getGeoNormal(next_indices[1]));
        Vector sh_n = detach(shape->getShadingNormal(next_indices[1], Vector2(uvt[0], uvt[1])));
        next.geoFrame = Frame(geo_n);
        next.shFrame = Frame(sh_n);
        next.wi = next.toLocal(-wo);
        next.uv = Vector2(uvt[0], uvt[1]);
        return;
    }

    void evalIntersectionRCIC(const Scene &scene, const RoughConductorBSDF &rc, const Vector &rnd, 
                            const Array2i next_indices, const Intersection &cur, Intersection &next) {
        Vector wo_local, wo;
        Float bsdf_pdf, bsdf_eta;
        int bsdf_id = cur.ptr_shape->bsdf_id;
        next.value = rc.sample(cur, rnd, wo_local,
                        bsdf_pdf, bsdf_eta,
                        EBSDFMode::EImportanceWithCorrection);
        wo = cur.toWorld(wo_local);
        Ray ray = Ray(cur.p, wo);
        Shape* shape = scene.getShape(next_indices[0]);
        Vector3i idx = shape->getIndices(next_indices[1]);
        Vector v0_next = detach(shape->getVertex(idx[0]));
        Vector v1_next = detach(shape->getVertex(idx[1]));
        Vector v2_next = detach(shape->getVertex(idx[2]));
        Array uvt = rayIntersectTriangle(v0_next, v1_next, v2_next, ray);
        next.p = cur.p + uvt[2] * wo;
        Vector geo_n = detach(shape->getGeoNormal(next_indices[1]));
        Vector sh_n = detach(shape->getShadingNormal(next_indices[1], Vector2(uvt[0], uvt[1])));
        next.geoFrame = Frame(geo_n);
        next.shFrame = Frame(sh_n);
        next.wi = next.toLocal(-wo);
        next.uv = Vector2(uvt[0], uvt[1]);
        return;
    }

    void d_evalIntersectionRCRad(const Scene &scene, Scene &d_scene,
                            const RoughConductorBSDF &rc,
                            const Vector &rnd, Vector &d_rnd, 
                            const Array2i next_indices, 
                            const Intersection &cur, Intersection &d_cur,
                            Intersection &d_its) {
        [[maybe_unused]] Intersection val;
#if defined(ENZYME) && defined(ENZYME_BOUNDARY_DIRECT)
        __enzyme_autodiff((void *)evalIntersectionRCRad,
                            enzyme_dup, &scene, &d_scene,
                            enzyme_const, &rc,
                            enzyme_dup, &rnd, &d_rnd,
                            enzyme_const, &next_indices,
                            enzyme_dup, &cur, &d_cur,
                            enzyme_dup, &val, &d_its);
#endif
    }

    void d_evalIntersectionRCIC(const Scene &scene, Scene &d_scene,
                            const RoughConductorBSDF &rc,
                            const Vector &rnd, Vector &d_rnd, 
                            const Array2i next_indices, 
                            const Intersection &cur, Intersection &d_cur,
                            Intersection &d_its) {
        [[maybe_unused]] Intersection val;
#if defined(ENZYME) && defined(ENZYME_BOUNDARY_DIRECT)
        __enzyme_autodiff((void *)evalIntersectionRCIC,
                            enzyme_dup, &scene, &d_scene,
                            enzyme_const, &rc,
                            enzyme_dup, &rnd, &d_rnd,
                            enzyme_const, &next_indices,
                            enzyme_dup, &cur, &d_cur,
                            enzyme_dup, &val, &d_its);
#endif
    }

    void evalIntersectionDiffuseRad(const Scene &scene, const DiffuseBSDF &diffuse, const Vector &rnd, const Array2i next_indices, const Intersection &cur, Intersection &next) { // assuming that we already have something in next
        Vector wo_local, wo;
        Float bsdf_pdf, bsdf_eta;
        next.value = diffuse.sample(cur, rnd, wo_local,
                        bsdf_pdf, bsdf_eta,
                        EBSDFMode::ERadiance);
        wo = cur.toWorld(wo_local);
        Ray ray = Ray(cur.p, wo);
        Shape* shape = scene.getShape(next_indices[0]);
        Vector3i idx = shape->getIndices(next_indices[1]);
        Vector v0_next = detach(shape->getVertex(idx[0]));
        Vector v1_next = detach(shape->getVertex(idx[1]));
        Vector v2_next = detach(shape->getVertex(idx[2]));
        Array uvt = rayIntersectTriangle(v0_next, v1_next, v2_next, ray);
        next.p = cur.p + uvt[2] * wo;
        Vector geo_n = detach(shape->getGeoNormal(next_indices[1]));
        Vector sh_n = detach(shape->getShadingNormal(next_indices[1], Vector2(uvt[0], uvt[1])));
        next.geoFrame = Frame(geo_n);
        next.shFrame = Frame(sh_n);
        next.wi = next.toLocal(-wo);
        next.uv = detach(Vector2(uvt[0], uvt[1]));
        return;
    }

    void evalIntersectionDiffuseIC(const Scene &scene, const DiffuseBSDF &diffuse, const Vector &rnd, const Array2i next_indices, const Intersection &cur, Intersection &next) { // assuming that we already have something in next
        Vector wo_local, wo;
        Float bsdf_pdf, bsdf_eta;
        next.value = diffuse.sample(cur, rnd, wo_local,
                        bsdf_pdf, bsdf_eta,
                        EBSDFMode::EImportanceWithCorrection);
        wo = cur.toWorld(wo_local);
        Ray ray = Ray(cur.p, wo);
        Shape* shape = scene.getShape(next_indices[0]);
        Vector3i idx = shape->getIndices(next_indices[1]);
        Vector v0_next = detach(shape->getVertex(idx[0]));
        Vector v1_next = detach(shape->getVertex(idx[1]));
        Vector v2_next = detach(shape->getVertex(idx[2]));
        Array uvt = rayIntersectTriangle(v0_next, v1_next, v2_next, ray);
        next.p = cur.p + uvt[2] * wo;
        Vector geo_n = detach(shape->getGeoNormal(next_indices[1]));
        Vector sh_n = detach(shape->getShadingNormal(next_indices[1], Vector2(uvt[0], uvt[1])));
        next.geoFrame = Frame(geo_n);
        next.shFrame = Frame(sh_n);
        next.wi = next.toLocal(-wo);
        next.uv = detach(Vector2(uvt[0], uvt[1]));
        return;
    }

    void d_evalIntersectionDiffuseRad(const Scene &scene, Scene &d_scene,
                            const DiffuseBSDF &diffuse, 
                            const Vector &rnd, Vector &d_rnd, 
                            const Array2i next_indices, 
                            const Intersection &cur, Intersection &d_cur,
                            Intersection &d_its) {
        [[maybe_unused]] Intersection its;
#if defined(ENZYME) && defined(ENZYME_BOUNDARY_DIRECT)
        __enzyme_autodiff((void *)evalIntersectionDiffuseRad,
                            enzyme_dup, &scene, &d_scene,
                            enzyme_const, &diffuse,
                            enzyme_dup, &rnd, &d_rnd,
                            enzyme_const, &next_indices,
                            enzyme_dup, &cur, &d_cur,
                            enzyme_dup, &its, &d_its);
#endif
    }

    void d_evalIntersectionDiffuseIC(const Scene &scene, Scene &d_scene,
                            const DiffuseBSDF &diffuse, 
                            const Vector &rnd, Vector &d_rnd, 
                            const Array2i next_indices, 
                            const Intersection &cur, Intersection &d_cur,
                            Intersection &d_its) {
        [[maybe_unused]] Intersection its;
#if defined(ENZYME) && defined(ENZYME_BOUNDARY_DIRECT)
        __enzyme_autodiff((void *)evalIntersectionDiffuseIC,
                            enzyme_dup, &scene, &d_scene,
                            enzyme_const, &diffuse,
                            enzyme_dup, &rnd, &d_rnd,
                            enzyme_const, &next_indices,
                            enzyme_dup, &cur, &d_cur,
                            enzyme_dup, &its, &d_its);
#endif
    }

    void evalCamera(const Intersection &its, const Scene &scene, const Spectrum &weight,
                        Spectrum &contrib) 
    {
        Vector cam_dir = (scene.camera.cpos - its.p).normalized();
        int bsdf_id = its.ptr_shape->bsdf_id;
        Spectrum bsdf_val = scene.bsdf_list[bsdf_id]->eval(its, its.toLocal(cam_dir), EBSDFMode::EImportanceWithCorrection);
        contrib = weight * bsdf_val;
    }

    void evalCamera_echo(const Intersection &its, const Scene &scene, const Spectrum &weight,
                        Spectrum &contrib) 
    {
        // if (its.pixel_idx[0] < 0 || its.pixel_idx[0] > scene.camera.getNumPixels()){
        //     contrib = Spectrum::Zero();
        //     return;
        // }
        Vector cam_dir = (scene.camera.cpos - its.p).normalized();
        int bsdf_id = its.ptr_shape->bsdf_id;
        // Spectrum bsdf_val = scene.bsdf_list[bsdf_id]->eval(its, its.toLocal(cam_dir), EBSDFMode::EImportanceWithCorrection);
        Spectrum bsdf_val = its.evalBSDF(its.toLocal(cam_dir),
                                     EBSDFMode::EImportanceWithCorrection);
        PSDR_INFO("bsdf_val: {}, {}, {}", bsdf_val[0], bsdf_val[1], bsdf_val[2]);
        PSDR_INFO("cam_dir: {}, {}, {}", its.toLocal(cam_dir)[0], its.toLocal(cam_dir)[1], its.toLocal(cam_dir)[2]);
        contrib = weight * bsdf_val;
    }

    void d_evalCamera(const Intersection &its, Intersection &d_its, const Scene &scene, Scene &d_scene, const Spectrum &weight, Spectrum &d_weight,
                        Spectrum &d_contrib) 
    {
        [[maybe_unused]] Spectrum contrib;
#if defined(ENZYME) && defined(ENZYME_BOUNDARY_DIRECT)
        __enzyme_autodiff((void *)evalCamera,
                            enzyme_dup, &its, &d_its,
                            enzyme_dup, &scene, &d_scene,
                            enzyme_dup, &weight, &d_weight,
                            enzyme_dup, &contrib, &d_contrib);
#endif
    }

    void evalCameraRC(const Intersection &its, const Vector &cam_p, const RoughConductorBSDF &rc, const Spectrum &weight,
                        Spectrum &contrib) 
    {
        Vector cam_dir = (cam_p - its.p).normalized();
        Spectrum bsdf_val = rc.eval(its, its.toLocal(cam_dir), EBSDFMode::EImportanceWithCorrection);
        contrib = weight * bsdf_val;
    }

    void d_evalCameraRC(const Intersection &its, Intersection &d_its, const Vector &cam_p, const RoughConductorBSDF &rc, const Spectrum &weight, Spectrum &d_weight,
                        Spectrum &d_contrib) 
    {
        [[maybe_unused]] Spectrum contrib;
#if defined(ENZYME) && defined(ENZYME_BOUNDARY_DIRECT)
        __enzyme_autodiff((void *)evalCameraRC,
                            enzyme_dup, &its, &d_its,
                            enzyme_const, &cam_p,
                            enzyme_const, &rc,
                            enzyme_dup, &weight, &d_weight,
                            enzyme_dup, &contrib, &d_contrib);
#endif
    }

    void evalCameraRD(const Intersection &its, const Vector &cam_p, const RoughDielectricBSDF &rd, const Spectrum &weight,
                        Spectrum &contrib) 
    {
        Vector cam_dir = (cam_p - its.p).normalized();
        Spectrum bsdf_val = rd.eval(its, its.toLocal(cam_dir), EBSDFMode::EImportanceWithCorrection);
        contrib = weight * bsdf_val;
    }

    void d_evalCameraRD(const Intersection &its, Intersection &d_its, const Vector &cam_p, const RoughDielectricBSDF &rd, const Spectrum &weight, Spectrum &d_weight,
                        Spectrum &d_contrib) 
    {
        [[maybe_unused]] Spectrum contrib;
#if defined(ENZYME) && defined(ENZYME_BOUNDARY_DIRECT)
        __enzyme_autodiff((void *)evalCameraRD,
                            enzyme_dup, &its, &d_its,
                            enzyme_const, &cam_p,
                            enzyme_const, &rd,
                            enzyme_dup, &weight, &d_weight,
                            enzyme_dup, &contrib, &d_contrib);
#endif
    }

    void evalCameraDi(const Intersection &its, const Vector &cam_p, const DiffuseBSDF &di, const Spectrum &weight,
                        Spectrum &contrib) 
    {
        Vector cam_dir = (cam_p - its.p).normalized();
        Spectrum bsdf_val = di.eval(its, its.toLocal(cam_dir), EBSDFMode::EImportanceWithCorrection);
        contrib = weight * bsdf_val;
    }

    void d_evalCameraDi(const Intersection &its, Intersection &d_its, const Vector &cam_p, const DiffuseBSDF &di, const Spectrum &weight, Spectrum &d_weight,
                        Spectrum &d_contrib) 
    {
        [[maybe_unused]] Spectrum contrib;
#if defined(ENZYME) && defined(ENZYME_BOUNDARY_DIRECT)
        __enzyme_autodiff((void *)evalCameraDi,
                            enzyme_dup, &its, &d_its,
                            enzyme_const, &cam_p,
                            enzyme_const, &di,
                            enzyme_dup, &weight, &d_weight,
                            enzyme_dup, &contrib, &d_contrib);
#endif
    }

    void evalLight(const Vector &its_p, const int &light_id, const Vector &its_n, const Vector &its_prev_p, const Scene &scene, const Spectrum &weight, Spectrum &contrib) {
        // const Shape *ptr_emitter = scene.shape_list[its.shape_id];
        const Emitter *emitter = scene.emitter_list[light_id];
        Vector dir = (its_prev_p - its_p).normalized();
        // Spectrum light_contrib = its.Le(-dir);
        Spectrum light_contrib = emitter->eval(its_n, dir);
        if (!light_contrib.isZero(Epsilon))
        {
            contrib = weight * light_contrib;
            // rRec.values[depth + 1] += throughput * light_contrib;
        }
    }

    void evalLightDir(const Intersection &its, const Scene &scene, const Spectrum &weight, Spectrum &contrib) {
        const Shape *ptr_emitter = scene.shape_list[its.shape_id];
        const Emitter *emitter = scene.emitter_list[ptr_emitter->light_id];
        Vector dir = its.toWorld(its.wi);
        Spectrum light_contrib =  emitter->eval(its.geoFrame.n, dir);
        if (!light_contrib.isZero(Epsilon))
        {
            contrib = weight * light_contrib;
            // rRec.values[depth + 1] += throughput * light_contrib;
        }
    }

    void d_evalLight(const Vector &its_p, Vector &d_its_p, const int &light_id, const Vector &its_n, const Vector &its_prev_p, Vector &d_its_prev_p, 
                    const Scene &scene, Scene &d_scene, const Spectrum &weight, Spectrum &d_weight,
                    Spectrum &d_contrib) 
    {
        [[maybe_unused]] Spectrum contrib;
#if defined(ENZYME) && defined(ENZYME_BOUNDARY_DIRECT)
        __enzyme_autodiff((void *)evalLight,
                            enzyme_dup, &its_p, &d_its_p,
                            enzyme_const, &light_id,
                            enzyme_const, &its_n,
                            enzyme_dup, &its_prev_p, &d_its_prev_p,
                            enzyme_dup, &scene, &d_scene,
                            enzyme_dup, &weight, &d_weight,
                            enzyme_dup, &contrib, &d_contrib);
#endif
    }

    void d_evalLightDir(const Intersection &its, Intersection &d_its, 
                    const Scene &scene, Scene &d_scene, const Spectrum &weight, Spectrum &d_weight,
                    Spectrum &d_contrib) 
    {
        [[maybe_unused]] Spectrum contrib;
#if defined(ENZYME) && defined(ENZYME_BOUNDARY_DIRECT)
        __enzyme_autodiff((void *)evalLightDir,
                            enzyme_dup, &its, &d_its,
                            enzyme_dup, &scene, &d_scene,
                            enzyme_dup, &weight, &d_weight,
                            enzyme_dup, &contrib, &d_contrib);
#endif
    }

    void evalPath(const Scene &scene, const LightPathPSS &path, Float &value)
    {
        // value = 0.0;
        // Spectrum throughput = path.baseValue;
        Spectrum throughputD(1.f), throughputS(1.f);

        if (path.verticesD.size() < 1)
            return;
        // value += edgeContrib;
        Spectrum valueD(0.f), valueS(0.f);
        for (int i = 1; i < path.verticesD.size(); i++)
        {
            throughputD *= path.verticesD[i].value;
        }
        evalCamera(path.verticesD[path.verticesD.size() - 1], scene, throughputD, valueD);
        for (int i = 1; i < path.verticesS.size(); i++)
        {
            throughputS *= path.verticesS[i].value;
        }
        if (path.verticesS.size() > 1){
            const Intersection &curr = path.verticesS[path.verticesS.size() - 1];
            const Shape *ptr_emitter = scene.shape_list[curr.shape_id];
            int light_id = ptr_emitter->light_id;
            evalLight(path.verticesS[path.verticesS.size() - 1].p, light_id, curr.geoFrame.n, path.verticesS[path.verticesS.size() - 2].p, scene, throughputS, valueS);
        } else {
            evalLightDir(path.verticesS[0], scene, throughputS, valueS);
        }
        // PSDR_INFO("valueD: {}, {}, {}", valueD[0], valueD[1], valueD[2]);
        // PSDR_INFO("baseValue: {}", path.baseValue);
        // PSDR_INFO("valueS: {}, {}, {}", valueS[0], valueS[1], valueS[2]);
        // PSDR_INFO("path.verticesS.size(): {}", path.verticesS.size());
        // PSDR_INFO("path.verticesD.size(): {}", path.verticesD.size());
        value = (valueD * valueS * path.baseValue).sum();
    }

    void d_evalPath(const Scene &scene, Scene &d_scene, LightPathPSSAD &pathAD) {
        // Float d_value = 1.0;
        // Float value = 0.0;
        // __enzyme_autodiff((void *)evalPath,
        //                     enzyme_dup, &scene, &d_scene,
        //                     enzyme_dup, &pathAD.val, &pathAD.der,
        //                     enzyme_dup, &value, &d_value);

        Float value = 0.0;
        // Spectrum throughput = pathAD.val.baseValue;

        if (pathAD.val.verticesD.size() < 1)
            return;
        // value += edgeContrib;
        Spectrum throughputD(1.f), throughputS(1.f);
        Spectrum valueD(0.f), valueS(0.f);

        // value += edgeContrib;
        for (int i = 1; i < pathAD.val.verticesD.size(); i++)
        {
            throughputD *= pathAD.val.verticesD[i].value;
        }
        evalCamera(pathAD.val.verticesD[pathAD.val.verticesD.size() - 1], scene, throughputD, valueD);
        for (int i = 1; i < pathAD.val.verticesS.size(); i++)
        {
            throughputS *= pathAD.val.verticesS[i].value;
        }
        if (pathAD.val.verticesS.size() > 1){
            const Intersection &curr = pathAD.val.verticesS[pathAD.val.verticesS.size() - 1];
            const Shape *ptr_emitter = scene.shape_list[curr.shape_id];
            int light_id = ptr_emitter->light_id;
            evalLight(pathAD.val.verticesS[pathAD.val.verticesS.size() - 1].p, light_id, curr.geoFrame.n, pathAD.val.verticesS[pathAD.val.verticesS.size() - 2].p, scene, throughputS, valueS);
        } else {
            evalLightDir(pathAD.val.verticesS[0], scene, throughputS, valueS);
        }

        // Float d_value = 1.0;
        Spectrum d_throughputD(0.f), d_throughputS(0.f);
        Spectrum d_valueS = valueD * pathAD.val.baseValue;
        Spectrum d_valueD = valueS * pathAD.val.baseValue;
        pathAD.der.baseValue += (valueD * valueS).sum();
        
        d_evalCamera(pathAD.val.verticesD[pathAD.val.verticesD.size() - 1], pathAD.der.verticesD[pathAD.val.verticesD.size() - 1], 
                        scene, d_scene, throughputD, d_throughputD, d_valueD);
        {
            Intersection fake_its = pathAD.val.verticesD[pathAD.val.verticesD.size() - 1];
            Intersection d_fake_its;
            d_fake_its.setZero();
            // d_fake_itsD.p = Vector(1.0, 1.0, 1.0);
            Vector d_fake_p_fd = Vector(0.0, 0.0, 0.0);
            Spectrum fake_valueD = Spectrum(0.0, 0.0, 0.0);
            Spectrum d_fake_valueD = Spectrum(1.0, 1.0, 1.0);
            for (int i = 0; i < 3; i++) {
                Intersection its_i = fake_its;
                its_i.wi[i] += 1e-4;
                evalCamera(its_i, scene, throughputD, fake_valueD);
                d_fake_p_fd[i] = (fake_valueD - valueD).sum() / 1e-4;
            }
            Spectrum d_fake_throughput;
            d_evalCamera(fake_its, d_fake_its, 
                        scene, d_scene, throughputD, d_fake_throughput, d_fake_valueD);
            // PSDR_INFO("d_fake_p: {}, {}, {}", d_fake_its.wi[0], d_fake_its.wi[1], d_fake_its.wi[2]);
            // PSDR_INFO("d_fake_p_fd: {}, {}, {}", d_fake_p_fd[0], d_fake_p_fd[1], d_fake_p_fd[2]);
        }

        for (int i = pathAD.val.verticesD.size() - 1; i >= 1; i--) {
            throughputD /= pathAD.val.verticesD[i].value;
            pathAD.der.verticesD[i].value += d_throughputD * throughputD;
            d_throughputD *= pathAD.val.verticesD[i].value;
            // PSDR_INFO("pathAD.der.verticesD[{}].value: {}, {}, {}", i, pathAD.der.verticesD[i].value[0], pathAD.der.verticesD[i].value[1], pathAD.der.verticesD[i].value[2]);
        }

        if (pathAD.val.verticesS.size() > 1){
            const Intersection &curr = pathAD.val.verticesS[pathAD.val.verticesS.size() - 1];
            const Shape *ptr_emitter = scene.shape_list[curr.shape_id];
            int light_id = ptr_emitter->light_id;
            d_evalLight(pathAD.val.verticesS[pathAD.val.verticesS.size() - 1].p, pathAD.der.verticesS[pathAD.val.verticesS.size() - 1].p, 
                        light_id, curr.geoFrame.n, 
                        pathAD.val.verticesS[pathAD.val.verticesS.size() - 2].p, pathAD.der.verticesS[pathAD.val.verticesS.size() - 2].p, 
                        scene, d_scene, throughputS, d_throughputS, d_valueS);
        } else {
            d_evalLightDir(pathAD.val.verticesS[0], pathAD.der.verticesS[0], 
                        scene, d_scene, throughputS, d_throughputS, d_valueS);
        }
        for (int i = pathAD.val.verticesS.size() - 1; i >= 1; i--) {
            throughputS /= pathAD.val.verticesS[i].value;
            pathAD.der.verticesS[i].value += d_throughputS * throughputS;
            d_throughputS *= pathAD.val.verticesS[i].value;
            // PSDR_INFO("pathAD.der.verticesS[{}].value: {}, {}, {}", i, pathAD.der.verticesS[i].value[0], pathAD.der.verticesS[i].value[1], pathAD.der.verticesS[i].value[2]);
        }
    }

    void getPath(const Scene &scene, LightPathPSS &path) {
        Vector rnd_0 = path.pss_stateE;
        Vector d_rnd_0(0.0, 0.0, 0.0);
        // PSDR_INFO("path.vertices[0].indices: {}, {}", path.vertices[0].indices[0], path.vertices[0].indices[1]);
        // PSDR_INFO("here1");
        // PSDR_INFO("path.verticesD[0].indices: {}, {}", path.verticesD[0].indices[0], path.verticesD[0].indices[1]);
        // PSDR_INFO("path.verticesS.size(): {}, {}", path.verticesS[0].indices[0], path.verticesS[0].indices[1]);
// #ifdef DIFF_DIM0
        getEdgeIntersectionDir(rnd_0, scene, path.ePSRec, 
                            path.verticesD[0].indices, path.verticesS[0].indices, path.baseValue, path.verticesD[0], path.verticesS[0]);
                            
        const Intersection &its_emitter = path.verticesS[path.verticesS.size() - 1];
        const Intersection &its_cam = path.verticesD[path.verticesD.size() - 1];
        // for (int i = 0; i < path.verticesS.size(); i++) {
        //     PSDR_INFO("    lightPath pos {}: {}, {}, {}", i, path.verticesS[i].p[0], path.verticesS[i].p[1], path.verticesS[i].p[2]);
        // }
// #else
        // edgePtDirToIntersection(path.edge_point, path.edge_dir, scene, path.ePSRec, 
        //                     path.verticesD[0].indices, path.verticesS[0].indices, path.baseValue, path.verticesD[0], path.verticesS[0]);
// #endif
        // PSDR_INFO("here2");
        for (int i = 1; i < path.verticesD.size(); i++) {
            // Vector rnd = path.pss_state.get(i);
            Vector rnd = MALA::pss_get(path.pss_stateD, i - 1);
            char bsdf_name[100];
            bsdf_name[0] = 0;
            path.verticesD[i-1].ptr_bsdf->className(bsdf_name);
            evalIntersection(scene, *path.verticesD[i-1].ptr_bsdf, rnd, path.verticesD[i].indices, EBSDFMode::EImportanceWithCorrection, 
                            path.verticesD[i - 1], path.verticesD[i]);
            // PSDR_INFO("d_rnd: {}, {}, {}", d_rnd[0], d_rnd[1], d_rnd[2]);
        }
        for (int i = 1; i < path.verticesS.size(); i++) {
            // Vector rnd = path.pss_state.get(i);
            Vector rnd = MALA::pss_get(path.pss_stateS, i - 1);
            char bsdf_name[100];
            bsdf_name[0] = 0;
            path.verticesS[i-1].ptr_bsdf->className(bsdf_name);
            evalIntersection(scene, *path.verticesS[i-1].ptr_bsdf, rnd, path.verticesS[i].indices, EBSDFMode::EImportanceWithCorrection, 
                            path.verticesS[i - 1], path.verticesS[i]);
            // PSDR_INFO("d_rnd: {}, {}, {}", d_rnd[0], d_rnd[1], d_rnd[2]);
        }
        // PSDR_INFO("here4");
        // vec_set(pathAD.der.pss_state, 0, d_rnd_0);
    }

    void d_getPath(const Scene &scene, Scene &d_scene, LightPathPSSAD &pathAD) {
        // PSDR_INFO("path size: {}", pathAD.val.vertices.size());
        for (int i = pathAD.val.verticesD.size() - 1; i >= 1; i--) {
            // Vector rnd = pathAD.val.pss_state.get(i);
            Vector rnd = MALA::pss_get(pathAD.val.pss_stateD, i - 1);
            Vector d_rnd(0.0, 0.0, 0.0);
            // PSDR_INFO("d_value {}: {}, {}, {}", i, pathAD.der.verticesD[i].value[0], pathAD.der.verticesD[i].value[1], pathAD.der.verticesD[i].value[2]);
            // PSDR_INFO("d_p {}: {}, {}, {}", i, pathAD.der.verticesD[i].p[0], pathAD.der.verticesD[i].p[1], pathAD.der.verticesD[i].p[2]);
            // PSDR_INFO("d_wo {}: {}, {}, {}", i, pathAD.der.verticesD[i].wi[0], pathAD.der.verticesD[i].wi[1], pathAD.der.verticesD[i].wi[2]);
            // PSDR_INFO("d_wi {}: {}, {}, {}", i, pathAD.der.verticesD[i - 1].wi[0], pathAD.der.verticesD[i - 1].wi[1], pathAD.der.verticesD[i - 1].wi[2]);
            char bsdf_name[100];
            bsdf_name[0] = 0;
            pathAD.val.verticesD[i-1].ptr_bsdf->className(bsdf_name);
            // d_evalIntersection(scene, d_scene, *pathAD.val.verticesD[i-1].ptr_bsdf, 
            //                 rnd, d_rnd, pathAD.val.verticesD[i].indices, EBSDFMode::EImportanceWithCorrection,
            //                 pathAD.val.verticesD[i - 1], pathAD.der.verticesD[i - 1], pathAD.der.verticesD[i]);
            if (strcmp(bsdf_name, "RoughConductorBSDF") == 0) {
                const RoughConductorBSDF* rc = dynamic_cast<const RoughConductorBSDF*>(pathAD.val.verticesD[i-1].ptr_bsdf);
                d_evalIntersectionRCIC(scene, d_scene, *rc, rnd, d_rnd, pathAD.val.verticesD[i].indices, 
                                    pathAD.val.verticesD[i - 1], pathAD.der.verticesD[i - 1], pathAD.der.verticesD[i]);
            } else if (strcmp(bsdf_name, "RoughDielectricBSDF") == 0) {
                const RoughDielectricBSDF* rd = dynamic_cast<const RoughDielectricBSDF*>(pathAD.val.verticesD[i-1].ptr_bsdf);
                d_evalIntersectionRDIC(scene, d_scene, *rd, rnd, d_rnd, pathAD.val.verticesD[i].indices, 
                                    pathAD.val.verticesD[i - 1], pathAD.der.verticesD[i - 1], pathAD.der.verticesD[i]);
            } else if (strcmp(bsdf_name, "DiffuseBSDF") == 0) {
                const DiffuseBSDF* df = dynamic_cast<const DiffuseBSDF*>(pathAD.val.verticesD[i-1].ptr_bsdf);
                d_evalIntersectionDiffuseIC(scene, d_scene, *df, rnd, d_rnd, pathAD.val.verticesD[i].indices, 
                                    pathAD.val.verticesD[i - 1], pathAD.der.verticesD[i - 1], pathAD.der.verticesD[i]);
            } else {
                // PSDR_INFO("bsdf name: {}", bsdf_name);
                assert(false);
            }
            // pathAD.der.pss_state.set(i, d_rnd);
            MALA::pss_set(pathAD.der.pss_stateD, i - 1, d_rnd);
            // PSDR_INFO("d_rnd {}: {}, {}, {}", i, d_rnd[0], d_rnd[1], d_rnd[2]);
            // PSDR_INFO("d_wi {}: {}, {}, {}", i, pathAD.der.verticesD[i - 1].wi[0], pathAD.der.verticesD[i - 1].wi[1], pathAD.der.verticesD[i - 1].wi[2]);
        }

        for (int i = pathAD.val.verticesS.size() - 1; i >= 1; i--) {
            // Vector rnd = pathAD.val.pss_state.get(i);
            Vector rnd = MALA::pss_get(pathAD.val.pss_stateS, i - 1);
            Vector d_rnd(0.0, 0.0, 0.0);
            // PSDR_INFO("d_value {}: {}, {}, {}", i, pathAD.der.verticesS[i].value[0], pathAD.der.verticesS[i].value[1], pathAD.der.verticesS[i].value[2]);
            // PSDR_INFO("d_p {}: {}, {}, {}", i, pathAD.der.verticesS[i].p[0], pathAD.der.verticesS[i].p[1], pathAD.der.verticesS[i].p[2]);
            // PSDR_INFO("d_wo {}: {}, {}, {}", i, pathAD.der.verticesS[i].wi[0], pathAD.der.verticesS[i].wi[1], pathAD.der.verticesS[i].wi[2]);
            char bsdf_name[100];
            bsdf_name[0] = 0;
            pathAD.val.verticesS[i-1].ptr_bsdf->className(bsdf_name);
            // d_evalIntersection(scene, d_scene, *pathAD.val.verticesS[i-1].ptr_bsdf, 
            //                 rnd, d_rnd, pathAD.val.verticesS[i].indices, EBSDFMode::ERadiance,
            //                 pathAD.val.verticesS[i - 1], pathAD.der.verticesS[i - 1], pathAD.der.verticesS[i]);
            // PSDR_INFO("bsdf name: {}", bsdf_name);
            if (strcmp(bsdf_name, "RoughConductorBSDF") == 0) {
                Vector wo = (pathAD.der.verticesS[i].p - pathAD.der.verticesS[i - 1].p).normalized();
                Vector d_wo(0.0, 0.0, 0.0);
                const RoughConductorBSDF* rc = dynamic_cast<const RoughConductorBSDF*>(pathAD.val.verticesS[i-1].ptr_bsdf);
                d_evalIntersectionRCRad(scene, d_scene, *rc, rnd, d_rnd, pathAD.val.verticesS[i].indices, 
                                    pathAD.val.verticesS[i - 1], pathAD.der.verticesS[i - 1], pathAD.der.verticesS[i]);
            } else if (strcmp(bsdf_name, "RoughDielectricBSDF") == 0) {
                Vector wo = (pathAD.der.verticesS[i].p - pathAD.der.verticesS[i - 1].p).normalized();
                Vector d_wo(0.0, 0.0, 0.0);
                const RoughDielectricBSDF* rd = dynamic_cast<const RoughDielectricBSDF*>(pathAD.val.verticesS[i-1].ptr_bsdf);
                d_evalIntersectionRDRad(scene, d_scene, *rd, rnd, d_rnd, pathAD.val.verticesS[i].indices, 
                                    pathAD.val.verticesS[i - 1], pathAD.der.verticesS[i - 1], pathAD.der.verticesS[i]);
            } else if (strcmp(bsdf_name, "DiffuseBSDF") == 0) {
                const DiffuseBSDF* df = dynamic_cast<const DiffuseBSDF*>(pathAD.val.verticesS[i-1].ptr_bsdf);
                d_evalIntersectionDiffuseRad(scene, d_scene, *df, rnd, d_rnd, pathAD.val.verticesS[i].indices, 
                                    pathAD.val.verticesS[i - 1], pathAD.der.verticesS[i - 1], pathAD.der.verticesS[i]);
            } else {
                // PSDR_INFO("bsdf name: {}", bsdf_name);
                assert(false);
            }
            // pathAD.der.pss_state.set(i, d_rnd);
            MALA::pss_set(pathAD.der.pss_stateS, i - 1, d_rnd);
            // PSDR_INFO("d_rnd {}: {}, {}, {}", i, d_rnd[0], d_rnd[1], d_rnd[2]);
            // PSDR_INFO("d_wi {}: {}, {}, {}", i, pathAD.der.verticesS[i - 1].wi[0], pathAD.der.verticesS[i - 1].wi[1], pathAD.der.verticesS[i - 1].wi[2]);
        }

        // Vector rnd_0 = pathAD.val.pss_state.get(0);
        Vector rnd_0 = pathAD.val.pss_stateE;
        Vector d_rnd_0(0.0, 0.0, 0.0);
        // PSDR_INFO("pathAD.der.baseValue: {}", pathAD.der.baseValue);
        // PSDR_INFO("pathAD.der.verticesD[0].p: {}, {}, {}", pathAD.der.verticesD[0].p[0], pathAD.der.verticesD[0].p[1], pathAD.der.verticesD[0].p[2]);
        // PSDR_INFO("pathAD.der.verticesS[0].p: {}, {}, {}", pathAD.der.verticesS[0].p[0], pathAD.der.verticesS[0].p[1], pathAD.der.verticesS[0].p[2]);
// #ifdef DIFF_DIM0
        //
        d_getEdgeIntersectionDir(rnd_0, d_rnd_0, scene, d_scene, pathAD.val.ePSRec, 
                                pathAD.val.verticesD[0].indices, pathAD.val.verticesS[0].indices, 
                                pathAD.der.baseValue, 
                                pathAD.der.verticesD[0], pathAD.der.verticesS[0]);
        // d_rnd_0[1] /= 4.0;
        // d_rnd_0[2] /= 4.0;

        // {
        //     Intersection fake_itsD = pathAD.val.verticesD[0];
        //     Intersection fake_itsS = pathAD.val.verticesS[0];
        //     Intersection d_fake_itsD, d_fake_itsS;
        //     d_fake_itsD.setZero();
        //     d_fake_itsS.setZero();
        //     // d_fake_itsD.p = Vector(1.0, 1.0, 1.0);
        //     Float d_fake_baseValue = 1.0;
        //     Vector d_fake_rnd_0_fd = Vector(0.0, 0.0, 0.0);
        //     Vector d_fake_rnd_0 = Vector(0.0, 0.0, 0.0);
        //     Float fake_baseValue = 1.0;
        //     for (int i = 0; i < 3; i++) {
        //         Vector rnd_i = rnd_0;
        //         rnd_i[i] += 1e-4;
        //         getEdgeIntersectionDir(rnd_i, scene, pathAD.val.ePSRec, 
        //                             pathAD.val.verticesD[0].indices, pathAD.val.verticesS[0].indices, fake_baseValue, fake_itsD, fake_itsS);
        //         d_fake_rnd_0_fd[i] = (fake_baseValue - pathAD.val.baseValue) / 1e-4;
        //     }
        //     d_getEdgeIntersectionDir(rnd_0, d_fake_rnd_0, scene, d_scene, pathAD.val.ePSRec, 
        //                         pathAD.val.verticesD[0].indices, pathAD.val.verticesS[0].indices, 
        //                         d_fake_baseValue, 
        //                         d_fake_itsD, d_fake_itsS);
        //     PSDR_INFO("d_fake_rnd_0: {}, {}, {}", d_fake_rnd_0[0], d_fake_rnd_0[1], d_fake_rnd_0[2]);
        //     PSDR_INFO("d_fake_rnd_0_fd: {}, {}, {}", d_fake_rnd_0_fd[0], d_fake_rnd_0_fd[1], d_fake_rnd_0_fd[2]);
        //     PSDR_INFO("ePSRec.pdf: {}", pathAD.val.ePSRec.pdf);
        //     PSDR_INFO("ePSRec.scale: {}", pathAD.val.ePSRec.scale);
        // }
// #else
        // d_edgePtDirToIntersection(pathAD.val.edge_point, pathAD.val.edge_dir, pathAD.der.edge_point, pathAD.der.edge_dir, 
        //                         scene, d_scene, pathAD.val.ePSRec, 
        //                         pathAD.val.verticesD[0].indices, pathAD.val.verticesS[0].indices, 
        //                         pathAD.der.baseValue, 
        //                         pathAD.der.verticesD[0], pathAD.der.verticesS[0]);
// #endif
        // pathAD.der.pss_state.set(0, d_rnd_0);
        pathAD.der.pss_stateE = d_rnd_0;
    }

    

    Spectrum __Li(const Scene &scene, const Ray &_ray, const MALA::MALAVector rnd, int n_bounces, LightPathPSS *path = nullptr)
    {
        assert(rnd.size() == 3 * n_bounces);
        Ray ray(_ray);
        Intersection its;

        Spectrum ret = Spectrum::Zero();
        scene.rayIntersect(ray, true, its);
        if (!its.isValid())
            return Spectrum::Zero();

        Spectrum throughput = Spectrum::Ones();
        // if (path != nullptr) {
        //     path->verticesS.clear();
        // }
        its.value = Spectrum::Ones();
        // if (path != nullptr) {
        //     path->verticesS.push_back(its);
        // }
        if (its.isEmitter() && n_bounces == 0)
        {
            ret = throughput * its.Le(-ray.dir);
            return ret;
            // rRec.values[0] += throughput * its.Le(-ray.dir);
        }
        for (int depth = 0; depth < n_bounces && its.isValid(); depth++)
        {
            Vector wo;
            // Indirect illumination
            Float bsdf_pdf, bsdf_eta;
            auto bsdf_weight = its.sampleBSDF(MALA::pss_get(rnd, depth), wo, bsdf_pdf, bsdf_eta);
            // Vector wo_local = its.toLocal(wo);
            if (path != nullptr){
                path->discrete_dimS[depth * 3] = 0.0;
                path->discrete_dimS[depth * 3 + 1] = 0.0;
                path->discrete_dimS[depth * 3 + 2] = 0.0;
                if (its.bsdf_type == 0) {
                    path->typeS[depth] = 'd';
                } else if (its.bsdf_type == 1) {
                    path->typeS[depth] = 'r';
                } else if (its.bsdf_type == 2) {
                    path->typeS[depth] = 't';
                } else if (its.bsdf_type == -1) {
                    path->typeS[depth] = 'n';
                
                }
            }
            if (bsdf_weight.isZero(Epsilon))
                return Spectrum::Zero();
            wo = its.toWorld(wo);
            // PSDR_INFO("wo_local: {}, {}, {}", wo[0], wo[1], wo[2]);
            ray = Ray(its.p, wo);

            Vector pre_p = its.p;
            if (!scene.rayIntersect(ray, true, its))
                return Spectrum::Zero();
            
            its.value = bsdf_weight;
            if (path != nullptr) {
                path->verticesS.push_back(its);
            }
            throughput *= bsdf_weight;
        }
        // PSDR_INFO("throughput: {}, {}, {}", throughput[0], throughput[1], throughput[2]);

        if (its.isEmitter())
        {
            its.value = Spectrum::Ones();
            Spectrum light_contrib = its.Le(-ray.dir);
            if (!light_contrib.isZero(Epsilon))
            {
                ret = throughput * light_contrib;
                // rRec.values[depth + 1] += throughput * light_contrib;
            }
            // if (path != nullptr) {
            //     path->verticesS.push_back(its);
            // }
        }
        return ret;
    }
    
    struct BoundaryRadianceQueryRecord
    {
        RndSampler *sampler;
        int max_bounces;
        std::vector<Spectrum> values;
        BoundaryRadianceQueryRecord(RndSampler *sampler, int max_bounces)
            : sampler(sampler), max_bounces(max_bounces), values(max_bounces + 1, Spectrum::Zero()) {}
    };
    
    Spectrum eval(const Scene &scene,
                    const ArrayXd &rnd, int max_bounces, int cam_bounce,
                    const DiscreteDistribution &edge_dist,
                    const std::vector<Vector2i> &edge_indices,
                    LightPathPSS *path, bool echo)
    {   
        // if (echo)
        //     d_sampleDirectBoundary(scene, rnd, max_bounces, cam_bounce, edge_dist, edge_indices, echo);

        /* cam_bounces + light_bounces + 1 (edge ray) = max_bounces */
        assert(max_bounces > cam_bounce);
        assert(max_bounces >= 0);
        assert(cam_bounce >= 0);
        assert(rnd.size() == 3 * max_bounces);
        Vector pss_stateE = MALA::pss_get(rnd, 0);
        MALA::MALAVector pss_stateD, pss_stateS;
        MALA::MALAVector discrete_dimD, discrete_dimS;
        if (max_bounces - cam_bounce - 1 > 0){
            pss_stateS.resize(3 * (max_bounces - cam_bounce - 1));
            discrete_dimS.resize(3 * (max_bounces - cam_bounce - 1));
            discrete_dimS.setZero();
            pss_stateS = rnd.segment(3, 3 * (max_bounces - cam_bounce - 1));
        }
        if (cam_bounce > 0){
            pss_stateD.resize(3 * cam_bounce);
            discrete_dimD.resize(3 * cam_bounce);
            discrete_dimD.setZero();
            pss_stateD = rnd.segment(3 * (max_bounces - cam_bounce), 3 * cam_bounce);
        }
        if (path != nullptr){
            path->verticesD.clear();
            path->verticesS.clear();
            path->pss_stateE = pss_stateE;
            if (max_bounces - cam_bounce - 1 > 0){
                path->pss_stateS.resize(pss_stateS.size());
                path->pss_stateS = pss_stateS;
                path->discrete_dimS.resize(discrete_dimS.size());
                path->discrete_dimS = discrete_dimS;
            }
            if (cam_bounce > 0){
                path->pss_stateD.resize(pss_stateD.size());
                path->pss_stateD = pss_stateD;
                path->discrete_dimD.resize(discrete_dimD.size());
                path->discrete_dimD = discrete_dimD;
            }
        }
        Spectrum contrib(0.0);
        BoundarySamplingRecord eRec;
        sampleEdgeRay(scene, pss_stateE, edge_dist, edge_indices, eRec, (path == nullptr) ? nullptr : &(path->ePSRec));
        if (echo) {
            PSDR_INFO("eRec.ref: {}, {}, {}", eRec.ref[0], eRec.ref[1], eRec.ref[2]);
            PSDR_INFO("eRec.dir: {}, {}, {}", eRec.dir[0], eRec.dir[1], eRec.dir[2]);
        }
        if (eRec.shape_id == -1)
        {
            // PSDR_WARN(eRec.shape_id == -1);
            return contrib;
        }
        if (path != nullptr){
            path->bound.min = path->ePSRec.min;
            path->bound.max = path->ePSRec.max;
            path->bound.edge_idx = path->ePSRec.edge_id;
            path->bound.shape_idx = path->ePSRec.shape_id;
            assert(path->ePSRec.shape_id >= 0 && path->ePSRec.shape_id < scene.shape_list.size());
            assert(path->bound.shape_idx >= 0 && path->bound.shape_idx < scene.shape_list.size());
            assert(path->bound.edge_idx >= 0 && path->bound.edge_idx < scene.shape_list[path->bound.shape_idx]->edges.size());
// #ifndef DIFF_DIM0
            path->edge_point = eRec.ref;
            path->edge_dir = eRec.dir;
// #endif
            if (echo) {
                PSDR_INFO("path->bound.edge_idx: {}", path->bound.edge_idx);
                PSDR_INFO("path->bound.shape_idx: {}", path->bound.shape_idx);
                PSDR_INFO("rnd[0]: {}", pss_stateE[0]);
                PSDR_INFO("ePSRec.shape_id: {}", eRec.shape_id);
                PSDR_INFO("ePSRec.edge_id: {}", eRec.edge_id);
            }
        }
        const Shape *shape = scene.shape_list[eRec.shape_id];
        const Edge &edge = shape->edges[eRec.edge_id];
        assert(edge.f0 >= 0);

        Ray edgeRay(eRec.ref, eRec.dir);
        if (path != nullptr) {
            path->bound.dir = edgeRay.dir;
        }
        Intersection itsS, its;
        if (!scene.rayIntersect(edgeRay, true, itsS) ||
            !scene.rayIntersect(edgeRay.flipped(), true, its))
            return contrib;

        /* Sample point on emitters */
        // DirectSamplingRecord dRec(eRec.ref);
        // Spectrum value = itsS.Le(-edgeRay.dir);
        
        // if (value.sum() > 0){
        //     PSDR_INFO("n, i: {}, {}", max_bounces, cam_bounce);
        //     PSDR_INFO("value: {}, {}, {}", value[0], value[1], value[2]);
        // }

        if (path != nullptr) {
            path->typeD[cam_bounce] = '\0';
            path->typeS[max_bounces - cam_bounce - 1] = '\0';
        }
        const Vector xB = eRec.ref,
                     &xS = itsS.p;
        // Ray ray(xB, (xB - xS).normalized());
        // Intersection its;
        // if (!scene.rayIntersect(ray, true, its))
        //     return contrib;

        // sanity check
        // make sure the ray is tangent to the surface
        // if (edge.f0 >= 0 && edge.f1 >= 0)
        // {
        //     Vector n0 = shape->getGeoNormal(edge.f0),
        //            n1 = shape->getGeoNormal(edge.f1);
        //     Float dotn0 = ray.dir.dot(n0),
        //           dotn1 = ray.dir.dot(n1);
        //     if (math::signum(dotn0) * math::signum(dotn1) > -0.5)
        //         return contrib;
        // }
        // NOTE prevent intersection with a backface
        const Vector2i ind0(eRec.shape_id, edge.f0), ind1(eRec.shape_id, edge.f1);
        if (itsS.indices == ind0 || itsS.indices == ind1 ||
            its.indices == ind0 || its.indices == ind1)
            return contrib;
        // FIXME: if the isTransmissive() of a dielectric returns false, the rendering time will be short, and the variance will be small.
        const Float gn1d1 = itsS.geoFrame.n.dot(-edgeRay.dir), sn1d1 = itsS.shFrame.n.dot(-edgeRay.dir),
                    gn2d1 = its.geoFrame.n.dot(edgeRay.dir), sn2d1 = its.shFrame.n.dot(edgeRay.dir);
        bool valid1 = (itsS.ptr_bsdf->isTransmissive() && math::signum(gn1d1) * math::signum(sn1d1) > 0.5f) || (!itsS.ptr_bsdf->isTransmissive() && gn1d1 > Epsilon && sn1d1 > Epsilon),
             valid2 = (its.ptr_bsdf->isTransmissive() && math::signum(gn2d1) * math::signum(sn2d1) > 0.5f) || (!its.ptr_bsdf->isTransmissive() && gn2d1 > Epsilon && sn2d1 > Epsilon);
        if (itsS.isEmitter())
            valid1 = true;
        if (!valid1 || !valid2)
            return contrib;

        // populate the data in BoundarySamplingRecord eRec
        eRec.dir = edgeRay.dir;
        eRec.shape_id_S = itsS.indices[0];
        eRec.tri_id_S = itsS.indices[1];
        eRec.shape_id_D = its.indices[0];
        eRec.tri_id_D = its.indices[1];

        /* Jacobian determinant that accounts for the change of variable */
        Vector v0 = shape->getVertex(edge.v0);
        Vector v1 = shape->getVertex(edge.v1);
        Vector v2 = shape->getVertex(edge.v2);
        const Vector &xD = its.p;
        Vector n = (v0 - v1).cross(-edgeRay.dir).normalized();
        n *= -math::signum(n.dot(v2 - v0)); // make sure n points to the visible side
        Float J = dlD_dlB(xS,
                          xB, (v0 - v1).normalized(),
                          xD, its.geoFrame.n) *
                  dA_dw(xB, xS, itsS.geoFrame.n);
        Float baseValue = J * geometric(xD, its.geoFrame.n, xS, itsS.geoFrame.n);

        if (std::abs(baseValue) < Epsilon)
            return contrib;
        its.value = Spectrum(1.0);
        if (path != nullptr){
            path->baseValue = baseValue;
            path->verticesD.clear();
            path->verticesS.clear();
            path->verticesD.push_back(its);
            path->verticesS.push_back(itsS);
        }
        edgeRay.tmax = std::numeric_limits<Float>::infinity();
        Spectrum valueS = __Li(scene, edgeRay, pss_stateS, max_bounces - cam_bounce - 1, path);
        if (echo){
            PSDR_INFO("edgeRay.dir: {}, {}, {}", edgeRay.dir[0], edgeRay.dir[1], edgeRay.dir[2]);
            PSDR_INFO("baseValue: {}", baseValue);
            PSDR_INFO("valueS: {}, {}, {}", valueS[0], valueS[1], valueS[2]);
            if (path != nullptr){
                for (int i = 0; i < path->verticesS.size(); i++){
                    PSDR_INFO("    lightPath pos {}: {}, {}, {}", i, path->verticesS[i].p[0], path->verticesS[i].p[1], path->verticesS[i].p[2]);
                }
                // Intersection its = path->verticesS[path->verticesS.size() - 1];
                // PSDR_INFO("emitter pos: {}, {}, {}", its.p[0], its.p[1], its.p[2]);
            }
        }
        if (valueS.isZero(Epsilon))
            return contrib;
        // assert(baseValue > -Epsilon);

        /* Sample detector path */
        Spectrum throughput(1.0f);
        Spectrum valueD(0.f);
        Ray ray_sensor;

        // PSDR_INFO("here3");
        for (int i = 0; i < cam_bounce; i++)
        {

            // contrib += contrib_i;
            Vector wo_local, wo;
            Float bsdf_pdf, bsdf_eta;
            Vector pss_cur = MALA::pss_get(pss_stateD, i);
            if (echo){
                // PSDR_INFO("pss_cur {}: {}, {}, {}", i, pss_cur[0], pss_cur[1], pss_cur[2]);
            }
            Spectrum bsdf_weight = its.sampleBSDF(pss_cur, wo_local,
                                                  bsdf_pdf, bsdf_eta,
                                                  EBSDFMode::EImportanceWithCorrection);
            if (bsdf_weight.isZero())
                break;
            if (path != nullptr){
                MALA::pss_set(path->discrete_dimD, i, Vector(0.0, 0.0, 0.0));
                // path->discrete_dimD[i * 3] = 0.0;
                // path->discrete_dimD[i * 3 + 1] = 0.0;
                // path->discrete_dimD[i * 3 + 2] = 1.0;
                if (its.bsdf_type == 0) {
                    path->typeD[i] = 'd';
                } else if (its.bsdf_type == 1) {
                    path->typeD[i] = 'r';
                } else if (its.bsdf_type == 2) {
                    path->typeD[i] = 't';
                } else if (its.bsdf_type == -1) {
                    path->typeD[i] = 'n';
                
                }
            }
            wo = its.toWorld(wo_local);
            Vector wi = its.toWorld(its.wi);
            Float wiDotGeoN = wi.dot(its.geoFrame.n), woDotGeoN = wo.dot(its.geoFrame.n);
            if (wiDotGeoN * its.wi.z() <= 0 || woDotGeoN * wo_local.z() <= 0){
                // PSDR_INFO("Invalid intersection, bounce: {}", i);
                return contrib;
            }
            throughput *= bsdf_weight;
            ray_sensor = Ray(its.p, wo);
            bool hit = scene.rayIntersect(ray_sensor, true, its);
            its.value = bsdf_weight;
            if (!hit || !its.isValid()){
                // PSDR_INFO("Invalid intersection, bounce: {}", i);
                return contrib;
            }
            if (path != nullptr){
                path->verticesD.push_back(its);
                if (echo){
                    PSDR_INFO("cam iter: {}", i);
                    PSDR_INFO("p[{}]: {}, {}, {}", i+1, its.p[0], its.p[1], its.p[2]);
                }
            }
        }
        // if (path != nullptr){
        //     path->verticesD.push_back(its);
        // }
        CameraDirectSamplingRecord cRec;
        Vector cam_dir = (scene.camera.cpos - its.p).normalized();
        if (!scene.camera.sampleDirect(its.p, cRec)){
            if (echo){
                PSDR_INFO("its.p: {}, {}, {}", its.p[0], its.p[1], its.p[2]);
            }
            return contrib;
        }
        if (!scene.isVisible(its.p, true, scene.camera.cpos, true)){
            if (echo){
                PSDR_INFO("its.p: {}, {}, {}", its.p[0], its.p[1], its.p[2]);
            }
            return contrib;
        }
        auto [pixel_idx, sensor_val] = scene.camera.sampleDirectPixel(cRec, 0.5);
        its.pixel_idx = pixel_idx;
        if (sensor_val < Epsilon){
            if (echo){
                PSDR_INFO("sensor_val: {}", sensor_val);
            }
            return contrib;
        }
        if (pixel_idx >= 0 && pixel_idx < scene.camera.getNumPixels()){
            // PSDR_INFO("here1");
            // PSDR_INFO("its.ptr_shape: {}", its.ptr_shape->bsdf_id);
            if (echo){
                evalCamera_echo(its, scene, throughput, valueD);
                PSDR_INFO("pixel_idx: {}", pixel_idx);
                PSDR_INFO("throughput: {}, {}, {}", throughput[0], throughput[1], throughput[2]);
                PSDR_INFO("valueD: {}, {}, {}", valueD[0], valueD[1], valueD[2]);
            } else {
                evalCamera(its, scene, throughput, valueD);
            }
            // PSDR_INFO("here2");
        }
        contrib = valueD * valueS * baseValue / eRec.pdf;
        return contrib;
    }

    void d_eval(const Scene &scene, Scene &d_scene,
                  LightPathPSSAD &pathAD
                  ) {
        LightPathPSS &path_der = pathAD.der;
        pathAD.der.clear(); // don't accumulate any exisiting value
        pathAD.der.resize(pathAD.val);
        pathAD.der.setZero();
        getPath(scene, pathAD.val);
        const Intersection &its_emitter = pathAD.val.verticesS[pathAD.val.verticesS.size() - 1];
        const Intersection &its_cam = pathAD.val.verticesD[pathAD.val.verticesD.size() - 1];
        // PSDR_INFO("its_emitter.p: {}, {}, {}", its_emitter.p[0], its_emitter.p[1], its_emitter.p[2]);
        // PSDR_INFO("its_cam.p: {}, {}, {}", its_cam.p[0], its_cam.p[1], its_cam.p[2]);
        Float value = 0.0;
        evalPath(scene, pathAD.val, value);
        // PSDR_INFO("value in d_: {}", value);
        d_evalPath(scene, d_scene, pathAD);
        d_getPath(scene, d_scene, pathAD);
    }
NAMESPACE_END(algorithm1_MALA_direct)