#include "algorithm1.h"
#include <render/scene.h>
#include <core/math_func.h>
#include <bsdf/roughconductor.h>
#include <bsdf/diffuse.h>
#include <bsdf/roughdielectric.h>
#include <mala_utils.h>

NAMESPACE_BEGIN(algorithm1_MALA_direct)    


    void sampleEdgeRay(const Scene &scene, const Array3 rnd,
                       const DiscreteDistribution &edge_dist,
                       const std::vector<Vector2i> &edge_indices,
                       EdgeRaySamplingRecord &eRec, 
                       EdgePrimarySampleRecord *ePSRec = nullptr)
    {
        /* Sample a point on the boundary */
        scene.sampleEdgePoint(rnd[0],
                              edge_dist, edge_indices,
                              eRec, ePSRec);
        if (eRec.shape_id == -1)
        {
            PSDR_WARN(eRec.shape_id == -1);
            return;
        }
        const Shape *shape = scene.shape_list[eRec.shape_id];
        const Edge &edge = shape->edges[eRec.edge_id];
        assert(edge.f0 >= 0);

        /* Sample edge ray */
        // if (edge.f1 < 0) // Case 1: boundary edge
        // {
            eRec.dir = squareToUniformSphere(Vector2{rnd[1], rnd[2]});
            eRec.pdf /= 4. * M_PI;
        // }
        if (edge.f1 >= 0) // Case 2: non-boundary edge
        {
            Vector n0 = shape->getGeoNormal(edge.f0);
            Vector n1 = shape->getGeoNormal(edge.f1);
        //     Float pdf1;
        //     eRec.dir = squareToEdgeRayDirection(Vector2(rnd[1], rnd[2]), n0, n1, pdf1);
        //     eRec.pdf *= pdf1;

            Float dotn0 = eRec.dir.dot(n0), dotn1 = eRec.dir.dot(n1);
            if (math::signum(dotn0) * math::signum(dotn1) > -0.5f)
            {
                // std::cerr << "\n[WARN] Bad edge ray sample: [" << dotn0 << ", " << dotn1 << "]" << std::endl;
                eRec.shape_id = -1;
            }
        }
        if (ePSRec != nullptr) {
            ePSRec->pdf = eRec.pdf;
        }
    }

    void getEdgeIntersectionDir(const Vector &rnd, const Scene &scene, 
                    const EdgePrimarySampleRecord &ePSRec, const EmitterPrimarySampleRecord &dPSRec, 
                    const Array2i &indices,
                    Spectrum &contrib, Intersection &itsD) {
        /* Sample point on emitters */
        Vector light_norm = dPSRec.n;
        Vector _rnd(rnd);
        _rnd[0] = (_rnd[0] - ePSRec.offset) / ePSRec.scale;

        Vector edge_point = ePSRec.v0 + _rnd[0] * (ePSRec.v1 - ePSRec.v0);

        const Shape *shape = scene.shape_list[ePSRec.shape_id];
        const Edge &edge = shape->edges[ePSRec.edge_id];
        Vector dir;
        // Vector emitter_point;
        // if (edge.f1 < 0) // Case 1: boundary edge
        // {
            dir = squareToUniformSphere(Vector2{_rnd[1], _rnd[2]});
        // }
        // else // Case 2: non-boundary edge
        // {
        //     Vector n0 = shape->getGeoNormal(edge.f0);
        //     Vector n1 = shape->getGeoNormal(edge.f1);
        //     Float pdf1;
        //     dir = squareToEdgeRayDirection(Vector2(_rnd[1], _rnd[2]), n0, n1, pdf1);
        // }
        // if (ePSRec.ray_flipped) {
        //     dir = -dir;
        // }

        // Vector dir = edge_point - emitter_point;
        // Float inv_dist_emitter = 1.0 / dir.norm();
        // dir *= inv_dist_emitter;
        // value = scene.emitter_list[dPSRec.emitter_id]->eval(light_norm, dir);
        
        Ray edgeRay(edge_point, dir);

        // scene.rayIntersect(edgeRay, true, itsS);
        // scene.rayIntersect(edgeRay.flipped(), true, its);
        // Intersection its;
        // scene.rayIntersect(ray, true, its);
        const Shape* shape_cam = scene.getShape(indices[0]);
        const Emitter* shape_emitter = scene.emitter_list[dPSRec.emitter_id];
        // scene.rayIntersect(ray, true, its);
        // shape_cam->rayIntersect(its_Cam.indices[1], ray, holder);
        Vector3i idx_cam = shape_cam->getIndices(indices[1]);
        Vector v0_cam = detach(shape_cam->getVertex(idx_cam[0]));
        Vector v1_cam = detach(shape_cam->getVertex(idx_cam[1]));
        Vector v2_cam = detach(shape_cam->getVertex(idx_cam[2]));

        Array uvt = rayIntersectTriangle(v0_cam, v1_cam, v2_cam, edgeRay.flipped());
        Vector detector_point = edge_point + uvt[2] * (-dir);
        Vector geo_n = detach(shape_cam->getFaceNormal(indices[1]));
        Vector sh_n = detach(shape_cam->getShadingNormal(indices[1], Vector2(uvt[0], uvt[1])));
        itsD.p = detector_point;
        itsD.geoFrame = Frame(geo_n);
        itsD.shFrame = Frame(sh_n);
        itsD.wi = itsD.toLocal(dir);
        itsD.uv = Vector2(uvt[0], uvt[1]);

        Array uvtS = rayIntersectTriangle(dPSRec.v0, dPSRec.v1, dPSRec.v2, edgeRay);
        Vector emitter_point = edge_point + uvtS[2] * dir;

        Spectrum value = scene.emitter_list[dPSRec.emitter_id]->eval(light_norm, -dir);
        Float J = dlD_dlB(emitter_point,
                          edge_point, (ePSRec.v0 - ePSRec.v1).normalized(),
                          detector_point, geo_n) *
                  dA_dw(edge_point, emitter_point, light_norm);

        // Float sensorVal = scene.camera.eval(pixel_idx[0], pixel_idx[1], its_Cam.p);
        
        // Spectrum bsdf_val = its_Cam.evalBSDF(-dir);
        Float baseValue = J * geometric(detector_point, geo_n, emitter_point, light_norm);
        contrib = value * baseValue / ePSRec.pdf; // eRec.pdf, bsdf_val
    }

    void d_getEdgeIntersectionDir(const Vector &rnd, Vector &d_rnd, const Scene &scene, Scene &d_scene,
                    const EdgePrimarySampleRecord &ePSRec, const EmitterPrimarySampleRecord &dPSRec, 
                    const Array2i &indices,
                    Spectrum &d_contrib, Intersection &d_itsD) {
        [[maybe_unused]] Spectrum contrib;
        [[maybe_unused]] Intersection itsD;
#if defined(ENZYME) && defined(ENZYME_BOUNDARY_DIRECT)
        __enzyme_autodiff((void *)getEdgeIntersectionDir,
                            enzyme_dup, &rnd, &d_rnd,
                            enzyme_dup, &scene, &d_scene,
                            enzyme_const, &ePSRec,
                            enzyme_const, &dPSRec,
                            enzyme_const, &indices,
                            enzyme_dup, &contrib, &d_contrib,
                            enzyme_dup, &itsD, &d_itsD);
#endif
    }

    void getEdgeIntersection(const Vector &rnd, const Scene &scene, 
                    const EdgePrimarySampleRecord &ePSRec, const EmitterPrimarySampleRecord &dPSRec, 
                    const Array2i &indices,
                    Spectrum &contrib) {

        /* Sample point on emitters */
        // DirectSamplingRecord dRec(eRec.ref);
        Vector light_norm = dPSRec.n;
        Vector _rnd(rnd);
        _rnd[0] = (_rnd[0] - ePSRec.offset) / ePSRec.scale;

        Vector edge_point = ePSRec.v0 + _rnd[0] * (ePSRec.v1 - ePSRec.v0);

        Vector emitter_point;
        if (!dPSRec.continuous){ // sample_triangle -> sample barycentric
            _rnd[1] = (_rnd[1] - dPSRec.offset) / dPSRec.scale;
            Float a = std::sqrt(_rnd[1]);
            Float u = 1.0f - a, v = a * _rnd[2];
            emitter_point = dPSRec.v0 + (dPSRec.v1 - dPSRec.v0) * u + (dPSRec.v2 - dPSRec.v0) * v;
        } else { // sample UV -> find barycentric -> find position
            Vector query_point = Vector(_rnd[1], _rnd[2], 1.0);
            Ray query_ray = Ray(query_point, Vector(0.0, 0.0, -1.0));
            Array uvt = rayIntersectTriangle(dPSRec.uv_mesh0, dPSRec.uv_mesh1, dPSRec.uv_mesh2, query_ray); // find barycentric coord
            emitter_point = dPSRec.v0 + (dPSRec.v1 - dPSRec.v0) * uvt[0] + (dPSRec.v2 - dPSRec.v0) * uvt[1]; // interpolate position
        }

        Vector dir = edge_point - emitter_point;
        Float dist = dir.norm();
        dir /= dist;
        // value = scene.emitter_list[dPSRec.emitter_id]->eval(light_norm, dir);
        
        Ray ray(edge_point, dir);

        // Intersection its;
        // scene.rayIntersect(ray, true, its);
        Shape* shape_cam = scene.getShape(indices[0]);
        // scene.rayIntersect(ray, true, its);
        // shape_cam->rayIntersect(its_Cam.indices[1], ray, holder);
        Vector3i idx_cam = shape_cam->getIndices(indices[1]);
        Vector v0_cam = detach(shape_cam->getVertex(idx_cam[0]));
        Vector v1_cam = detach(shape_cam->getVertex(idx_cam[1]));
        Vector v2_cam = detach(shape_cam->getVertex(idx_cam[2]));

        Array uvt = rayIntersectTriangle(v0_cam, v1_cam, v2_cam, ray);
        Vector p = edge_point + uvt[2] * dir;
        Vector geo_n = detach(shape_cam->getFaceNormal(indices[1]));

        Float G = std::abs(dPSRec.n.dot(dir)) / (dist * dist);

        Spectrum value = scene.emitter_list[dPSRec.emitter_id]->eval(light_norm, dir);
        value = value * G / dPSRec.pdf; // compensate for sampling on emitter surface
        Float J = dlD_dlB(emitter_point,
                          edge_point, (ePSRec.v0 - ePSRec.v1).normalized(),
                          p, geo_n) *
                  dA_dw(edge_point, emitter_point, light_norm);

        // // Float sensorVal = scene.camera.eval(pixel_idx[0], pixel_idx[1], its_Cam.p);
        
        // // Spectrum bsdf_val = its_Cam.evalBSDF(-dir);
        // int bsdf_id = its.ptr_shape->bsdf_id;
        Float baseValue = J * geometric(p, geo_n, emitter_point, light_norm);
        contrib = baseValue * value / ePSRec.pdf; // eRec.pdf, bsdf_val
    }
    

    void d_getEdgeIntersection(const Vector &rnd, Vector &d_rnd, const Scene &scene, Scene &d_scene,
                    const EdgePrimarySampleRecord &ePSRec, const EmitterPrimarySampleRecord &dPSRec, 
                    const Array2i &indices,
                    Spectrum &d_contrib) {
        [[maybe_unused]] Spectrum contrib;
#if defined(ENZYME) && defined(ENZYME_BOUNDARY_DIRECT)
        __enzyme_autodiff((void *)getEdgeIntersection,
                            enzyme_dup, &rnd, &d_rnd,
                            enzyme_dup, &scene, &d_scene,
                            enzyme_const, &ePSRec,
                            enzyme_const, &dPSRec,
                            enzyme_const, &indices,
                            enzyme_dup, &contrib, &d_contrib);
#endif
    }

    void evalIntersectionRD(const Scene &scene, const RoughDielectricBSDF &rd, const Vector &rnd, const Array2i next_indices, const Intersection &cur, Intersection &next) { // assuming that we already have something in next
        Vector wo_local, wo;
        Float bsdf_pdf, bsdf_eta;
        int bsdf_id = cur.ptr_shape->bsdf_id;
        next.value = rd.sample(cur, rnd, wo_local,
                        bsdf_pdf, bsdf_eta,
                        EBSDFMode::EImportanceWithCorrection);
        wo = cur.toWorld(wo_local);
        Ray ray = Ray(cur.p, wo);
        Shape* shape = scene.getShape(next_indices[0]);
        Vector3i idx = shape->getIndices(next_indices[1]);
        Vector v0_next = detach(shape->getVertex(idx[0]));
        Vector v1_next = detach(shape->getVertex(idx[1]));
        Vector v2_next = detach(shape->getVertex(idx[2]));
        Array uvt = rayIntersectTriangle(v0_next, v1_next, v2_next, ray);
        next.p = cur.p + uvt[2] * wo;
        Vector geo_n = detach(shape->getFaceNormal(next_indices[1]));
        Vector sh_n = detach(shape->getShadingNormal(next_indices[1], Vector2(uvt[0], uvt[1])));
        next.geoFrame = Frame(geo_n);
        next.shFrame = Frame(sh_n);
        next.wi = next.toLocal(-wo);
        next.uv = Vector2(uvt[0], uvt[1]);
        return;
    }

    void d_evalIntersectionRD(const Scene &scene, Scene &d_scene,
                            const RoughDielectricBSDF &rd, 
                            const Vector &rnd, Vector &d_rnd, 
                            const Array2i next_indices, 
                            const Intersection &cur, Intersection &d_cur,
                            Intersection &d_its) {
        [[maybe_unused]] Intersection its;
#if defined(ENZYME) && defined(ENZYME_BOUNDARY_DIRECT)
        __enzyme_autodiff((void *)evalIntersectionRD,
                            enzyme_dup, &scene, d_scene,
                            enzyme_const, &rd,
                            enzyme_dup, &rnd, &d_rnd,
                            enzyme_const, &next_indices,
                            enzyme_dup, &cur, &d_cur,
                            enzyme_dup, &its, &d_its);
#endif
    }

    void evalIntersectionRC(const Scene &scene, const RoughConductorBSDF &rc, const Vector &rnd, const Array2i next_indices, const Intersection &cur, Intersection &next) {
        Vector wo_local, wo;
        Float bsdf_pdf, bsdf_eta;
        next.value = rc.sample(cur, rnd, wo_local,
                        bsdf_pdf, bsdf_eta,
                        EBSDFMode::EImportanceWithCorrection);
        wo = cur.toWorld(wo_local);
        Ray ray = Ray(cur.p, wo);
        Shape* shape = scene.getShape(next_indices[0]);
        Vector3i idx = shape->getIndices(next_indices[1]);
        Vector v0_next = detach(shape->getVertex(idx[0]));
        Vector v1_next = detach(shape->getVertex(idx[1]));
        Vector v2_next = detach(shape->getVertex(idx[2]));
        Array uvt = rayIntersectTriangle(v0_next, v1_next, v2_next, ray);
        next.p = cur.p + uvt[2] * wo;
        Vector geo_n = detach(shape->getFaceNormal(next_indices[1]));
        Vector sh_n = detach(shape->getShadingNormal(next_indices[1], Vector2(uvt[0], uvt[1])));
        next.geoFrame = Frame(geo_n);
        next.shFrame = Frame(sh_n);
        next.wi = next.toLocal(-wo);
        next.uv = Vector2(uvt[0], uvt[1]);
        return;
    }


    void d_evalIntersectionRC(const Scene &scene, Scene &d_scene,
                            const RoughConductorBSDF &rc, 
                            const Vector &rnd, Vector &d_rnd, 
                            const Array2i next_indices, 
                            const Intersection &cur, Intersection &d_cur,
                            Intersection &d_its) {
        [[maybe_unused]] Intersection val;
#if defined(ENZYME) && defined(ENZYME_BOUNDARY_DIRECT)
        __enzyme_autodiff((void *)evalIntersectionRC,
                            enzyme_dup, &scene, d_scene,
                            enzyme_const, &rc,
                            enzyme_dup, &rnd, &d_rnd,
                            enzyme_const, &next_indices,
                            enzyme_dup, &cur, &d_cur,
                            enzyme_dup, &val, &d_its);
#endif
    }

    void evalIntersectionDiffuse(const Scene &scene, const DiffuseBSDF &diffuse, const Vector &rnd, const Array2i next_indices, const Intersection &cur, Intersection &next) { // assuming that we already have something in next
        Vector wo_local, wo;
        Float bsdf_pdf, bsdf_eta;
        next.value = diffuse.sample(cur, rnd, wo_local,
                        bsdf_pdf, bsdf_eta,
                        EBSDFMode::EImportanceWithCorrection);
        wo = cur.toWorld(wo_local);
        Ray ray = Ray(cur.p, wo);
        Shape* shape = scene.getShape(next_indices[0]);
        Vector3i idx = shape->getIndices(next_indices[1]);
        Vector v0_next = detach(shape->getVertex(idx[0]));
        Vector v1_next = detach(shape->getVertex(idx[1]));
        Vector v2_next = detach(shape->getVertex(idx[2]));
        Array uvt = rayIntersectTriangle(v0_next, v1_next, v2_next, ray);
        next.p = cur.p + uvt[2] * wo;
        Vector geo_n = detach(shape->getFaceNormal(next_indices[1]));
        Vector sh_n = detach(shape->getShadingNormal(next_indices[1], Vector2(uvt[0], uvt[1])));
        next.geoFrame = Frame(geo_n);
        next.shFrame = Frame(sh_n);
        next.wi = next.toLocal(-wo);
        next.uv = detach(Vector2(uvt[0], uvt[1]));
        return;
    }

    void d_evalIntersectionDiffuse(const Scene &scene, Scene &d_scene,
                            const DiffuseBSDF &diffuse, 
                            const Vector &rnd, Vector &d_rnd, 
                            const Array2i next_indices, 
                            const Intersection &cur, Intersection &d_cur,
                            Intersection &d_its) {
        [[maybe_unused]] Intersection its;
#if defined(ENZYME) && defined(ENZYME_BOUNDARY_DIRECT)
        __enzyme_autodiff((void *)evalIntersectionDiffuse,
                            enzyme_dup, &scene, d_scene,
                            enzyme_const, &diffuse,
                            enzyme_dup, &rnd, &d_rnd,
                            enzyme_const, &next_indices,
                            enzyme_dup, &cur, &d_cur,
                            enzyme_dup, &its, &d_its);
#endif
    }

    void connectCamera(const Intersection &its, const Scene &scene, const Spectrum &weight,
                        Spectrum &contrib) 
    {
        // if (its.pixel_idx[0] < 0 || its.pixel_idx[0] > scene.camera.getNumPixels()){
        //     contrib = Spectrum::Zero();
        //     return;
        // }
        Vector cam_dir = (scene.camera.cpos - its.p).normalized();
        int bsdf_id = its.ptr_shape->bsdf_id;
        Spectrum bsdf_val = scene.bsdf_list[bsdf_id]->eval(its, its.toLocal(cam_dir), EBSDFMode::EImportanceWithCorrection);
        // auto bsdf_val = its.evalBSDF(its.toLocal(cam_dir),
        //                              EBSDFMode::EImportanceWithCorrection);
        contrib = weight * bsdf_val;
    }

    void connectCamera_echo(const Intersection &its, const Scene &scene, const Spectrum &weight,
                        Spectrum &contrib) 
    {
        // if (its.pixel_idx[0] < 0 || its.pixel_idx[0] > scene.camera.getNumPixels()){
        //     contrib = Spectrum::Zero();
        //     return;
        // }
        Vector cam_dir = (scene.camera.cpos - its.p).normalized();
        int bsdf_id = its.ptr_shape->bsdf_id;
        Spectrum bsdf_val = scene.bsdf_list[bsdf_id]->eval(its, its.toLocal(cam_dir), EBSDFMode::EImportanceWithCorrection);
        PSDR_INFO("bsdf_val: {}, {}, {}", bsdf_val[0], bsdf_val[1], bsdf_val[2]);
        PSDR_INFO("cam_dir: {}, {}, {}", its.toLocal(cam_dir)[0], its.toLocal(cam_dir)[1], its.toLocal(cam_dir)[2]);
        PSDR_INFO("its.wi: {}, {}, {}", its.wi[0], its.wi[1], its.wi[2]);
        // auto bsdf_val = its.evalBSDF(its.toLocal(cam_dir),
        //                              EBSDFMode::EImportanceWithCorrection);
        contrib = weight * bsdf_val;
    }

    void d_connectCamera(const Intersection &its, Intersection &d_its, const Scene &scene, Scene &d_scene, const Spectrum &weight, Spectrum &d_weight,
                        Spectrum &d_contrib) 
    {
        [[maybe_unused]] Spectrum contrib;
#if defined(ENZYME) && defined(ENZYME_BOUNDARY_DIRECT)
        __enzyme_autodiff((void *)connectCamera,
                            enzyme_dup, &its, &d_its,
                            enzyme_dup, &scene, &d_scene,
                            enzyme_dup, &weight, &d_weight,
                            enzyme_dup, &contrib, &d_contrib);
#endif
    }

    void evalPath(const Scene &scene, const LightPathPSS &path, Float &value)
    {
        value = 0.0;
        Spectrum throughput = path.baseValue;

        if (path.vertices.size() < 1)
            return;
        // value += edgeContrib;

        for (int i = 0; i < path.vertices.size(); i++)
        {
            throughput *= path.vertices[i].value;
            Spectrum camera_i = Spectrum::Zero();
            if (i == path.vertices.size() - 1){
                connectCamera(path.vertices[i], scene, throughput, camera_i);
                value += camera_i[0] + camera_i[1] + camera_i[2];
            }
        }
    }

    void d_evalPath(const Scene &scene, Scene &d_scene, LightPathPSSAD &pathAD) {
        // Float d_value = 1.0;
        // Float value = 0.0;
        // __enzyme_autodiff((void *)evalPath,
        //                     enzyme_dup, &scene, &d_scene,
        //                     enzyme_dup, &pathAD.val, &pathAD.der,
        //                     enzyme_dup, &value, &d_value);

        Float value = 0.0;
        Spectrum throughput = pathAD.val.baseValue;

        if (pathAD.val.vertices.size() < 1)
            return;
        // value += edgeContrib;

        for (int i = 0; i < pathAD.val.vertices.size(); i++)
        {
            throughput *= pathAD.val.vertices[i].value;
        }

        Float d_value = 1.0;
        Spectrum d_throughput = Spectrum::Zero();
        if (pathAD.val.vertices[pathAD.val.vertices.size() - 1].pixel_idx[0] >= 0 
            && pathAD.val.vertices[pathAD.val.vertices.size() - 1].pixel_idx[0] < scene.camera.getNumPixels()){
            Spectrum d_camera = Spectrum(d_value);
            d_connectCamera(pathAD.val.vertices[pathAD.val.vertices.size() - 1], pathAD.der.vertices[pathAD.val.vertices.size() - 1], 
                            scene, d_scene, throughput, d_throughput, d_camera);
        }
        for (int i = pathAD.val.vertices.size() - 1; i >= 0; i--) {
            // Spectrum d_camera_i = Spectrum(d_value);
            throughput /= pathAD.val.vertices[i].value;
            pathAD.der.vertices[i].value += d_throughput * throughput;
            d_throughput *= pathAD.val.vertices[i].value;
        }
        pathAD.der.baseValue += d_throughput;
    }

    void getPath(const Scene &scene, LightPathPSS &path) {
        Vector rnd_0 = MALA::pss_get(path.pss_state, 0);
        Vector d_rnd_0(0.0, 0.0, 0.0);
        // PSDR_INFO("path.vertices[0].indices: {}, {}", path.vertices[0].indices[0], path.vertices[0].indices[1]);
        getEdgeIntersectionDir(rnd_0, scene, path.ePSRec, path.dPSRec, path.vertices[0].indices, path.baseValue, path.vertices[0]);
        for (int i = 1; i < path.vertices.size(); i++) {
            // Vector rnd = path.pss_state.get(i);
            Vector rnd = MALA::pss_get(path.pss_state, i);
            char bsdf_name[100];
            bsdf_name[0] = 0;
            path.vertices[i-1].ptr_bsdf->className(bsdf_name);
            if (strcmp(bsdf_name, "RoughConductorBSDF") == 0) {
                const RoughConductorBSDF* rc = dynamic_cast<const RoughConductorBSDF*>(path.vertices[i-1].ptr_bsdf);
                // PSDR_INFO("RoughConductorBSDF");
                evalIntersectionRC(scene, *rc, rnd, path.vertices[i].indices, 
                                    path.vertices[i - 1], path.vertices[i]);
                PSDR_INFO("path.vertices[i]: {}, {}, {}", path.vertices[i].value[0], path.vertices[i].value[1], path.vertices[i].value[2]);
                // PSDR_INFO("rc d_rnd: {}, {}, {}", d_rnd[0], d_rnd[1], d_rnd[2]);
            } else if (strcmp(bsdf_name, "RoughDielectricBSDF") == 0) {
                const RoughDielectricBSDF* rd = dynamic_cast<const RoughDielectricBSDF*>(path.vertices[i-1].ptr_bsdf);
                evalIntersectionRD(scene, *rd, rnd, path.vertices[i].indices, 
                                    path.vertices[i - 1], path.vertices[i]);
            } else if (strcmp(bsdf_name, "DiffuseBSDF") == 0) {
                // PSDR_INFO("bsdf name: {}", bsdf_name);
                const DiffuseBSDF* df = dynamic_cast<const DiffuseBSDF*>(path.vertices[i-1].ptr_bsdf);
                // DiffuseBSDF* d_df = dynamic_cast<DiffuseBSDF*>(df->clone());
                evalIntersectionDiffuse(scene, *df, rnd, path.vertices[i].indices, 
                                    path.vertices[i - 1], path.vertices[i]);
            } else {
                PSDR_INFO("bsdf name: {}", bsdf_name);
                assert(false);
            }
            // PSDR_INFO("d_rnd: {}, {}, {}", d_rnd[0], d_rnd[1], d_rnd[2]);
        }
        // vec_set(pathAD.der.pss_state, 0, d_rnd_0);
    }

    void d_getPath(const Scene &scene, Scene &d_scene, LightPathPSSAD &pathAD) {
        // PSDR_INFO("path size: {}", pathAD.val.vertices.size());
        for (int i = pathAD.val.vertices.size() - 1; i >= 1; i--) {
            // Vector rnd = pathAD.val.pss_state.get(i);
            Vector rnd = MALA::pss_get(pathAD.val.pss_state, i);
            Vector d_rnd(0.0, 0.0, 0.0);
            // PSDR_INFO("d_value {}: {}, {}, {}", i, pathAD.der.vertices[i].value[0], pathAD.der.vertices[i].value[1], pathAD.der.vertices[i].value[2]);
            // PSDR_INFO("d_p {}: {}, {}, {}", i, pathAD.der.vertices[i].p[0], pathAD.der.vertices[i].p[1], pathAD.der.vertices[i].p[2]);
            char bsdf_name[100];
            bsdf_name[0] = 0;
            pathAD.val.vertices[i-1].ptr_bsdf->className(bsdf_name);
            if (strcmp(bsdf_name, "RoughConductorBSDF") == 0) {
                const RoughConductorBSDF* rc = dynamic_cast<const RoughConductorBSDF*>(pathAD.val.vertices[i-1].ptr_bsdf);
                // PSDR_INFO("RoughConductorBSDF");
                d_evalIntersectionRC(scene, d_scene, *rc, rnd, d_rnd, pathAD.val.vertices[i].indices, 
                                    pathAD.val.vertices[i - 1], pathAD.der.vertices[i - 1], pathAD.der.vertices[i]);
                // PSDR_INFO("rc d_rnd: {}, {}, {}", d_rnd[0], d_rnd[1], d_rnd[2]);
            } else if (strcmp(bsdf_name, "RoughDielectricBSDF") == 0) {
                const RoughDielectricBSDF* rd = dynamic_cast<const RoughDielectricBSDF*>(pathAD.val.vertices[i-1].ptr_bsdf);
                d_evalIntersectionRD(scene, d_scene, *rd, rnd, d_rnd, pathAD.val.vertices[i].indices, 
                                    pathAD.val.vertices[i - 1], pathAD.der.vertices[i - 1], pathAD.der.vertices[i]);
            } else if (strcmp(bsdf_name, "DiffuseBSDF") == 0) {
                const DiffuseBSDF* df = dynamic_cast<const DiffuseBSDF*>(pathAD.val.vertices[i-1].ptr_bsdf);
                d_evalIntersectionDiffuse(scene, d_scene, *df, rnd, d_rnd, pathAD.val.vertices[i].indices, 
                                    pathAD.val.vertices[i - 1], pathAD.der.vertices[i - 1], pathAD.der.vertices[i]);
            } else {
                // PSDR_INFO("bsdf name: {}", bsdf_name);
                assert(false);
            }
            // pathAD.der.pss_state.set(i, d_rnd);
            MALA::pss_set(pathAD.der.pss_state, i, d_rnd);
            // PSDR_INFO("d_rnd {}: {}, {}, {}", i, d_rnd[0], d_rnd[1], d_rnd[2]);
        }
        // Vector rnd_0 = pathAD.val.pss_state.get(0);
        Vector rnd_0 = MALA::pss_get(pathAD.val.pss_state, 0);
        Vector d_rnd_0(0.0, 0.0, 0.0);
        d_getEdgeIntersectionDir(rnd_0, d_rnd_0, scene, d_scene, pathAD.val.ePSRec, pathAD.val.dPSRec, pathAD.val.vertices[0].indices, pathAD.der.baseValue, pathAD.der.vertices[0]);
        // pathAD.der.pss_state.set(0, d_rnd_0);
        MALA::pss_set(pathAD.der.pss_state, 0, d_rnd_0);
    }

    Spectrum eval(const Scene &scene,
                    const ArrayXd &rnd, int max_bounces,
                    const DiscreteDistribution &edge_dist,
                    const std::vector<Vector2i> &edge_indices,
                    LightPathPSS *path, bool echo)
    {
        if (path != nullptr){
            path->vertices.clear();
            path->pss_state = rnd;
            path->discrete_dim.resize(rnd.size());
            path->discrete_dim.setZero();
        }
        assert(rnd.size() == 3 * max_bounces);
        Spectrum contrib(0.0);
        /* Sample a point on the boundary */
        // path.rnd_edge_emitter = d_rnd;
        // Vector rnd_curr = d_rnd.get_u(0);
        Vector rnd_curr = MALA::pss_get(rnd, 0);
        BoundarySamplingRecord eRec;
        // scene.sampleEdgePoint(rnd[0],
        //                       edge_dist, edge_indices,
        //                       eRec, (path == nullptr) ? nullptr : &(path->ePSRec));
        sampleEdgeRay(scene, rnd_curr, edge_dist, edge_indices, eRec, (path == nullptr) ? nullptr : &(path->ePSRec));
        if (eRec.shape_id == -1)
        {
            // PSDR_WARN(eRec.shape_id == -1);
            return contrib;
        }
        const Shape *shape = scene.shape_list[eRec.shape_id];
        const Edge &edge = shape->edges[eRec.edge_id];
        assert(edge.f0 >= 0);

        Ray edgeRay(eRec.ref, eRec.dir);
        Intersection itsS, its;
        if (!scene.rayIntersect(edgeRay, true, itsS) ||
            !scene.rayIntersect(edgeRay.flipped(), true, its))
            return contrib;
        if (!itsS.ptr_shape->isEmitter()) {
            // if (!its.ptr_shape->isEmitter())
                return contrib;
            // else {
            //     std::swap(itsS, its);
            //     edgeRay = edgeRay.flipped();
            //     eRec.pdf *= 2.0;
            //     if (path != nullptr) {
            //         path->ePSRec.ray_flipped = true;
            //         path->ePSRec.pdf = eRec.pdf;
            //     }
            // }
        }
        /* Sample point on emitters */
        // DirectSamplingRecord dRec(eRec.ref);
        Spectrum value = itsS.Le(-edgeRay.dir);

        if (path != nullptr) {
            path->dPSRec.emitter_id = itsS.ptr_shape->light_id;
            // path->dPSRec.emitter = itsS.ptr_emitter;
            path->dPSRec.triangle_id = itsS.indices[1];
            Vector3i ind = itsS.ptr_shape->getIndices(itsS.indices[1]);
            path->dPSRec.v0 = itsS.ptr_shape->getVertex(ind[0]);
            path->dPSRec.v1 = itsS.ptr_shape->getVertex(ind[1]);
            path->dPSRec.v2 = itsS.ptr_shape->getVertex(ind[2]);
            path->dPSRec.n = itsS.ptr_shape->getFaceNormal(itsS.indices[1]);
            path->type[0] = 'e';
        }
        // Vector2 rnd_light(rnd[1], rnd[2]);
        // Spectrum value = scene.sampleEmitterDirect(rnd_light, dRec, (path == nullptr) ? nullptr : &(path->dPSRec));
        // if (echo) {
        //     PSDR_INFO("value_fwd: {}, {}, {}", value[0], value[1], value[2]);
        //     PSDR_INFO("dRec.p: {}, {}, {}", dRec.p[0], dRec.p[1], dRec.p[2]);
        //     PSDR_INFO("dRec.barycentric: {}, {}", dRec.barycentric[0], dRec.barycentric[1]);
        // }
        if (value.isZero(Epsilon))
            return contrib;
        const Vector xB = eRec.ref,
                     &xS = itsS.p;
        // Ray ray(xB, (xB - xS).normalized());
        // Intersection its;
        // if (!scene.rayIntersect(ray, true, its))
        //     return contrib;

        // sanity check
        // make sure the ray is tangent to the surface
        // if (edge.f0 >= 0 && edge.f1 >= 0)
        // {
        //     Vector n0 = shape->getGeoNormal(edge.f0),
        //            n1 = shape->getGeoNormal(edge.f1);
        //     Float dotn0 = ray.dir.dot(n0),
        //           dotn1 = ray.dir.dot(n1);
        //     if (math::signum(dotn0) * math::signum(dotn1) > -0.5)
        //         return contrib;
        // }
        // NOTE prevent intersection with a backface
        Float gnDotD = its.geoFrame.n.dot(edgeRay.dir);
        Float snDotD = its.shFrame.n.dot(edgeRay.dir);
        bool success = (its.ptr_bsdf->isTransmissive() && math::signum(gnDotD) * math::signum(snDotD) > 0.5f) ||
                       (!its.ptr_bsdf->isTransmissive() && gnDotD > 0.01 && snDotD > 0.01);
        if (!success)
            return contrib;
        // populate the data in BoundarySamplingRecord eRec
        eRec.dir = edgeRay.dir;
        eRec.shape_id_S = itsS.indices[0];
        eRec.tri_id_S = itsS.indices[1];
        eRec.shape_id_D = its.indices[0];
        eRec.tri_id_D = its.indices[1];

        /* Jacobian determinant that accounts for the change of variable */
        Vector v0 = shape->getVertex(edge.v0);
        Vector v1 = shape->getVertex(edge.v1);
        Vector v2 = shape->getVertex(edge.v2);
        const Vector &xD = its.p;
        Vector n = (v0 - v1).cross(-edgeRay.dir).normalized();
        n *= -math::signum(n.dot(v2 - v0)); // make sure n points to the visible side
        Float J = dlD_dlB(xS,
                          xB, (v0 - v1).normalized(),
                          xD, its.geoFrame.n) *
                  dA_dw(xB, xS, itsS.geoFrame.n);
        Float baseValue = J * geometric(xD, its.geoFrame.n, xS, itsS.geoFrame.n);
        if (echo){
            PSDR_INFO("baseValue: {}", baseValue);
        }
        its.value = Spectrum(1.0);
        if (path != nullptr)
            path->baseValue = value * baseValue / eRec.pdf;
        if (std::abs(baseValue) < Epsilon)
            return contrib;

        if (path != nullptr){
            path->bound.min = edge_dist.m_cdf[path->ePSRec.edge_id];
            path->bound.max = edge_dist.m_cdf[path->ePSRec.edge_id + 1];
            path->bound.edge_idx = path->ePSRec.edge_id;
            path->bound.shape_idx = path->ePSRec.shape_id;
            path->bound.emitter_point = itsS.p;
            path->bound.emitter_dir = edgeRay.dir;
            path->bound.mode = 1;
            assert(path->ePSRec.shape_id >= 0 && path->ePSRec.shape_id < scene.shape_list.size());
            assert(path->bound.shape_idx >= 0 && path->bound.shape_idx < scene.shape_list.size());
            assert(path->bound.edge_idx >= 0 && path->bound.edge_idx < scene.shape_list[path->bound.shape_idx]->edges.size());
        }
        // assert(baseValue > -Epsilon);

        /* Sample detector path */
        Spectrum throughput(1.0f);
        Ray ray_sensor;
        for (int i = 0; i < max_bounces; i++)
        {
            if (echo){
                PSDR_INFO("iter: {}", i);
                PSDR_INFO("value: {}, {}, {}", value[0], value[1], value[2]);
                PSDR_INFO("pixel_idx {}: {}, {}", i, its.pixel_idx[0], its.pixel_idx[1]);
            }

            if (path != nullptr){
                path->vertices.push_back(its);
            }
            // contrib += contrib_i;
            if (i == max_bounces - 1){ // connect to camera
                // Spectrum contrib_i(0.0);
                CameraDirectSamplingRecord cRec;
                Vector cam_dir = (scene.camera.cpos - its.p).normalized();
                if (!scene.camera.sampleDirect(its.p, cRec))
                    its.pixel_idx = Array2i(-1, -1);
                if (!scene.isVisible(its.p, true, scene.camera.cpos, true))
                    its.pixel_idx = Array2i(-1, -1);
                auto [pixel_idx, sensor_val] = scene.camera.sampleDirectPixel(cRec, 0.5);
                its.pixel_idx = pixel_idx;
                if (sensor_val < Epsilon)
                    its.pixel_idx = Array2i(-1, -1);
                if (pixel_idx >= 0 && pixel_idx < scene.camera.getNumPixels()){
                    connectCamera(its, scene, value * baseValue / eRec.pdf, contrib);
                }
                break;
            }
            Vector wo_local, wo;
            Float bsdf_pdf, bsdf_eta;
            // rnd_curr = d_rnd.get_u(i + 1);
            rnd_curr = MALA::pss_get(rnd, i + 1);
            if (echo){
                PSDR_INFO("rnd_curr {}: {}, {}, {}", i, rnd_curr[0], rnd_curr[1], rnd_curr[2]);
            }
            Spectrum bsdf_weight = its.sampleBSDF(rnd_curr, wo_local,
                                                  bsdf_pdf, bsdf_eta,
                                                  EBSDFMode::EImportanceWithCorrection);
            if (bsdf_weight.isZero())
                break;
            if (path != nullptr){
                path->discrete_dim[(i + 1) * 3] = 0.0;
                path->discrete_dim[(i + 1) * 3 + 1] = 0.0;
                path->discrete_dim[(i + 1) * 3 + 2] = 1.0;
                if (its.bsdf_type == 0) {
                    path->type[i + 1] = 'd';
                } else if (its.bsdf_type == 1) {
                    path->type[i + 1] = 'r';
                } else if (its.bsdf_type == 2) {
                    path->type[i + 1] = 't';
                } else if (its.bsdf_type == -1) {
                    path->type[i + 1] = 'n';
                
                }
            }
            its.value = bsdf_weight;
            wo = its.toWorld(wo_local);
            Vector wi = its.toWorld(its.wi);
            Float wiDotGeoN = wi.dot(its.geoFrame.n), woDotGeoN = wo.dot(its.geoFrame.n);
            if (wiDotGeoN * its.wi.z() <= 0 || woDotGeoN * wo_local.z() <= 0){
                // PSDR_INFO("Invalid intersection, bounce: {}", i);
                break;
            }
            value *= bsdf_weight;
            ray_sensor = Ray(its.p, wo);
            bool hit = scene.rayIntersect(ray_sensor, true, its);
            if (!hit || !its.isValid()){
                // PSDR_INFO("Invalid intersection, bounce: {}", i);
                assert(contrib.isZero());
                break;
            }
        }
        return contrib;
    }

    void d_eval(const Scene &scene, Scene &d_scene,
                  LightPathPSSAD &pathAD
                  ) {
        LightPathPSS &path_der = pathAD.der;
        pathAD.der.clear(); // don't accumulate any exisiting value
        pathAD.der.resize(pathAD.val);
        pathAD.der.setZero();
        getPath(scene, pathAD.val);
        Float value = 0.0;
        evalPath(scene, pathAD.val, value);
        // PSDR_INFO("value in d_: {}", value);
        d_evalPath(scene, d_scene, pathAD);
        d_getPath(scene, d_scene, pathAD);
    }
NAMESPACE_END(algorithm1_MALA_direct)