#include "boundary.h"
#include <render/scene.h>
#include <core/math_func.h>
#include <core/timer.h>
#include <core/nanoflann.hpp>
#include <render/photon_map.h>
#include <core/logger.h>
#include <unsupported/Eigen/CXX11/Tensor>
#include <emitter/area.h>
#include <bsdf/diffuse.h>
#include <bsdf/roughconductor.h>
#include <bsdf/roughdielectric.h>

using MALA::PSS_State;

namespace {
    int sampleInt(const Float &rnd, const int &max) { // sample from {1, 2, ..., max}
        return std::min((int)(rnd * max), max - 1) + 1;
    }

    void sampleEdgeRay(const Scene &scene, const Array3 rnd,
                       const DiscreteDistribution &edge_dist,
                       const std::vector<Vector2i> &edge_indices,
                       EdgeRaySamplingRecord &eRec)
    {
        /* Sample a point on the boundary */
        scene.sampleEdgePoint(rnd[0],
                              edge_dist, edge_indices,
                              eRec);
        if (eRec.shape_id == -1)
        {
            PSDR_WARN(eRec.shape_id == -1);
            return;
        }
        // assert(eRec.shape_id == 0);
        const Shape *shape = scene.shape_list[eRec.shape_id];
        const Edge &edge = shape->edges[eRec.edge_id];
        assert(edge.f0 >= 0);
        assert(edge.mode != -1);

        /* Sample edge ray */
        if (edge.f1 < 0) // Case 1: boundary edge
        {
            eRec.dir = squareToUniformSphere(Vector2{rnd[1], rnd[2]});
            eRec.pdf /= 4. * M_PI;
        }
        else // Case 2: non-boundary edge
        {
            Vector n0 = shape->getGeoNormal(edge.f0);
            Vector n1 = shape->getGeoNormal(edge.f1);
            Float pdf1;
            eRec.dir = squareToEdgeRayDirection(Vector2(rnd[1], rnd[2]), n0, n1, pdf1);
            eRec.pdf *= pdf1;

            Float dotn0 = eRec.dir.dot(n0), dotn1 = eRec.dir.dot(n1);
            if (math::signum(dotn0) * math::signum(dotn1) > -0.5f)
            {
                std::cerr << "\n[WARN] Bad edge ray sample: [" << dotn0 << ", " << dotn1 << "]" << std::endl;
                PSDR_INFO("f0: {}, f1: {}", edge.f0, edge.f1);
                PSDR_INFO("n0: {}, {}, {}", n0[0], n0[1], n0[2]);
                PSDR_INFO("n1: {}, {}, {}", n1[0], n1[1], n1[2]);
                PSDR_INFO("dir: {}, {}, {}", eRec.dir[0], eRec.dir[1], eRec.dir[2]);
                eRec.shape_id = -1;
            }
        }
    }

    Float Xtou0(const Vector &X,
                       const DiscreteDistribution &edge_dist, int edge_idx,
                       const Vector &v0, const Vector &v1) {
        Float t = (X - v0).dot(v1 - v0) / (v1 - v0).squaredNorm();
        Float u0 = edge_dist.m_cdf[edge_idx] + t * (edge_dist.m_cdf[edge_idx + 1] - edge_dist.m_cdf[edge_idx]);
        return u0;
    }

    Vector2 Wtou1u2(const Vector &W,
                       const Vector &n0, const Vector &n1) {
        return edgeRayDirectionToSquare(W, n0, n1);
    }

    Vector du0todX(const Float rnd, const Float d_rnd,
                       const DiscreteDistribution &edge_dist,
                       const Vector &v0, const Vector &v1) { // du/dt -> dw/dt and dx/dt
        Float pdf = 1.f;
        Float re_rnd = rnd;
        int i = edge_dist.sampleReuse(re_rnd, pdf);
        re_rnd = (rnd - edge_dist.m_cdf[i]) / (edge_dist.m_cdf[i + 1] - edge_dist.m_cdf[i]);
        Float d_re_rnd = d_rnd / (edge_dist.m_cdf[i + 1] - edge_dist.m_cdf[i]);
        Vector X = v0 + re_rnd * (v1 - v0);
        Vector dX = (v1 - v0) * d_re_rnd;
        return dX;
    }

    Float dXtodu0(const Float rnd, const Vector &dX,
                       const DiscreteDistribution &edge_dist,
                       const Vector &v0, const Vector &v1) {
        Float pdf = 1.f;
        Float _rnd = rnd;
        int i = edge_dist.sampleReuse(_rnd, pdf);
        Float d_re_rnd = (v1 - v0).dot(dX) / (v1 - v0).squaredNorm();
        Float d_rnd = d_re_rnd * (edge_dist.m_cdf[i + 1] - edge_dist.m_cdf[i]);
        return d_rnd;
    }

    Vector du1u2todW(const Vector2 rnd, const Vector2 d_rnd,
                       const Vector &n0, const Vector &n1) { // du/dt -> dw/dt and dx/dt
        Vector d_dir = dsquareToDEdgeRayDirection(rnd, d_rnd, n0, n1);
        return d_dir;
    }

    void jacobianWtou1u2(const Vector2 &rnd, const Vector &n0, const Vector &n1, Vector &dWtodu1, Vector &dWtodu2) {
        dWtodu1 = du1u2todW(rnd, Vector2(1, 0), n0, n1);
        dWtodu2 = du1u2todW(rnd, Vector2(0, 1), n0, n1);
    }

    void jacobianXtou0(const Float &u0, const DiscreteDistribution &edge_dist, int edge_idx,
                       const Vector &v0, const Vector &v1, Vector &dXtodu0) {
        dXtodu0 = du0todX(u0, 1, edge_dist, v0, v1);
    }

    Vector2 dWtodu1u2(const Vector2 rnd, const Vector &dW,
                       const Vector &n0, const Vector &n1) {
        Float pdf = 1.0;
        Vector dir = squareToEdgeRayDirection(rnd, n0, n1, pdf);
        Vector2 d_rnd = dedgeRayDirectionToDSquare(rnd, dir, dW, n0, n1);
        return d_rnd;
    }

    static DebugInfo debugInfo;

    bool test_Edge_Silhouette(const Scene &scene, int edge_idx, int shape_idx, const Vector &ray_dir) {
        const Shape *shape = scene.shape_list[shape_idx];
        const Edge &edge = shape->edges[edge_idx];

        Vector n0 = shape->getGeoNormal(edge.f0);
        Vector n1;
        if (edge.f1 >= 0) {
            n1 = shape->getGeoNormal(edge.f1);
        } else {
            assert(false);
        }
        Float dotn0 = ray_dir.dot(n0), dotn1 = ray_dir.dot(n1);
        // PSDR_INFO("n0.dot(ray_dir): {}", n0.dot(ray_dir));
        // PSDR_INFO("n1.dot(ray_dir): {}", n1.dot(ray_dir));
        // if (edge.mode == -1)
        //     return false;
        if (edge.mode == 0 || (math::signum(dotn0) * math::signum(dotn1) < -0.5f && edge.mode != -1)){
            return true;
        }
        return false;
    }

    bool test_Mutation_Validity(const Scene &scene, const Vector &rnd0, const Vector &rnd1,
                                const DiscreteDistribution &edge_dist,
                                const std::vector<Vector2i> &edge_indices, const std::vector<int> &mutation_path) {
        BoundarySamplingRecord eRec0, eRec1;
        
        sampleEdgeRay(scene, rnd0, edge_dist, edge_indices, eRec0);

        sampleEdgeRay(scene, rnd1, edge_dist, edge_indices, eRec1);

        Vector dir0 = eRec0.dir;
        Vector dir1 = eRec1.dir;

        for (int i = 0; i < mutation_path.size(); i++) {
            int edge_idx = mutation_path[i];
            if (!test_Edge_Silhouette(scene, edge_idx, eRec0.shape_id, dir0) || !test_Edge_Silhouette(scene, edge_idx, eRec0.shape_id, dir1)) {
                // PSDR_INFO("edge {}: {} failed silhouette test", i, edge_idx);
                // Shape *shape = scene.shape_list[eRec0.shape_id];
                return false;
            }
        }
        return true;
    }

    bool edgeJump(const Vector &mutation_start, const Vector &mutation_dir, 
                   const Scene &scene, algorithm1_MALA_indirect::EdgeBound &bound, 
                   const std::vector<Vector2i> &edge_indices,
                   const std::vector<std::vector<int>> &edge_indices_inv, 
                   const DiscreteDistribution &edge_dist, 
                   RndSampler *sampler, Vector &final_dir_oppo,
                   std::vector<std::pair<Vector, int>> &mutation_path, bool across_edge) {
        final_dir_oppo = -mutation_dir;
        EdgeRaySamplingRecord eRec0, eRec1;
        sampleEdgeRay(scene, mutation_start, edge_dist, edge_indices, eRec0);
        // PSDR_INFO("mutation_start: {}, {}, {}", mutation_start[0], mutation_start[1], mutation_start[2]);
        // PSDR_INFO("mutation_start point: {}, {}, {}", eRec0.ref[0], eRec0.ref[1], eRec0.ref[2]);
        // PSDR_INFO("mutation_start dir: {}, {}, {}", eRec0.dir[0], eRec0.dir[1], eRec0.dir[2]);
        Vector mutation_dest = mutation_start + mutation_dir;
        // std::vector<int> mutation_path;
        Float distance = abs(mutation_dir[0]);

        Shape* shape = scene.shape_list[bound.shape_idx];
        int curr_edge_idx = bound.edge_idx;
        int curr_edge_idx_dist = edge_indices_inv[bound.shape_idx][curr_edge_idx];
        Edge edge = shape->edges[curr_edge_idx];

        if (edge.f1 >= 0) {
            PSDR_INFO("In edgeJump: edge.f1 >= 0");
            assert(false);
        }
        Float curr_edgelen;
        int transition_vertex_idx;
        bool v0_based;
        if (mutation_start[0] + mutation_dir[0] < bound.min) { // walk the first step depending on the direction to walk
            curr_edgelen = abs(mutation_start[0] - bound.min);
            transition_vertex_idx = edge.v0;
            v0_based = true;
        } else {
            curr_edgelen = abs(mutation_start[0] - bound.max);
            transition_vertex_idx = edge.v1;
            v0_based = false;
        }
        mutation_path.push_back(std::make_pair(mutation_start, 0));
        int counter = 0;
        // PSDR_INFO("bound.dir: {}, {}, {}", bound.dir[0], bound.dir[1], bound.dir[2]);
        // PSDR_INFO("curr_edge_idx: {}", curr_edge_idx);
        Vector ray_dir = bound.dir;

        if (distance < curr_edgelen) { // no need for mutation across edge
            mutation_path.push_back(std::make_pair(mutation_dest, 0));
            sampleEdgeRay(scene, mutation_dest, edge_dist, edge_indices, eRec1);
            // PSDR_INFO("mutation_dest: {}, {}, {}", mutation_dest[0], mutation_dest[1], mutation_dest[2]);
            // PSDR_INFO("mutation_end point: {}, {}, {}", eRec1.ref[0], eRec1.ref[1], eRec1.ref[2]);
            // PSDR_INFO("mutation_end dir: {}, {}, {}", eRec1.dir[0], eRec1.dir[1], eRec1.dir[2]);
            return true;
        }
        if (!across_edge) {
            return false;
        }
        while (distance > curr_edgelen){ // mutate to the appropriate edge
            // PSDR_INFO("distance: {}, curr_edgelen: {}", distance, curr_edgelen);
            // PSDR_INFO("curr_edge_idx: {}, transition_vertex_idx: {}", curr_edge_idx, transition_vertex_idx);
            distance -= curr_edgelen;
            counter++;
            // work at the transition vertex
            std::vector<int> valid_edges;
            Vector transition_vertex = shape->vertices[transition_vertex_idx];
            // test_Edge_Silhouette(scene, curr_edge_idx, bound.shape_idx, ray_dir);
            for (int i = 0; i < shape->adjacentEdges[transition_vertex_idx].size(); i++){
                int edge_idx = shape->adjacentEdges[transition_vertex_idx][i];
                // Edge edge = shape->edges[edge_idx];
                if (edge_idx == curr_edge_idx){
                    continue;
                }
                // Vector n0 = shape->getFaceNormal(edge.f0);
                // Vector n1 = shape->getFaceNormal(edge.f1);
                // if (edge.mode == 0 || (n0.dot(bound.ray_dir) * n1.dot(bound.ray_dir) < 0 && edge.mode != -1)){
                //     valid_edges.push_back(edge_idx);
                // }
                // PSDR_INFO("edge silhouette: {}", edge_idx);
                // PSDR_INFO("ray_dir: {}, {}, {}", ray_dir[0], ray_dir[1], ray_dir[2]);
                // Vector vertex0 = shape->vertices[shape->edges[edge_idx].v0];
                // PSDR_INFO("endpoint0: {}, {}, {}", vertex0[0], vertex0[1], vertex0[2]);
                // Vector vertex1 = shape->vertices[shape->edges[edge_idx].v1];
                // PSDR_INFO("endpoint1: {}, {}, {}", vertex1[0], vertex1[1], vertex1[2]);
                // if (test_Edge_Silhouette(scene, edge_idx, bound.shape_idx, ray_dir)){
                valid_edges.push_back(edge_idx);
                // }
            }
            // assert(valid_edges.size() == 1);
            if (valid_edges.size() == 0){
                // reject
                // PSDR_INFO("valid_edges.size() == 0");
                return false; // remain current state
            }
            Float rnd = sampler->next1D();
            int picked = floor(rnd * valid_edges.size()); // randomly pick an edge at the transition vertex
            // sample reuse
            // rnd = rnd * valid_edges.size() - picked;

            // move to the next transition vertex
            curr_edge_idx = valid_edges[picked];
            curr_edge_idx_dist = edge_indices_inv[bound.shape_idx][curr_edge_idx];
            // PSDR_INFO("picked {}: {}", mutation_path.size(), curr_edge_idx);
            // mutation_path.push_back(curr_edge_idx);
            if (transition_vertex_idx == shape->edges[curr_edge_idx].v0){
                transition_vertex_idx = shape->edges[curr_edge_idx].v1;
                v0_based = true;
            }
            else {
                transition_vertex_idx = shape->edges[curr_edge_idx].v0;
                v0_based = false;
            }
            curr_edgelen = edge_dist[curr_edge_idx_dist];
            if (v0_based) { // walk from v0 to v1
                mutation_dest[0] = edge_dist.m_cdf[curr_edge_idx_dist];
            }
            else {
                mutation_dest[0] = edge_dist.m_cdf[curr_edge_idx_dist + 1];
            }
            mutation_path.push_back(std::make_pair(mutation_dest, 0));
            // curr_edgelen = (shape->vertices[shape->edges[curr_edge_idx].v0] - shape->vertices[shape->edges[curr_edge_idx].v1]).norm();
        }
        Float remaining_dist_pss = (distance);
        if (v0_based) { // walk from v0 to v1
            mutation_dest[0] = edge_dist.m_cdf[curr_edge_idx_dist] + remaining_dist_pss;
            final_dir_oppo = Vector(-abs(mutation_dir[0]), -mutation_dir[1], -mutation_dir[2]);
        }
        else {
            mutation_dest[0] = edge_dist.m_cdf[curr_edge_idx_dist + 1] - remaining_dist_pss;
            final_dir_oppo = Vector(abs(mutation_dir[0]), -mutation_dir[1], -mutation_dir[2]);
        }
        mutation_path.push_back(std::make_pair(mutation_dest, 0));

        // sampleEdgeRay(scene, mutation_dest, edge_dist, edge_indices, eRec1);
        // PSDR_INFO("mutation_dest: {}, {}, {}", mutation_dest[0], mutation_dest[1], mutation_dest[2]);
        // PSDR_INFO("mutation_end point: {}, {}, {}", eRec1.ref[0], eRec1.ref[1], eRec1.ref[2]);
        // PSDR_INFO("mutation_end dir: {}, {}, {}", eRec1.dir[0], eRec1.dir[1], eRec1.dir[2]);

        if (!(mutation_dest[0] >= edge_dist.m_cdf[curr_edge_idx_dist] && mutation_dest[0] <= edge_dist.m_cdf[curr_edge_idx_dist + 1])) {
            // for (int i = 0; i < mutation_path.size(); i++) {
                // int path_idx = edge_indices_inv[bound.shape_idx][mutation_path[i]];
                // PSDR_INFO("mutation_path[{}]: {}", i, path_idx);
                // PSDR_INFO("     edge_dist.m_cdf[{}]: {}", path_idx, edge_dist.m_cdf[path_idx]);
                // PSDR_INFO("     edge_dist.m_cdf[{}]: {}", path_idx + 1, edge_dist.m_cdf[path_idx + 1]);
            // }
            // PSDR_INFO("proposal_state.u[0]: {}", proposal_state.u[0]);
            // PSDR_INFO("current_state.u[0]: {}", current_state.u[0]);
            PSDR_INFO("distance: {}", distance);
            PSDR_INFO("curr_edge_idx_dist: {}", curr_edge_idx_dist);
            PSDR_INFO("curr_edge_idx: {}", curr_edge_idx);
            PSDR_INFO("curr edge_dist.m_cdf[{}]: {}", curr_edge_idx_dist, edge_dist.m_cdf[curr_edge_idx_dist]);
            PSDR_INFO("curr edge_dist.m_cdf[{}]: {}", curr_edge_idx_dist + 1, edge_dist.m_cdf[curr_edge_idx_dist + 1]);
            PSDR_INFO("bound.min: {}, bound.max: {}", bound.min, bound.max);
            PSDR_INFO("edge_dist[curr_edge_idx_dist]: {}", edge_dist[curr_edge_idx_dist]);
            PSDR_INFO("curr_edgelen: {}", curr_edgelen);
            assert(false);
        }
        // if (!test_Mutation_Validity(scene, MALA::pss_get(mutation_start, 0), MALA::pss_get(mutation_dest, 0), edge_dist, edge_indices, mutation_path)) {
        //     // PSDR_INFO("mutation path: {}", mutation_path.size());
        //     return -1.0; // remain current state
        // }
        return true;
    }

    bool intersectBlock(const Vector &ori, const Vector &dir, const Vector &block_min, const Vector &block_max, 
                        Vector &isect, int &isect_face, Float &tmin) {
        isect_face = -1;
        // test if ori in in block
        const Float eps = 1e-4;
        if (!(ori[0] > block_min[0] - eps && ori[0] < block_max[0] + eps && 
            ori[1] > block_min[1] - eps && ori[1] < block_max[1] + eps && 
            ori[2] > block_min[2] - eps && ori[2] < block_max[2] + eps)) {
            PSDR_INFO("ori: {}, {}, {}", ori[0], ori[1], ori[2]);
            PSDR_INFO("block_min: {}, {}, {}", block_min[0], block_min[1], block_min[2]);
            PSDR_INFO("block_max: {}, {}, {}", block_max[0], block_max[1], block_max[2]);
            // assert(false);
            return false;
        }
        Vector t;
        for (int i = 0; i < 3; i++) {
            if (dir[i] > 0) {
                t[i] = (block_max[i] - ori[i]) / dir[i];
            } else if (dir[i] < 0) {
                t[i] = (block_min[i] - ori[i]) / dir[i];
            } else {
                t[i] = 10.0;
            }
        }
        tmin = std::min(std::min(t[0], t[1]), t[2]);
        isect_face = tmin == t[0] ? 0 : (tmin == t[1] ? 1 : 2);
        if (tmin < 0) {
            PSDR_INFO("tmin: {}", tmin);
            PSDR_INFO("ori: {}, {}, {}", ori[0], ori[1], ori[2]);
            PSDR_INFO("dir: {}, {}, {}", dir[0], dir[1], dir[2]);
            PSDR_INFO("block_min: {}, {}, {}", block_min[0], block_min[1], block_min[2]);
            PSDR_INFO("block_max: {}, {}, {}", block_max[0], block_max[1], block_max[2]);
            
            // assert(false);
            return false;
        }
        isect = ori + tmin * dir;
        return true;
    }

    bool blockJump(Vector mutation_start, Vector mutation_dir, 
                    const Scene &scene, algorithm1_MALA_indirect::EdgeBound &bound, 
                    const std::vector<Vector2i> &edge_indices,
                    const std::vector<std::vector<int>> &edge_indices_inv, 
                    const DiscreteDistribution &edge_dist, 
                    const std::vector<std::vector<Vector3i>> &face_to_edge_indices,
                    RndSampler *sampler, Vector &final_dir_oppo,
                    std::vector<std::pair<Vector, int>> &mutation_path, bool across_edge) {
        Vector mutation_dest = mutation_start + mutation_dir;
        mutation_path.clear();
        Float distance = mutation_dir.norm();
        Float total_dist = distance;
        Vector mutation_dir_normalized = mutation_dir.normalized();
        Shape* shape = scene.shape_list[bound.shape_idx];
        Vector bound_min = Vector(bound.min, 0, 0);
        Vector bound_max = Vector(bound.max, 1, 1);
        if (mutation_start[2] > 0.5) {
            bound_min[2] = 0.5;
        } else {
            bound_max[2] = 0.5;
        }
        Vector isect = mutation_dest;
        Vector current_pos = mutation_start;
        int transition_dim;
        Float curr_dist;
        if (!intersectBlock(current_pos, mutation_dir_normalized, bound_min, bound_max, isect, transition_dim, curr_dist)) {
            return false;
        }
        int curr_edge_idx = bound.edge_idx;
        int curr_edge_idx_dist = edge_indices_inv[bound.shape_idx][curr_edge_idx];
        mutation_path.push_back(std::make_pair(mutation_start, -1));
        int counter = 0;
        // PSDR_INFO("current_pos: {}, {}, {}", current_pos[0], current_pos[1], current_pos[2]);
        while (distance > curr_dist){ // mutate to the appropriate edge
            // PSDR_INFO("     distance: {}, curr_dist: {}", distance, curr_dist);
            // PSDR_INFO("     transition_dim: {}", transition_dim);
            if (!across_edge) {
                return false;
            }
            distance -= curr_dist;
            counter++;
            Edge edge = shape->edges[curr_edge_idx];
            Vector n0 = shape->getGeoNormal(edge.f0);
            if (edge.f1 < 0) {
                PSDR_INFO("In blockJump: edge.f1 < 0");
                assert(false);
            }
            Vector n1 = shape->getGeoNormal(edge.f1);
            Float pdf;

            Vector dX = du0todX(isect[0], mutation_dir_normalized[0], edge_dist, shape->vertices[edge.v0], shape->vertices[edge.v1]);
            Vector ray_dir = squareToEdgeRayDirection(Vector2(isect[1], isect[2]), n0, n1, pdf);
            Vector dW = du1u2todW(Vector2(isect[1], isect[2]), Vector2(mutation_dir_normalized[1], mutation_dir_normalized[2]), n0, n1);
            // work on the transition vertex
            if (transition_dim == 0) {

                // d_squareToEdgeRayDirection(Vector2(isect[1], isect[2]), Vector2(mutation_dir_normalized[1], mutation_dir_normalized[2]), n0, n1, ray_dir, dW);
                // // FD test
                // Vector W = squareToEdgeRayDirection(Vector2(isect[1], isect[2]), n0, n1, pdf);
                // Vector dW_fd = Vector::Zero();
                // Float eps = 1e-8;
                // for (int i = 0; i < 2; i++) {
                //     Vector2 rnd = Vector2(isect[1], isect[2]);
                //     rnd[i] += eps;
                //     Vector W_eps = squareToEdgeRayDirection(rnd, n0, n1, pdf);
                //     dW_fd += (W_eps - W) / eps;
                // }
                // Vector2 d_sample = Vector2(1, 1);
                // Vector dW_test = Vector::Zero();
                // Vector dW_mine = du1u2todW(Vector2(isect[1], isect[2]), d_sample, n0, n1);
                // Vector2 du1u2 = dWtodu1u2(Vector2(isect[1], isect[2]), dW_mine, n0, n1);
                // Vector W_test = Vector::Zero();
                // d_squareToEdgeRayDirection(Vector2(isect[1], isect[2]), d_sample, n0, n1, W_test, dW_test);
                // Float a, da, b, db;
                // d_squareToab(Vector2(isect[1], isect[2]), d_sample, n0, n1, a, da, b, db);
                // // PSDR_INFO("da, db: {}, {}", da, db);
                // EdgeSamplingRecord eRec0, eRec1;
                // eps = 1e-13;
                // scene.sampleEdgePoint(isect[0],
                //               edge_dist, edge_indices,
                //               eRec0);
                // scene.sampleEdgePoint(isect[0] - eps,
                //               edge_dist, edge_indices,
                //               eRec1);
                // Vector dX_fd = -(eRec1.ref - eRec0.ref) / (eps);
                // Vector dX_test;
                // dX_test = du0todX(isect[0], d_sample[0], edge_dist, shape->vertices[edge.v0], shape->vertices[edge.v1]);
                // PSDR_INFO("dX_fd: {}, {}, {}", dX_fd[0], dX_fd[1], dX_fd[2]);
                // PSDR_INFO("dX_mine: {}, {}, {}", dX_test[0], dX_test[1], dX_test[2]);
                // PSDR_INFO("W: {}, {}, {}", W[0], W[1], W[2]);
                // PSDR_INFO("ray_dir: {}, {}, {}", ray_dir[0], ray_dir[1], ray_dir[2]);
                // PSDR_INFO("dW_fd: {}, {}, {}", dW_fd[0], dW_fd[1], dW_fd[2]);
                // PSDR_INFO("dW_mine: {}, {}, {}", dW_mine[0], dW_mine[1], dW_mine[2]);
                // PSDR_INFO("dW_enzyme: {}, {}, {}", dW_test[0], dW_test[1], dW_test[2]);
                // PSDR_INFO("du1u2: {}, {}", du1u2[0], du1u2[1]);
                
                // assert(false);

                int transition_vertex_idx = shape->edges[curr_edge_idx].v0;
                if (mutation_dir_normalized[0] > 0) {
                    transition_vertex_idx = shape->edges[curr_edge_idx].v1;
                }
                std::vector<int> valid_edges;
                Vector transition_vertex = shape->vertices[transition_vertex_idx];

                for (int i = 0; i < shape->adjacentEdges[transition_vertex_idx].size(); i++){
                    int edge_idx = shape->adjacentEdges[transition_vertex_idx][i];
                    if (edge_idx == curr_edge_idx){
                        continue;
                    }

                    if (test_Edge_Silhouette(scene, edge_idx, bound.shape_idx, ray_dir)){
                        valid_edges.push_back(edge_idx);
                    }
                }
                // assert(valid_edges.size() == 1);
                if (valid_edges.size() == 0){
                    // reject
                    // assert(false); // silhouette dissapeared halfway: is this possible?
                    return false; // so it is.
                }
                Float rnd = sampler->next1D();
                int picked = floor(rnd * valid_edges.size()); // randomly pick an edge at the transition vertex
                // sample reuse
                // rnd = rnd * valid_edges.size() - picked;

                // move to the next transition vertex
                curr_edge_idx = valid_edges[picked];
                curr_edge_idx_dist = edge_indices_inv[bound.shape_idx][curr_edge_idx];
                // PSDR_INFO("picked {}: {}", mutation_path.size(), curr_edge_idx);
                
                edge = shape->edges[curr_edge_idx];
                n0 = shape->getGeoNormal(edge.f0);
                n1 = shape->getGeoNormal(edge.f1);
                Vector2 new_raypss = edgeRayDirectionToSquare(ray_dir, n0, n1);

                // Vector2 new_du1u2 = dWtodu1u2(new_raypss, dW, n0, n1);
                Float new_du0 = dXtodu0(isect[0], dX, edge_dist, shape->vertices[edge.v0], shape->vertices[edge.v1]);

                current_pos[1] = new_raypss[0];
                current_pos[2] = new_raypss[1];

                if (transition_vertex_idx == shape->edges[curr_edge_idx].v0){ // starting at v0, next transition vtx will be v1
                    transition_vertex_idx = shape->edges[curr_edge_idx].v1;
                    current_pos[0] = edge_dist.m_cdf[curr_edge_idx_dist];
                    mutation_dir_normalized[0] = abs(mutation_dir_normalized[0]); // going positive to v1
                }
                else {
                    transition_vertex_idx = shape->edges[curr_edge_idx].v0;
                    current_pos[0] = edge_dist.m_cdf[curr_edge_idx_dist + 1];
                    mutation_dir_normalized[0] = -abs(mutation_dir_normalized[0]); // going negative to v0
                }
                if (dW.norm() > 1e-8){
                    Vector2 u1u2 = Vector2(current_pos[1], current_pos[2]);
                    Vector dW_norm = dW.normalized();
                    Vector dW00 = du1u2todW(u1u2, Vector2(mutation_dir_normalized[1], mutation_dir_normalized[2]), n0, n1);
                    Float dist_dW00 = 1.0 - dW00.normalized().dot(dW_norm);

                    Vector dW01 = du1u2todW(u1u2, Vector2(mutation_dir_normalized[1], -mutation_dir_normalized[2]), n0, n1);
                    Float dist_dW01 = 1.0 - dW01.normalized().dot(dW_norm);

                    Vector dW10 = du1u2todW(u1u2, Vector2(-mutation_dir_normalized[1], mutation_dir_normalized[2]), n0, n1);
                    Float dist_dW10 = 1.0 - dW10.normalized().dot(dW_norm);

                    Vector dW11 = du1u2todW(u1u2, Vector2(-mutation_dir_normalized[1], -mutation_dir_normalized[2]), n0, n1);
                    Float dist_dW11 = 1.0 - dW11.normalized().dot(dW_norm);
                    
                    Float min_dist = std::min(std::min(std::min(dist_dW00, dist_dW01), dist_dW10), dist_dW11);
                    if (min_dist == dist_dW00) {
                        mutation_dir_normalized[1] = mutation_dir_normalized[1];
                        mutation_dir_normalized[2] = mutation_dir_normalized[2];
                    } else if (min_dist == dist_dW01) {
                        mutation_dir_normalized[1] = mutation_dir_normalized[1];
                        mutation_dir_normalized[2] = -mutation_dir_normalized[2];
                    } else if (min_dist == dist_dW10) {
                        mutation_dir_normalized[1] = -mutation_dir_normalized[1];
                        mutation_dir_normalized[2] = mutation_dir_normalized[2];
                    } else if (min_dist == dist_dW11){
                        mutation_dir_normalized[1] = -mutation_dir_normalized[1];
                        mutation_dir_normalized[2] = -mutation_dir_normalized[2];
                    } else {
                        PSDR_INFO("dW: {}, {}, {}", dW[0], dW[1], dW[2]);
                        PSDR_INFO("min_dist: {}", min_dist);
                        PSDR_INFO("dist_dW00: {}", dist_dW00);
                        PSDR_INFO("dist_dW01: {}", dist_dW01);
                        PSDR_INFO("dist_dW10: {}", dist_dW10);
                        PSDR_INFO("dist_dW11: {}", dist_dW11);
                        // assert(false);
                        return false;
                    }
                }

            } else if (transition_dim == 1) {
                EdgeSamplingRecord _;
                scene.sampleEdgePoint(isect[0],
                              edge_dist, edge_indices,
                              _);
                Vector ray_org = _.ref;
                ray_dir = squareToEdgeRayDirection(Vector2(isect[1], isect[2]), n0, n1, pdf, false);

                Vector2 sancheck = edgeRayDirectionToSquare(ray_dir, n0, n1);
                if (abs(sancheck[0] - isect[1]) > 1e-8 || abs(sancheck[1] - isect[2]) > 1e-8) {
                    PSDR_INFO("sancheck: {}, {}", sancheck[0], sancheck[1]);
                    PSDR_INFO("isect: {}, {}", isect[1], isect[2]);
                    PSDR_INFO("ray_dir: {}, {}, {}", ray_dir[0], ray_dir[1], ray_dir[2]);
                    assert(false);
                }

                ray_dir = ray_dir.normalized();
                Float raydotf0 = ray_dir.dot(n0);
                Float raydotf1 = ray_dir.dot(n1);
                int face_id = 0;
                if (abs(raydotf0) < 1e-8) { // jump to an edge on f0
                    face_id = edge.f0;
                } else if (abs(raydotf1) < 1e-8) { // jump to an edge on f1
                    face_id = edge.f1;
                } else {
                    PSDR_INFO("numerical error: raydotf0: {}, raydotf1: {}", raydotf0, raydotf1);
                    assert(false);
                }
                Vector in_tri_intersection;
                Float u, t;
                bool found = false;
                int old_edge_idx = curr_edge_idx;
                for (int i = 0; i < face_to_edge_indices[bound.shape_idx][face_id].size(); i++) { // find the edge and point that the ray intersects
                    int edge_idx = face_to_edge_indices[bound.shape_idx][face_id][i];
                    if (edge_idx == curr_edge_idx || edge_idx < 0) { // is current edge or concave edge
                        continue;
                    }
                    Edge edge_i = shape->edges[edge_idx];
                    if (rayLineIntersection(ray_org, ray_dir, shape->vertices[edge_i.v0], shape->vertices[edge_i.v1], 
                                            in_tri_intersection, u, t)) {
                        Edge target_edge = shape->edges[edge_idx];
                        if (target_edge.mode == -1) { // skip concave edges
                            continue;
                        }
                        curr_edge_idx = edge_idx;
                        curr_edge_idx_dist = edge_indices_inv[bound.shape_idx][curr_edge_idx];

                        found = true;
                        break;
                    }
                }
                if (!found) { // all edges concave, invalid walk
                    return false;
                }

                edge = shape->edges[curr_edge_idx];
                Float v0_pss = edge_dist.m_cdf[curr_edge_idx_dist];
                Float v1_pss = edge_dist.m_cdf[curr_edge_idx_dist + 1];
                n0 = shape->getGeoNormal(edge.f0);
                n1 = shape->getGeoNormal(edge.f1);
                Vector2 new_raypss = edgeRayDirectionToSquare(ray_dir, n0, n1);
                if (new_raypss[0] < -1e-4 || new_raypss[0] > 1 + 1e-4) {
                    Edge new_edge = shape->edges[curr_edge_idx];
                    Edge old_edge = shape->edges[old_edge_idx];
                    PSDR_INFO("new_edge.v0, v1: {}, {}", new_edge.v0, new_edge.v1);
                    PSDR_INFO("old_edge.v0, v1: {}, {}", old_edge.v0, old_edge.v1);
                    PSDR_INFO("new_raypss: {}, {}", new_raypss[0], new_raypss[1]);
                    PSDR_INFO("current_pos: {}, {}, {}", current_pos[0], current_pos[1], current_pos[2]);
                    PSDR_INFO("mutation_dir: {}, {}, {}", mutation_dir[0], mutation_dir[1], mutation_dir[2]);
                    PSDR_INFO("ray_dir: {}, {}, {}", ray_dir[0], ray_dir[1], ray_dir[2]);
                    PSDR_INFO("n0: {}, {}, {}", n0[0], n0[1], n0[2]);
                    PSDR_INFO("n1: {}, {}, {}", n1[0], n1[1], n1[2]);
                    // assert(false);
                }
                Vector new_pos;
                new_pos[0] = v0_pss + (v1_pss - v0_pss) * u;
                scene.sampleEdgePoint(new_pos[0],
                              edge_dist, edge_indices,
                              _);
                if ((_.ref - in_tri_intersection).norm() > 1e-5) { // san check
                    PSDR_INFO("_.ref: {}, {}, {}", _.ref[0], _.ref[1], _.ref[2]);
                    PSDR_INFO("in_tri_intersection: {}, {}, {}", in_tri_intersection[0], in_tri_intersection[1], in_tri_intersection[2]);
                    PSDR_INFO("new_pos: {}, {}, {}", new_pos[0], new_pos[1], new_pos[2]);
                    // assert(false);
                    return false;
                }
                new_pos[1] = new_raypss[0];
                new_pos[2] = new_raypss[1];
                
                Vector2 new_du1u2 = dWtodu1u2(new_raypss, dW, n0, n1);
                Float new_du0 = dXtodu0(isect[0], dX, edge_dist, shape->vertices[edge.v0], shape->vertices[edge.v1]);

                {
                    if (new_pos[1] < 1e-8) {
                        mutation_dir_normalized[1] = abs(mutation_dir_normalized[1]);
                    } else if (new_pos[1] > 1-1e-8) {
                        mutation_dir_normalized[1] = -abs(mutation_dir_normalized[1]);
                    }
                }
                {
                    Vector dX_norm = dX.normalized();
                    Vector dX0 = du0todX(new_pos[0], mutation_dir_normalized[0], edge_dist, shape->vertices[edge.v0], shape->vertices[edge.v1]).normalized();
                    Float dist_dX0 = 1.0 - dX0.dot(dX_norm);
                    Vector dX1 = du0todX(new_pos[0], -mutation_dir_normalized[0], edge_dist, shape->vertices[edge.v0], shape->vertices[edge.v1]).normalized();
                    Float dist_dX1 = 1.0 - dX1.dot(dX_norm);
                    if (dist_dX0 < dist_dX1) {
                        mutation_dir_normalized[0] = mutation_dir_normalized[0];
                    } else {
                        mutation_dir_normalized[0] = -mutation_dir_normalized[0];
                    }
                }
                {
                    Vector W_placeholder = Vector::Zero();
                    Vector dW_norm = dW.normalized();
                    Vector dW0 = du1u2todW(Vector2(new_pos[1], new_pos[2]), Vector2(mutation_dir_normalized[1], mutation_dir_normalized[2]), n0, n1);
                    Float dist_dW0 = 1.0 - dW0.normalized().dot(dW_norm);

                    Vector dW1 = du1u2todW(Vector2(new_pos[1], new_pos[2]), Vector2(mutation_dir_normalized[1], -mutation_dir_normalized[2]), n0, n1);
                    Float dist_dW1 = 1.0 - dW1.normalized().dot(dW_norm);
                    if (dist_dW0 < dist_dW1) {
                        mutation_dir_normalized[2] = mutation_dir_normalized[2];
                    } else {
                        mutation_dir_normalized[2] = -mutation_dir_normalized[2];
                    }
                }
                current_pos = new_pos;
            } else if (transition_dim == 2) {
                Vector new_pos = isect;
                Vector curr_dir = squareToEdgeRayDirection(Vector2(isect[1], isect[2]), n0, n1, pdf);
                // mutation_dir_normalized[2] = 1.0 - mutation_dir_normalized[2];
                if (bound_min[2] == 0.5) {
                    if (mutation_dir_normalized[2] > 0.0) {
                        new_pos[2] = 0.5 - 1e-8;
                    } else {
                        new_pos[2] = 0.0 + 1e-8;
                    }
                } else {
                    if (mutation_dir_normalized[2] > 0.0) {
                        new_pos[2] = 1.0 - 1e-8;
                    } else {
                        new_pos[2] = 0.5 + 1e-8;
                    }
                }
                mutation_dir_normalized[2] = -mutation_dir_normalized[2];
                current_pos = new_pos;
            }
            // EdgeRaySamplingRecord erRec;
            // sampleEdgeRay(scene, current_pos, edge_dist, edge_indices, erRec);
            // if (erRec.shape_id < 0) {
            //     return false;
            // }
            bound_min = Vector(edge_dist.m_cdf[curr_edge_idx_dist], 0, 0);
            bound_max = Vector(edge_dist.m_cdf[curr_edge_idx_dist + 1], 1, 1);
            if (current_pos[2] > 0.5) {
                bound_min[2] = 0.5;
            } else {
                bound_max[2] = 0.5;
            }
            if (bound_min[0] > bound_max[0]) {
                PSDR_INFO("bound_min: {}, {}, {}", bound_min[0], bound_min[1], bound_min[2]);
                PSDR_INFO("bound_max: {}, {}, {}", bound_max[0], bound_max[1], bound_max[2]);
                PSDR_INFO("current_pos: {}, {}, {}", current_pos[0], current_pos[1], current_pos[2]);
                PSDR_INFO("curr_edge_idx: {}", curr_edge_idx);
                PSDR_INFO("curr_edge_idx_dist: {}", curr_edge_idx_dist);
                PSDR_INFO("edge_dist[curr_edge_idx_dist]: {}", edge_dist[curr_edge_idx_dist]);
                PSDR_INFO("edge_dist[curr_edge_idx_dist + 1]: {}", edge_dist[curr_edge_idx_dist + 1]);
                PSDR_INFO("transition_dim: {}", transition_dim);
                assert(false);
            }
            // curr_dist = intersectBlock(current_pos, mutation_dir_normalized, bound_min, bound_max, isect, transition_dim);
            if (!intersectBlock(current_pos, mutation_dir_normalized, bound_min, bound_max, isect, transition_dim, curr_dist)) {
                return false;
            }
            mutation_path.push_back(std::make_pair(current_pos + mutation_dir_normalized * 1e-5, transition_dim));
            // if (counter > 1000) {
            //     return false;
            // }
        }
        Float remaining_dist_pss = (distance);
        Vector final_pos = current_pos + remaining_dist_pss * mutation_dir_normalized;
        // PSDR_INFO("remaining_dist_pss: {}", remaining_dist_pss);
        // PSDR_INFO("mutation_dir: {}, {}, {}", mutation_dir[0], mutation_dir[1], mutation_dir[2]);
        // PSDR_INFO("mutation_dir_normalized: {}, {}, {}", mutation_dir_normalized[0], mutation_dir_normalized[1], mutation_dir_normalized[2]);
        // PSDR_INFO("final_pos: {}, {}, {}", final_pos[0], final_pos[1], final_pos[2]);
        final_dir_oppo = -mutation_dir_normalized * total_dist;
        mutation_path.push_back(std::make_pair(final_pos, -1));
        return true;
    }

    [[maybe_unused]] void velocity(const Scene &scene,
                                   const BoundarySamplingRecord &bRec,
                                   Float &res)
    {
        const Shape *shape = scene.shape_list[bRec.shape_id];
        const Edge &edge = shape->edges[bRec.edge_id];
        const Vector &xB_0 = shape->getVertex(edge.v0);
        const Vector &xB_1 = shape->getVertex(edge.v1);
        const Vector &xB_2 = shape->getVertex(edge.v2);

        const Shape *shapeS = scene.shape_list[bRec.shape_id_S];
        const auto &indS = shapeS->getIndices(bRec.tri_id_S);
        const Vector &xS_0 = shapeS->getVertex(indS[0]);
        const Vector &xS_1 = shapeS->getVertex(indS[1]);
        const Vector &xS_2 = shapeS->getVertex(indS[2]);

        const Shape *shapeD = scene.shape_list[bRec.shape_id_D];
        const auto &indD = shapeD->getIndices(bRec.tri_id_D);
        const Vector &xD_0 = shapeD->getVertex(indD[0]);
        const Vector &xD_1 = shapeD->getVertex(indD[1]);
        const Vector &xD_2 = shapeD->getVertex(indD[2]);

        res = normal_velocity(xS_0, xS_1, xS_2,
                              xB_0, xB_1, xB_2, bRec.t, bRec.dir,
                              xD_0, xD_1, xD_2);
    }

    void d_velocity(const Scene &scene, Scene &d_scene,
                    EdgeSamplingRecord &eRec,
                    Float d_u)
    {
        [[maybe_unused]] Float u;
#if defined(ENZYME) && defined(ENZYME_BOUNDARY_DIRECT)
        __enzyme_autodiff((void *)velocity,
                          enzyme_dup, &scene, &d_scene,
                          enzyme_const, &eRec,
                          enzyme_dupnoneed, &u, &d_u);
#endif
    }
    
    inline Spectrum connectCamera(const Intersection &its,
                                         const Scene &scene,
                                         RndSampler &sampler)
    {
        CameraDirectSamplingRecord cRec;
        if (!scene.camera.sampleDirect(its.p, cRec))
            return Spectrum(0.f);
        if (!scene.isVisible(its.p, true, scene.camera.cpos, true))
            return Spectrum(0.f);
        auto [pixel_idx, sensor_val] = scene.camera.sampleDirectPixel(cRec, sampler.next1D());
        if (sensor_val < Epsilon)
            return Spectrum(0.f);
        // char bsdf_name[100];
        // bsdf_name[0] = 0;
        // its.getBSDF()->className(bsdf_name);
        // bool isDiffuse = (strcmp(bsdf_name, "DiffuseBSDF") == 0);
        // if (isDiffuse) {
        //     return Spectrum(0.f);
        // }
        Spectrum bsdf_val = its.evalBSDF(its.toLocal(cRec.dir),
                                     EBSDFMode::EImportanceWithCorrection);
        return bsdf_val;
    }

    inline bool handleSurfaceInteraction(const Intersection &its,
                                         const Scene &scene, Scene &d_scene,
                                         EdgeRaySamplingRecord &eRec,
                                         RndSampler &sampler,
                                         std::vector<Spectrum> &d_image)
    {
        int shape_idx = -1;
        if (forward) shape_idx = scene.getShapeRequiresGrad();
        CameraDirectSamplingRecord cRec;
        if (!scene.camera.sampleDirect(its.p, cRec)){
            // PSDR_INFO("!sampledirect");
            assert(false);
            return false;
        }
        if (!scene.isVisible(its.p, true, scene.camera.cpos, true)){
            // PSDR_INFO("!visible");
            assert(false);
            return false;
        }
        auto [pixel_idx, sensor_val] = scene.camera.sampleDirectPixel(cRec, sampler.next1D());
        if (sensor_val < Epsilon){
            // PSDR_INFO("!sensorval");
            assert(false);
            return false;
        }

        // auto bsdf_val = its.evalBSDF(its.toLocal(cRec.dir),
        //                              EBSDFMode::EImportanceWithCorrection);
        // Spectrum value = weight * bsdf_val;
        Float d_u = (d_image[pixel_idx] * sensor_val).sum();

        // if (d_u < 1e-10)
        //     return false;
        Float param = 0;
        if (forward) {
            d_scene.shape_list[shape_idx]->param = 0.;
            param = d_scene.shape_list[shape_idx]->param;
        }
        d_velocity(scene, d_scene, eRec, d_u);
        if (forward) {
            param = d_scene.shape_list[shape_idx]->param - param;
            const int tid = omp_get_thread_num();
            debugInfo.image_per_thread[tid][pixel_idx] += Spectrum(param, 0, 0);
        }
        return true;
    }

    struct PathSampleRecord {
        std::vector<Vector> pss;
        std::vector<Intersection> path;
        Spectrum contrib = Spectrum(0.f);
    };
    
    struct BoundaryRadianceQueryRecord
    {
        RndSampler *sampler;
        int max_bounces;
        std::vector<Spectrum> values;
        BoundaryRadianceQueryRecord(RndSampler *sampler, int max_bounces)
            : sampler(sampler), max_bounces(max_bounces), values(max_bounces + 1, Spectrum::Zero()) {}
    };
    

    Spectrum __Li(const Scene &scene, const Ray &_ray, BoundaryRadianceQueryRecord &rRec, std::vector<PathSampleRecord> &SPaths)
    {
        SPaths.clear();
        Ray ray(_ray);
        RndSampler *sampler = rRec.sampler;
        Intersection its;
        PathSampleRecord SPath;
        Spectrum ret = Spectrum::Zero();
        scene.rayIntersect(ray, true, its);
        if (!its.isValid())
            return Spectrum::Zero();

        SPath.path.push_back(its);

        Spectrum throughput = Spectrum::Ones();
        if (its.isEmitter())
        {
            ret += throughput * its.Le(-ray.dir);
            rRec.values[0] = throughput * its.Le(-ray.dir);
            PathSampleRecord SPathDirect = SPath;
            SPathDirect.contrib = rRec.values[0];
            SPaths.push_back(SPathDirect);
            // return ret;
        }
        for (int depth = 0; depth < rRec.max_bounces && its.isValid(); depth++)
        {
            // Direct illumination
            // Float pdf_nee;
            Vector wo;
            DirectSamplingRecord dRec(its);
            char bsdf_name[100];
            bsdf_name[0] = 0;
            its.getBSDF()->className(bsdf_name);
            bool isDiffuse = (strcmp(bsdf_name, "DiffuseBSDF") == 0);
            // bool isDiffuse = false;

            if (isDiffuse) {
                Vector rnd = sampler->next3D();
                auto value = scene.sampleEmitterDirect(Vector2(rnd.x(), rnd.y()), dRec);
                wo = its.toLocal(dRec.dir.normalized());
                if (!value.isZero(Epsilon) && its.wi.z() > Epsilon)
                {
                    // PSDR_INFO("NEE depth: {}", depth);
                    auto bsdf_val = its.evalBSDF(wo);
                    ret += throughput * value * bsdf_val;
                    rRec.values[depth + 1] = throughput * value * bsdf_val;

                    PathSampleRecord SPathNEE = SPath;
                    Vector2 bsdf_rnd = invSquareToCosineHemisphere(wo);
                    Vector wo_dd;
                    Float pdf, eta;
                    its.sampleBSDF(Vector(bsdf_rnd.x(), bsdf_rnd.y(), rnd[2]), wo_dd, pdf, eta);
                    if((wo_dd - wo).norm() > 1e-6) {
                        PSDR_INFO("wo: {}, {}, {}", wo[0], wo[1], wo[2]);
                        PSDR_INFO("wo_dd: {}, {}, {}", wo_dd[0], wo_dd[1], wo_dd[2]);
                        PSDR_INFO("bsdf_rnd: {}, {}", bsdf_rnd[0], bsdf_rnd[1]);
                        Vector wo_dd2 = squareToCosineHemisphere(Vector2(bsdf_rnd.x(), bsdf_rnd.y()));
                        PSDR_INFO("wo_dd2: {}, {}, {}", wo_dd2[0], wo_dd2[1], wo_dd2[2]);
                        PSDR_INFO("its.wi: {}, {}, {}", its.wi[0], its.wi[1], its.wi[2]);
                        assert(false);
                    }
                    Spectrum value_fake = dRec.emittance * bsdf_val * throughput;
                    // PSDR_INFO("value_fake: {}, {}, {}", value_fake[0], value_fake[1], value_fake[2]);

                    SPathNEE.pss.push_back(Vector(bsdf_rnd.x(), bsdf_rnd.y(), rnd[2]));
                    Intersection its_S(dRec);
                    its_S.p = dRec.p;
                    SPathNEE.path.push_back(its_S);
                    SPathNEE.contrib = rRec.values[depth + 1];
                    SPaths.push_back(SPathNEE);

                }
            }
            // Indirect illumination
            Float bsdf_pdf, bsdf_eta;
            Vector rnd_i = sampler->next3D();
            auto bsdf_weight = its.sampleBSDF(rnd_i, wo, bsdf_pdf, bsdf_eta);
            if (bsdf_weight.isZero(Epsilon))
                break;
            wo = its.toWorld(wo);
            ray = Ray(its.p, wo);

            Vector pre_p = its.p;
            if (!scene.rayIntersect(ray, true, its))
                break;

            SPath.pss.push_back(rnd_i);
            SPath.path.push_back(its);
            throughput *= bsdf_weight;

            if (its.isEmitter())
            {
                Spectrum light_contrib = its.Le(-ray.dir);
                if (!light_contrib.isZero(Epsilon))
                {
                    if (!isDiffuse){
                        ret += throughput * light_contrib;
                        rRec.values[depth + 1] = throughput * light_contrib;

                        PathSampleRecord SPathBSDF = SPath;
                        SPathBSDF.contrib = rRec.values[depth + 1];
                        SPaths.push_back(SPathBSDF);
                    }
                }
            }
        }
        return ret;
    }

    struct SeedPath {
        Spectrum weight;
        MALA::MALAVector pss;
        int cam_b;
        int max_b;
        SeedPath() = default;
        SeedPath(const Spectrum &spec): weight(spec) {}
    };

    MALA::MALAVector compositePath(const PathSampleRecord &SPath, const PathSampleRecord &DPath, const Vector &pssE) {
        MALA::MALAVector pss(3 * (SPath.pss.size() + DPath.pss.size() + 1));
        MALA::pss_set(pss, 0, pssE);
        int idx = 1;
        for (int i = 0; i < SPath.pss.size(); i++) {
            MALA::pss_set(pss, i + 1, SPath.pss[i]);
        }
        for (int i = 0; i < DPath.pss.size(); i++) {
            MALA::pss_set(pss, i + 1 + SPath.pss.size(), DPath.pss[i]);
        }
        return pss;
    }

    Spectrum sampleIndirectBoundary(const Scene &scene, const Vector &edge_rnd,
                                RndSampler &sampler, const int &max_bounces,
                                const DiscreteDistribution &edge_dist,
                                const std::vector<Vector2i> &edge_indices, std::vector<SeedPath> &seed_paths)
    {
        Vector rnd = edge_rnd;
        BoundarySamplingRecord eRec;
        // EdgeRaySamplingRecord eRec;
        sampleEdgeRay(scene, rnd, edge_dist, edge_indices, eRec);
        if (eRec.shape_id < 0)
        {
            return Spectrum(0.f);
        }
        const Shape *shape = scene.shape_list[eRec.shape_id];
        const Edge &edge = shape->edges[eRec.edge_id];
        Ray edgeRay(eRec.ref, eRec.dir);
        Intersection itsS, itsD;
        if (!scene.rayIntersect(edgeRay, true, itsS) ||
            !scene.rayIntersect(edgeRay.flipped(), true, itsD))
            return Spectrum(0.f);
        // populate the data in BoundarySamplingRecord eRec
        eRec.shape_id_S = itsS.indices[0];
        eRec.tri_id_S = itsS.indices[1];
        eRec.shape_id_D = itsD.indices[0];
        eRec.tri_id_D = itsD.indices[1];

        // make sure the ray is tangent to the surface
        if (edge.f0 >= 0 && edge.f1 >= 0)
        {
            Vector n0 = shape->getGeoNormal(edge.f0),
                   n1 = shape->getGeoNormal(edge.f1);
            Float dotn0 = edgeRay.dir.dot(n0),
                  dotn1 = edgeRay.dir.dot(n1);
            if (math::signum(dotn0) * math::signum(dotn1) > -0.5)
            {
                PSDR_ASSERT_MSG(false, "Bad edge ray sample: [ {}, {} ]", dotn0, dotn1);
                return Spectrum(0.f);
            }
        }

        /* prevent self intersection */
        const Vector2i ind0(eRec.shape_id, edge.f0), ind1(eRec.shape_id, edge.f1);
        if (itsS.indices == ind0 || itsS.indices == ind1 ||
            itsD.indices == ind0 || itsD.indices == ind1)
            return Spectrum(0.f);
        // FIXME: if the isTransmissive() of a dielectric returns false, the rendering time will be short, and the variance will be small.
        const Float gn1d1 = itsS.geoFrame.n.dot(-edgeRay.dir), sn1d1 = itsS.shFrame.n.dot(-edgeRay.dir),
                    gn2d1 = itsD.geoFrame.n.dot(edgeRay.dir), sn2d1 = itsD.shFrame.n.dot(edgeRay.dir);
        bool valid1 = (itsS.ptr_bsdf->isTransmissive() && math::signum(gn1d1) * math::signum(sn1d1) > 0.5f) || (!itsS.ptr_bsdf->isTransmissive() && gn1d1 > Epsilon && sn1d1 > Epsilon),
             valid2 = (itsD.ptr_bsdf->isTransmissive() && math::signum(gn2d1) * math::signum(sn2d1) > 0.5f) || (!itsD.ptr_bsdf->isTransmissive() && gn2d1 > Epsilon && sn2d1 > Epsilon);
        if (itsS.isEmitter())
            valid1 = true;
        if (!valid1 || !valid2)
            return Spectrum(0.f);

        /* Jacobian determinant that accounts for the change of variable */
        Vector v0 = shape->getVertex(edge.v0);
        Vector v1 = shape->getVertex(edge.v1);
        Vector v2 = shape->getVertex(edge.v2);
        const Vector xB = v0 + (v1 - v0) * eRec.t,
                     &xS = itsS.p;
        const Vector &xD = itsD.p;
        Vector n = (v0 - v1).cross(-edgeRay.dir).normalized();
        n *= -math::signum(n.dot(v2 - v0)); // make sure n points to the visible side
        Float J = dlD_dlB(xS,
                          xB, (v0 - v1).normalized(),
                          xD, itsD.geoFrame.n) *
                  dA_dw(xB, xS, itsS.geoFrame.n);
        Float baseValue = J * geometric(xD, itsD.geoFrame.n, xS, itsS.geoFrame.n);
        // assert(baseValue > 0.0);

        /* Sample source path */
        BoundaryRadianceQueryRecord rRec(&sampler, max_bounces);
        std::vector<PathSampleRecord> SPaths;
        __Li(scene, Ray{xB, (xS - xB).normalized()}.shifted(), rRec, SPaths);
        // if (SPaths.size() > 0){
        //     PSDR_INFO("SPath.size(): {}", SPaths[0].pss.size());
        // }
        std::vector<Spectrum> radiances(std::move(rRec.values));
        for (int i = 1; i < radiances.size(); i++)
            radiances[i] += radiances[i - 1];

        /* Sample detector path */
        Spectrum throughput(1.0f);
        Ray ray_sensor;
        Intersection its(itsD);
        Spectrum contrib = Spectrum(0.f);

        PathSampleRecord DPath;
        DPath.path.push_back(its);

        for (int i = 0; i < max_bounces; i++)
        {
            Spectrum cam_bsdf = connectCamera(its, scene, sampler);
            for (int j = 0; j < SPaths.size(); j++){
                if (SPaths[j].pss.size() > max_bounces - 1 - i)
                    continue;
                Spectrum contrib_ij = cam_bsdf * SPaths[j].contrib * throughput * baseValue / eRec.pdf;
                if (contrib_ij.isZero(1e-8))
                    continue;
                SeedPath seed_path(contrib_ij);
                seed_path.pss = compositePath(SPaths[j], DPath, rnd);
                seed_path.cam_b = i;
                seed_path.max_b = seed_path.pss.size() / 3;
                contrib += contrib_ij;
                seed_paths.push_back(seed_path);
                // Spectrum contrib_ij_ref = algorithm1_MALA_indirect::eval(scene, seed_path.pss, seed_path.max_b, seed_path.cam_b, edge_dist, edge_indices);
                // Spectrum diff = contrib_ij - contrib_ij_ref;
                // if (diff.abs().sum() > 1e-2) {
                //     algorithm1_MALA_indirect::LightPathPSS path(3 * seed_path.max_b);
                //     Spectrum contrib_ij_ref = algorithm1_MALA_indirect::eval(scene, seed_path.pss, seed_path.max_b, seed_path.cam_b, edge_dist, edge_indices, &path, true);
                //     PSDR_INFO("i: {}, j: {}, SPaths[j].size: {}", i, j, SPaths[j].pss.size());
                //     PSDR_INFO("valueS: {}, {}, {}", SPaths[j].contrib[0], SPaths[j].contrib[1], SPaths[j].contrib[2]);
                //     Spectrum valueD = cam_bsdf * throughput;
                //     Intersection its = SPaths[j].path[SPaths[j].path.size() - 1];
                //     PSDR_INFO("emitter_p: {}, {}, {}", its.p[0], its.p[1], its.p[2]);
                    
                //     PSDR_INFO("contrib_ij: {}, {}, {}", contrib_ij[0], contrib_ij[1], contrib_ij[2]);
                //     PSDR_INFO("contrib_ij_ref: {}, {}, {}", contrib_ij_ref[0], contrib_ij_ref[1], contrib_ij_ref[2]);
                //     assert(false);
                // }
            }
            Vector wo_local, wo;
            Float bsdf_pdf, bsdf_eta;
            Vector rnd_i = sampler.next3D();
            Spectrum bsdf_weight = its.sampleBSDF(rnd_i, wo_local,
                                                  bsdf_pdf, bsdf_eta,
                                                  EBSDFMode::EImportanceWithCorrection);
            if (bsdf_weight.isZero())
                break;
            wo = its.toWorld(wo_local);
            Vector wi = its.toWorld(its.wi);
            Float wiDotGeoN = wi.dot(its.geoFrame.n), woDotGeoN = wo.dot(its.geoFrame.n);
            if (wiDotGeoN * its.wi.z() <= 0 || woDotGeoN * wo_local.z() <= 0)
                break;
            throughput *= bsdf_weight;
            ray_sensor = Ray(its.p, wo);
            scene.rayIntersect(ray_sensor, true, its);
            if (!its.isValid())
                break;
            DPath.pss.push_back(rnd_i);
            DPath.path.push_back(its);
        }
        return contrib;
    }

    void d_sampleIndirectBoundary(const Scene &scene, Scene &d_scene,
                                RndSampler &sampler, const int max_b, const int cam_b,
                                const DiscreteDistribution &edge_dist,
                                const std::vector<Vector2i> &edge_indices,
                                std::vector<Spectrum> &d_image, const MALA::MALAVector &pss)
    {
        // Float intpdf = 1.0 / Float(options.max_bounces) / Float(max_b);

        // int max_b = 3;
        // int cam_b = 0;
        // Float intpdf = 1.0;
        Vector pss_stateE = MALA::pss_get(pss, 0);
        MALA::MALAVector pss_stateD, pss_stateS;
        if (max_b - cam_b - 1 > 0){
            pss_stateS.resize(3 * (max_b - cam_b - 1));
            pss_stateS = pss.segment(3, 3 * (max_b - cam_b - 1));
        }
        if (cam_b > 0){
            pss_stateD.resize(3 * cam_b);
            pss_stateD = pss.segment(3 * (max_b - cam_b), 3 * cam_b);
        }

        BoundarySamplingRecord eRec;
        // EdgeRaySamplingRecord eRec;
        sampleEdgeRay(scene, pss_stateE, edge_dist, edge_indices, eRec);
        if (eRec.shape_id < 0)
        {
            assert(false);
        }
        const Shape *shape = scene.shape_list[eRec.shape_id];
        const Edge &edge = shape->edges[eRec.edge_id];
        Ray edgeRay(eRec.ref, eRec.dir);
        Intersection itsS, itsD;
        if (!scene.rayIntersect(edgeRay, true, itsS) ||
            !scene.rayIntersect(edgeRay.flipped(), true, itsD))
            assert(false);
        // populate the data in BoundarySamplingRecord eRec
        eRec.shape_id_S = itsS.indices[0];
        eRec.tri_id_S = itsS.indices[1];
        eRec.shape_id_D = itsD.indices[0];
        eRec.tri_id_D = itsD.indices[1];

        // algorithm1_MALA_indirect::LightPathPSS path(3 * max_b);
        // Spectrum weight = algorithm1_MALA_indirect::eval(scene, pss, max_b, cam_b, edge_dist, edge_indices, &path) / intpdf;
        // if (weight.sum() < 1e-10)
        //     return;
        // // PSDR_INFO("cam_b: {}", cam_b);
        // // PSDR_INFO("path.verticesD.size(): {}", path.verticesD.size());
        // assert(path.verticesD.size() == cam_b + 1);
        // Intersection cam_its = path.verticesD[cam_b];
        // handleSurfaceInteraction(cam_its, scene, d_scene, eRec, sampler, weight, d_image);
        // return;

        if (eRec.shape_id < 0)
        {
            assert(false);
        }

        // make sure the ray is tangent to the surface
        if (edge.f0 >= 0 && edge.f1 >= 0)
        {
            Vector n0 = shape->getGeoNormal(edge.f0),
                   n1 = shape->getGeoNormal(edge.f1);
            Float dotn0 = edgeRay.dir.dot(n0),
                  dotn1 = edgeRay.dir.dot(n1);
            if (math::signum(dotn0) * math::signum(dotn1) > -0.5)
            {
                PSDR_ASSERT_MSG(false, "Bad edge ray sample: [ {}, {} ]", dotn0, dotn1);
            assert(false);
            }
        }

        /* prevent self intersection */
        const Vector2i ind0(eRec.shape_id, edge.f0), ind1(eRec.shape_id, edge.f1);
        if (itsS.indices == ind0 || itsS.indices == ind1 ||
            itsD.indices == ind0 || itsD.indices == ind1)
            assert(false);
        // FIXME: if the isTransmissive() of a dielectric returns false, the rendering time will be short, and the variance will be small.
        const Float gn1d1 = itsS.geoFrame.n.dot(-edgeRay.dir), sn1d1 = itsS.shFrame.n.dot(-edgeRay.dir),
                    gn2d1 = itsD.geoFrame.n.dot(edgeRay.dir), sn2d1 = itsD.shFrame.n.dot(edgeRay.dir);
        bool valid1 = (itsS.ptr_bsdf->isTransmissive() && math::signum(gn1d1) * math::signum(sn1d1) > 0.5f) || (!itsS.ptr_bsdf->isTransmissive() && gn1d1 > Epsilon && sn1d1 > Epsilon),
             valid2 = (itsD.ptr_bsdf->isTransmissive() && math::signum(gn2d1) * math::signum(sn2d1) > 0.5f) || (!itsD.ptr_bsdf->isTransmissive() && gn2d1 > Epsilon && sn2d1 > Epsilon);
        if (itsS.isEmitter())
            valid1 = true;
        if (!valid1 || !valid2)
            assert(false);

        /* Jacobian determinant that accounts for the change of variable */
        Vector v0 = shape->getVertex(edge.v0);
        Vector v1 = shape->getVertex(edge.v1);
        Vector v2 = shape->getVertex(edge.v2);
        const Vector xB = v0 + (v1 - v0) * eRec.t,
                     &xS = itsS.p;
        const Vector &xD = itsD.p;
        Vector n = (v0 - v1).cross(-edgeRay.dir).normalized();
        n *= -math::signum(n.dot(v2 - v0)); // make sure n points to the visible side
        Float J = dlD_dlB(xS,
                          xB, (v0 - v1).normalized(),
                          xD, itsD.geoFrame.n) *
                  dA_dw(xB, xS, itsS.geoFrame.n);
        Float baseValue = J * geometric(xD, itsD.geoFrame.n, xS, itsS.geoFrame.n);
        assert(baseValue > 0.0);

        /* Sample source path */
        // radiances[0] = Spectrum(0.0f);
        // for (int i = 1; i < radiances.size(); i++)
        //     radiances[i] += radiances[i - 1];

        /* Sample detector path */
        Spectrum throughput(1.0f);
        Ray ray_sensor;
        Intersection its(itsD);
        std::vector<Vector> ps;
        ps.push_back(its.p);
        for (int i = 0; i < cam_b; i++)
        {
            Vector wo_local, wo;
            Float bsdf_pdf, bsdf_eta;
            Spectrum bsdf_weight = its.sampleBSDF(MALA::pss_get(pss_stateD, i), wo_local,
                                                  bsdf_pdf, bsdf_eta,
                                                  EBSDFMode::EImportanceWithCorrection);
            if (bsdf_weight.isZero())
                return;
            wo = its.toWorld(wo_local);
            Vector wi = its.toWorld(its.wi);
            Float wiDotGeoN = wi.dot(its.geoFrame.n), woDotGeoN = wo.dot(its.geoFrame.n);
            if (wiDotGeoN * its.wi.z() <= 0 || woDotGeoN * wo_local.z() <= 0)
                return;
            throughput *= bsdf_weight;
            ray_sensor = Ray(its.p, wo);
            scene.rayIntersect(ray_sensor, true, its);
            ps.push_back(its.p);
            if (!its.isValid())
                return;
        }
        assert(throughput.abs().sum() < 1e100);
        bool succeed = handleSurfaceInteraction(its, scene, d_scene,
                                eRec, sampler,
                                d_image);
    }
}

Float mutateLargeStep(const Scene &scene, RndSampler *sampler, int max_bounces, int cur_n, int cur_i, 
                    const DiscreteDistribution &edge_dist,
                    const std::vector<Vector2i> &edge_indices,
                    const Grid3D_Sampling::grid3D &grid_distrb,
                    const MALA::MALAVector &current, MALA::MALAVector &proposal, int &n, int &i) { 
    Vector proposal0 = sampler->next3D();
    // n = sampleInt(sampler->next1D(), max_bounces);
    // i = sampleInt(sampler->next1D(), n) - 1;
    n = cur_n;
    i = cur_i;
    proposal.resize(n * 3);
    Float proposal_pdf = 1.0, current_pdf = 1.0;
    if (grid_distrb.distrb.getSum() > Epsilon)
    {
        Vector current0 = MALA::pss_get(current, 0);
        Float current0_pdf = grid_distrb.query_pdf(current0);
        current_pdf *= current0_pdf;
        proposal0 = grid_distrb.sample(proposal0, proposal_pdf);
    }
    MALA::pss_set(proposal, 0, proposal0);
    for (int i = 1; i < n; i++){
        MALA::pss_set(proposal, i, sampler->next3D());
    }
    // else if (aq_distrb.distrb.getSum() > Epsilon)
    // {
    //     current_pdf = aq_distrb.query_pdf(current);
    //     proposal = aq_distrb.sample(proposal, proposal_pdf);
    // }
    Spectrum contrib_cur_spec = algorithm1_MALA_indirect::eval(scene, current, cur_n, cur_i, edge_dist, edge_indices);
    Spectrum contrib_prop_spec = algorithm1_MALA_indirect::eval(scene, proposal, n, i, edge_dist, edge_indices);
    // eval_DirectBoundary(current, scene, 1, edge_dist, edge_indices, false, contrib_cur_spec);
    // eval_DirectBoundary(proposal, scene, 1, edge_dist, edge_indices, false, contrib_prop_spec);
    Float contrib_cur = contrib_cur_spec[0] + contrib_cur_spec[1] + contrib_cur_spec[2];
    Float contrib_prop = contrib_prop_spec[0] + contrib_prop_spec[1] + contrib_prop_spec[2];
    Float A = contrib_prop / contrib_cur / proposal_pdf * current_pdf;
    A = clamp(A, 0.0, 1.0);
    return A; // will be accepted
}

Float mutateSmallStep(const Scene &scene, Scene &d_scene, 
                            RndSampler *sampler, int max_bounces, int n_cam_bounces,
                            MALA::Mutation &mutation, MALA::KNNCache &cache,
                            const DiscreteDistribution &edge_dist,
                            const std::vector<Vector2i> &edge_indices,
                            const std::vector<std::vector<int>> &edge_indices_inv,
                            const std::vector<std::vector<Vector3i>> &face_to_edge_indices,
                            Float step_length, const MALAOptions &options,
                            const MALA::MALAVector &current, MALA::MALAVector &proposal, int *require_mut_dim1 = nullptr) {
    // if (reuse) {
    //     current = proposal;
    // } else {
    MALA::PSS_State current_state(max_bounces * 3);
    MALA::PSS_State proposal_state(max_bounces * 3);
    current_state.u = current;
    // current_state.u[0] = current[1];
    // current_state.u[1] = current[2];
    algorithm1_MALA_indirect::EdgeBound bound;
    algorithm1_MALA_indirect::LightPathPSS path_u(max_bounces * 3);
    Spectrum fu_spec(0.0f);
    fu_spec = algorithm1_MALA_indirect::eval(scene, current, max_bounces, n_cam_bounces, edge_dist, edge_indices, &path_u, false);
    // eval_DirectBoundary(current, scene, 1, edge_dist, edge_indices, false, fu_spec, &bound);
    bound = path_u.bound;
    int curr_edge_idx = bound.edge_idx;
    // Float edge_length = scene.shape_list[bound.shape_idx]->edges[curr_edge_idx].length;
    Shape* shape = scene.shape_list[bound.shape_idx];
    Edge edge = shape->edges[curr_edge_idx];
    Vector n0 = shape->getGeoNormal(edge.f0);
    Vector n1 = shape->getGeoNormal(edge.f1);
    Vector v0 = shape->vertices[edge.v0];
    Vector v1 = shape->vertices[edge.v1];
    // PSDR_INFO("here1");
    // PSDR_INFO("bound.shape_id: {}", bound.shape_idx);
    // PSDR_INFO("bound.edge_id: {}", bound.edge_idx);
    int edge_id_dist = edge_indices_inv[bound.shape_idx][bound.edge_idx];
    // PSDR_INFO("edge_id_dist: {}", edge_id_dist);
    bound.min = edge_dist.m_cdf[edge_id_dist];
    bound.max = edge_dist.m_cdf[edge_id_dist + 1];

    // PSDR_INFO("here2");
    if (!(current_state.u[0] > bound.min && current_state.u[0] < bound.max)) {
        EdgeSamplingRecord erec;
        EdgePrimarySampleRecord ePSRec;
        assert(edge_dist.size() == edge_indices.size());
        scene.sampleEdgePoint(current_state.u[0], edge_dist, edge_indices, erec, &ePSRec);
        algorithm1_MALA_indirect::LightPathPSS path_tmp(max_bounces * 3);
        algorithm1_MALA_indirect::eval(scene, current, max_bounces, n_cam_bounces, edge_dist, edge_indices, &path_tmp, true);
        PSDR_INFO("ePSRec.edge_dist_id: {}", ePSRec.edge_dist_id);
        PSDR_INFO("edge_indices[edge_id_dist][0]: {}", edge_indices[edge_id_dist][0]);
        PSDR_INFO("edge_indices[edge_id_dist][1]: {}", edge_indices[edge_id_dist][1]);
        PSDR_INFO("eRec.shape_id: {}", erec.shape_id);
        PSDR_INFO("eRec.edge_id: {}", erec.edge_id);
        PSDR_INFO("bound.min: {}", bound.min);
        PSDR_INFO("bound.max: {}", bound.max);
        PSDR_INFO("ePSRec.min: {}", ePSRec.min);
        PSDR_INFO("ePSRec.max: {}", ePSRec.max);
        PSDR_INFO("current_state.u[0]: {}", current_state.u[0]);
        assert(false);
    }
    current_state.f_u = fu_spec[0] + fu_spec[1] + fu_spec[2];
    // gaussian_test(current_state.u, current_state.f_u);
    if (current_state.f_u == 0){
        return 0.0;
        // assert(false);
    }


    MALA::MALAVector m_u, M_u;
    // if (cache.write || !mutation.step_readonly(cache, current_state.u, Vector(1.0, 1.0, 1.0), m_u, M_u)){
    // Vector d_u_ = d_eval_DirectBoundary(current, scene, d_scene, 1, edge_dist, edge_indices);
    MALA::MALAVector u_edge((max_bounces + 1) * 3), d_u_edge((max_bounces + 1) * 3);
    u_edge = path_u.getEdgePSSState(max_bounces, n_cam_bounces);
    if (options.use_gradient){
        algorithm1_MALA_indirect::LightPathPSSAD pathAD_u(path_u);
        algorithm1_MALA_indirect::d_eval(scene, d_scene, pathAD_u);
        MALA::MALAVector d_u(max_bounces * 3);
        d_u = pathAD_u.der.getPSSState(max_bounces, n_cam_bounces);
        current_state.g = d_u / current_state.f_u;
        d_u_edge = pathAD_u.der.getEdgePSSState(max_bounces, n_cam_bounces) / current_state.f_u;
        // current_state.g[0] = dXtodu0(current_state.u[0], MALA::pss_get(d_u_edge, 0), edge_dist, v0, v1);
        Vector2 g1g2 = dWtodu1u2(Vector2(current_state.u[1], current_state.u[2]), MALA::pss_get(d_u_edge, 1), n0, n1);
        // current_state.g[1] = g1g2[0];
        // current_state.g[2] = g1g2[1];
    } else {
        current_state.g = MALA::MALAVector::Zero(max_bounces * 3);
        d_u_edge = MALA::MALAVector::Zero((max_bounces + 1) * 3);
    }
    // PSDR_INFO("HERE1");
    mutation.step(cache, current_state.u, current_state.g, m_u, M_u, n_cam_bounces);
    // mutation.step(cache, u_edge, d_u_edge, m_u, M_u, n_cam_bounces); // use edge point + edge dir to query cache
    // MALA::MALAVector m_u_pss = MALA::MALAVector::Zero(max_bounces * 3), M_u_pss = MALA::MALAVector::Zero(max_bounces * 3);
    // m_u_pss[0] = dXtodu0(current_state.u[0], MALA::pss_get(m_u, 0), edge_dist, v0, v1);
    // Vector2 m1m2 = dWtodu1u2(Vector2(current_state.u[1], current_state.u[2]), MALA::pss_get(m_u, 1), n0, n1);
    // m_u_pss[1] = m1m2[0];
    // m_u_pss[2] = m1m2[1];
    // Vector Xtou0, Wtou1, Wtou2;
    // jacobianWtou1u2(Vector2(current_state.u[1], current_state.u[2]), n0, n1, Wtou1, Wtou2);
    // jacobianXtou0(current_state.u[0], edge_dist, edge_id_dist, v0, v1, Xtou0);
    // M_u_pss[0] = Xtou0.cwiseProduct(Xtou0).cwiseProduct(MALA::pss_get(M_u, 0)).sum();
    // M_u_pss[1] = Wtou1.cwiseProduct(Wtou1).cwiseProduct(MALA::pss_get(M_u, 1)).sum();
    // M_u_pss[2] = Wtou2.cwiseProduct(Wtou2).cwiseProduct(MALA::pss_get(M_u, 1)).sum();
    MALA::Gaussian gaussian_uv;
    MALA::MALAVector step_length_vec = MALA::MALAVector::Constant(m_u.size(), step_length);

    Vector dX_reference = Vector(1.0, 1.0, 1.0);
    Vector dW_reference = Vector(1.0, 1.0, 1.0);
    Float du0_ref = dXtodu0(current_state.u[0], dX_reference, edge_dist, v0, v1);
    Vector2 du1u2_ref = dWtodu1u2(Vector2(current_state.u[1], current_state.u[2]), dW_reference, n0, n1);
    
    // int curr_edge_idx = bound.edge_idx;
    Float edge_length = scene.shape_list[bound.shape_idx]->edges[curr_edge_idx].length;
    Float warp_norm = (abs(du0_ref) + abs(du1u2_ref[0]) + abs(du1u2_ref[1])) / 3.0;
    if (options.scale_steplength && edge.f1 >= 0) { // when this dimension for wedge
        step_length_vec[0] *= abs(du0_ref) * options.scale[0];
        step_length_vec[1] *= abs(du1u2_ref[0]) * options.scale[1];
        step_length_vec[2] *= abs(du1u2_ref[1]) * options.scale[2];
        // step_length_vec[0] *= edge_length;
    } else { // do we really need to scale here? the old implementation seems to work well under this case.
        // step_length_vec[0] *= abs(du0_ref);
    }
    // step_length_vec[0] *= edge_length;
    // PSDR_INFO("m_u_pss: {}, {}, {}", m_u_pss[0], m_u_pss[1], m_u_pss[2]);
    // PSDR_INFO("M_u_pss: {}, {}, {}", M_u_pss[0], M_u_pss[1], M_u_pss[2]);

    MALA::ComputeGaussian(step_length_vec, m_u, M_u/*, path.getDiscreteDim(max_bounces, n_cam_bounces)*/, gaussian_uv);
    // Vector w = sampler->next2D();
    // MALA::MALAVector dd = path.getDiscreteDim(max_bounces, n_cam_bounces);
    // PSDR_INFO("dd.size: {}", dd.size());
    // proposal_state.u = gaussian_uv.GenerateSample(sampler, max_bounces * 3).cwiseProduct(
    //                     1.0 - path.getDiscreteDim(max_bounces, n_cam_bounces)) + current_state.u;
    // // PSDR_INFO("HERE2");
    // PSDR_INFO("gaussian_uv.mean: {}, {}, {}", gaussian_uv.mean[0], gaussian_uv.mean[1], gaussian_uv.mean[2]);
    // PSDR_INFO("gaussian_uv.covL_d: {}, {}, {}", gaussian_uv.covL_d[0], gaussian_uv.covL_d[1], gaussian_uv.covL_d[2]);
    MALA::MALAVector v_edge_prop = gaussian_uv.GenerateSample(sampler) + current_state.u;// + u_edge;
    // PSDR_INFO("current_state.u: {}, {}, {}", current_state.u[0], current_state.u[1], current_state.u[2]);
    // PSDR_INFO("v_edge_prop: {}, {}, {}", v_edge_prop[0], v_edge_prop[1], v_edge_prop[2]);
    // PSDR_INFO("u_edge_point: {}, {}, {}", u_edge[0], u_edge[1], u_edge[2]);
    // PSDR_INFO("u_edge_dir: {}, {}, {}", u_edge[3], u_edge[4], u_edge[5]);
    // PSDR_INFO("v_edge_point: {}, {}, {}", v_edge_prop[0], v_edge_prop[1], v_edge_prop[2]);
    // PSDR_INFO("v_edge_dir: {}, {}, {}", v_edge_prop[3], v_edge_prop[4], v_edge_prop[5]);
    // PSDR_INFO("d_u_edge_point: {}, {}, {}", d_u_edge[0], d_u_edge[1], d_u_edge[2]);
    // PSDR_INFO("d_u_edge_dir: {}, {}, {}", d_u_edge[3], d_u_edge[4], d_u_edge[5]);
    MALA::MALAVector offset = v_edge_prop - current_state.u; // offset in world space for the edge    
    MALA::MALAVector neg_offset = -offset;
    // PSDR_INFO("HERE3");
    for (int i = 3; i < (max_bounces) * 3; i++){
        if (v_edge_prop[i] < 0.0 || v_edge_prop[i] > 1.0){
            return 0.0; // reject
        }
    }
    // step_length_vec[0] *= edge_dist[bound.edge_idx + 1];
    Vector mutation_dir = MALA::pss_get(offset, 0); // calculate mutation dist in primary sample space
    // mutation_dir[0] = dXtodu0(current_state.u[0], MALA::pss_get(offset, 0), edge_dist, v0, v1); // map X to u0
    // Vector2 u1u2 = dWtodu1u2(Vector2(current_state.u[1], current_state.u[2]), MALA::pss_get(offset, 1), n0, n1); // map W to u1u2
    // mutation_dir[1] = u1u2[0];
    // mutation_dir[2] = u1u2[1];
    Vector final_dir_oppo;
    // assert(false);
    // if (true){ // always try to mutate across edge
        // PSDR_INFO("block size: {}x{}x{}", bound.max - bound.min, 1, 1);
        std::vector<std::pair<Vector, int>> mutation_path; // contains start, end and every transition state
        bool success;

        if (edge.f1 < 0) { // if is topology boundary
            success = edgeJump(MALA::pss_get(current_state.u, 0), 
                        mutation_dir,
                        scene, bound, edge_indices, edge_indices_inv, edge_dist, sampler, final_dir_oppo, mutation_path, options.across_edge);
        } else { // if is wedge
            success = blockJump(MALA::pss_get(current_state.u, 0), 
                        mutation_dir,
                        scene, bound, edge_indices, edge_indices_inv, edge_dist, face_to_edge_indices, sampler, final_dir_oppo, mutation_path, options.across_edge); // mutate in PSS
        }
        if (!success){
            // PSDR_INFO("mutation_path.size(): {}", mutation_path.size());
            return 0.0; // reject
        } else {
            // PSDR_INFO("mutation_path.size(): {}", mutation_path.size());
            // PSDR_INFO("mutation_dir: {}, {}, {}", mutation_dir[0], mutation_dir[1], mutation_dir[2]);
            proposal_state.u[0] = mutation_path[mutation_path.size() - 1].first[0];
            proposal_state.u[1] = mutation_path[mutation_path.size() - 1].first[1];
            proposal_state.u[2] = mutation_path[mutation_path.size() - 1].first[2]; // new position in PSS.
            // PSDR_INFO("proposal_state.u: {}, {}, {}", proposal_state.u[0], proposal_state.u[1], proposal_state.u[2]);
            proposal_state.u.segment(3, 3 * (max_bounces - 1)) = v_edge_prop.segment(3, 3 * (max_bounces - 1));
        }

        if (require_mut_dim1 != nullptr) {
            for (int i = 0; i < mutation_path.size(); i++){
                if (mutation_path[i].second == 1){
                    *require_mut_dim1 = i;
                    break;
                }
            }
        }
        // if (u_0 < -0.5) {
        //     return 0.0; // reject
        // }
        // proposal_state.u[0] = u_0;
        // EdgeRaySamplingRecord erRec_current, erRec_proposal;
        // sampleEdgeRay(scene, MALA::pss_get(current_state.u, 0), edge_dist, edge_indices, erRec_current);
        // sampleEdgeRay(scene, MALA::pss_get(proposal_state.u, 0), edge_dist, edge_indices, erRec_proposal);
        // PSDR_INFO("mutation_dir: {}, {}, {}", mutation_dir[0], mutation_dir[1], mutation_dir[2]);
        // PSDR_INFO("erRec_current.ray_dir: {}, {}, {}", erRec_current.dir[0], erRec_current.dir[1], erRec_current.dir[2]);
        // PSDR_INFO("erRec_proposal.ray_dir: {}, {}, {}", erRec_proposal.dir[0], erRec_proposal.dir[1], erRec_proposal.dir[2]);
        // PSDR_INFO("erRec_current.ref: {}, {}, {}", erRec_current.ref[0], erRec_current.ref[1], erRec_current.ref[2]);
        // PSDR_INFO("erRec_proposal.ref: {}, {}, {}", erRec_proposal.ref[0], erRec_proposal.ref[1], erRec_proposal.ref[2]);

        // PSDR_INFO("length of mutation_path: {}", mutation_path.size());

        
    // }
    // PSDR_INFO("HERE4");
    proposal = proposal_state.u;
    Spectrum fv_spec(0.0f);
    algorithm1_MALA_indirect::LightPathPSS path_v(max_bounces * 3);
    // eval_DirectBoundary(proposal_state.u, scene, 1, edge_dist, edge_indices, false, fv_spec);
    fv_spec = algorithm1_MALA_indirect::eval(scene, proposal, max_bounces, n_cam_bounces, edge_dist, edge_indices, &path_v, false);
    Shape* shape_v = scene.shape_list[path_v.bound.shape_idx];
    int curr_edge_idx_v = path_v.bound.edge_idx;
    Edge edge_v = shape->edges[curr_edge_idx_v];
    Vector n0_v = shape_v->getGeoNormal(edge_v.f0);
    Vector n1_v = shape_v->getGeoNormal(edge_v.f1);
    Vector v0_v = shape_v->vertices[edge_v.v0];
    Vector v1_v = shape_v->vertices[edge_v.v1];
    
    proposal_state.f_u = fv_spec[0] + fv_spec[1] + fv_spec[2];
    if (proposal_state.f_u == 0.0){
        // reject
        // PSDR_INFO("proposal.fu == 0.0");
        return 0.0; // remain current state
    }
    // if (strcmp(path_u.typeS, path_v.typeS) != 0 || strcmp(path_u.typeD, path_v.typeD) != 0){
    //     // reject
    //     return 0.0; // remain current state
    // }

    MALA::MALAVector v_edge((max_bounces + 1) * 3), d_v_edge((max_bounces + 1) * 3);
    v_edge = path_v.getEdgePSSState(max_bounces, n_cam_bounces);
    // PSDR_INFO("actual v_edge_point: {}, {}, {}", v_edge[0], v_edge[1], v_edge[2]);
    // PSDR_INFO("actual v_edge_dir: {}, {}, {}", v_edge[3], v_edge[4], v_edge[5]);
    MALA::MALAVector m_v, M_v;
    if (options.use_gradient){
        algorithm1_MALA_indirect::LightPathPSSAD pathAD_v(path_v);
        algorithm1_MALA_indirect::d_eval(scene, d_scene, pathAD_v);
        MALA::MALAVector d_v(max_bounces * 3);
        d_v = pathAD_v.der.getPSSState(max_bounces, n_cam_bounces);
        proposal_state.g = d_v / proposal_state.f_u;

        d_v_edge = pathAD_v.der.getEdgePSSState(max_bounces, n_cam_bounces) / proposal_state.f_u;
        // proposal_state.g[0] = dXtodu0(proposal_state.u[0], MALA::pss_get(d_v_edge, 0), edge_dist, v0_v, v1_v);
        // Vector2 g1g2 = dWtodu1u2(Vector2(proposal_state.u[1], proposal_state.u[2]), MALA::pss_get(d_v_edge, 1), n0_v, n1_v);
        // proposal_state.g[1] = g1g2[0];
        // proposal_state.g[2] = g1g2[1];
    } else {
        proposal_state.g = MALA::MALAVector::Zero(max_bounces * 3);
        d_v_edge = MALA::MALAVector::Zero((max_bounces + 1) * 3);
    }
    // PSDR_INFO("d_v_edge_point: {}, {}, {}", d_v_edge[0], d_v_edge[1], d_v_edge[2]);
    // PSDR_INFO("d_v_edge_dir: {}, {}, {}", d_v_edge[3], d_v_edge[4], d_v_edge[5]);
    // proposal_state.g[0] = d_rnd_v[1] / proposal_state.f_u;
    // proposal_state.g[1] = d_rnd_v[2] / proposal_state.f_u;
    // mutation.step_hypo(cache, v_edge, d_v_edge, m_v, M_v, n_cam_bounces); // cache query in world space
    mutation.step_hypo(cache, proposal_state.u, proposal_state.g, m_v, M_v, n_cam_bounces); // cache query in pss space
    // MALA::MALAVector m_v_pss = MALA::MALAVector::Zero(max_bounces * 3), M_v_pss = MALA::MALAVector::Zero(max_bounces * 3);
    // m_v_pss[0] = dXtodu0(proposal_state.u[0], MALA::pss_get(m_v, 0), edge_dist, v0_v, v1_v);
    // Vector2 m1m2_v = dWtodu1u2(Vector2(proposal_state.u[1], proposal_state.u[2]), MALA::pss_get(m_v, 1), n0_v, n1_v);
    // m_v_pss[1] = m1m2_v[0];
    // m_v_pss[2] = m1m2_v[1];
    // Vector Xtov0, Wtov1, Wtov2;
    // jacobianWtou1u2(Vector2(proposal_state.u[1], proposal_state.u[2]), n0_v, n1_v, Wtov1, Wtov2);
    
    // int edge_id_dist_v = edge_indices_inv[path_v.bound.shape_idx][path_v.bound.edge_idx];
    // jacobianXtou0(proposal_state.u[0], edge_dist, edge_id_dist_v, v0, v1, Xtov0);
    // M_v_pss[0] = Xtov0.cwiseProduct(Xtov0).cwiseProduct(MALA::pss_get(M_v, 0)).sum();
    // M_v_pss[1] = Wtov1.cwiseProduct(Wtov1).cwiseProduct(MALA::pss_get(M_v, 1)).sum();
    // M_v_pss[2] = Wtov2.cwiseProduct(Wtov2).cwiseProduct(MALA::pss_get(M_v, 1)).sum();
    dX_reference = Vector(1.0, 1.0, 1.0);
    dW_reference = Vector(1.0, 1.0, 1.0);
    du0_ref = dXtodu0(proposal_state.u[0], dX_reference, edge_dist, v0_v, v1_v);
    du1u2_ref = dWtodu1u2(Vector2(proposal_state.u[1], proposal_state.u[2]), dW_reference, n0_v, n1_v);
    
    // int curr_edge_idx = bound.edge_idx;
    // Float edge_length = scene.shape_list[bound.shape_idx]->edges[curr_edge_idx].length;
    step_length_vec = MALA::MALAVector::Constant(m_u.size(), step_length);
    edge_length = scene.shape_list[bound.shape_idx]->edges[curr_edge_idx].length;
    warp_norm = (abs(du0_ref) + abs(du1u2_ref[0]) + abs(du1u2_ref[1])) / 3.0;
    if (options.scale_steplength && edge.f1 >= 0) { // when this dimension for wedge
        step_length_vec[0] *= abs(du0_ref) * options.scale[0];
        step_length_vec[1] *= abs(du1u2_ref[0]) * options.scale[1];
        step_length_vec[2] *= abs(du1u2_ref[1]) * options.scale[2];
        // step_length_vec[0] *= edge_length;
    } else { // do we really need to scale here? the old implementation seems to work well under this case.
        // step_length_vec[0] *= abs(du0_ref);
    }

    MALA::Gaussian gaussian_vu;
    MALA::ComputeGaussian(step_length_vec, m_v, M_v/*, path_v.getDiscreteDim(max_bounces, n_cam_bounces)*/, gaussian_vu);

    // MALA::MALAVector vu_offset = u_edge - v_edge; 
    // MALA::pss_set(vu_offset, 0, du0todX(proposal_state.u[0], final_dir_oppo[0], edge_dist, v0_v, v1_v));
    // MALA::pss_set(vu_offset, 1, du1u2todW(Vector2(proposal_state.u[1], proposal_state.u[2]), Vector2(final_dir_oppo[1], final_dir_oppo[2]), n0_v, n1_v));
    // MALA::MALAVector actual_offset = v_edge - u_edge; // offset in world space for the edge    
    MALA::pss_set(neg_offset, 0, final_dir_oppo);

    Float log_pdf_uv = 0;
    Float log_pdf_vu = 0;
    log_pdf_uv = gaussian_uv.GaussianLogPdf(offset/*, path.getDiscreteDim(max_bounces, n_cam_bounces)*/);
    log_pdf_vu = gaussian_vu.GaussianLogPdf(neg_offset/*, path_v.getDiscreteDim(max_bounces, n_cam_bounces)*/);
    // PSDR_INFO("offset: {}, {}, {}", offset[0], offset[1], offset[2]);
    // PSDR_INFO("final_dir_oppo: {}, {}, {}", final_dir_oppo[0], final_dir_oppo[1], final_dir_oppo[2]);
    // PSDR_INFO("f_u, f_v: {}, {}", current_state.f_u, proposal_state.f_u);
    // PSDR_INFO("log_pdf_uv, log_pdf_vu: {}, {}", log_pdf_uv, log_pdf_vu);
    Float A = exp(log_pdf_vu - log_pdf_uv) * proposal_state.f_u / current_state.f_u;
    // PSDR_INFO("A: {}", A);
    A = clamp(A, 0.0, 1.0);
    return A;
}
    void evalIntersectionRDRad(const Scene &scene, const RoughDielectricBSDF &rd, const Vector &rnd, const Array2i next_indices, 
                            const Intersection &cur, Intersection &next) { // assuming that we already have something in next
        Vector wo_local, wo;
        Float bsdf_pdf, bsdf_eta;
        int bsdf_id = cur.ptr_shape->bsdf_id;
        next.value = rd.sample(cur, rnd, wo_local,
                        bsdf_pdf, bsdf_eta,
                        EBSDFMode::ERadiance);
        wo = cur.toWorld(wo_local);
        Ray ray = Ray(cur.p, wo);
        Shape* shape = scene.getShape(next_indices[0]);
        Vector3i idx = shape->getIndices(next_indices[1]);
        Vector v0_next = detach(shape->getVertex(idx[0]));
        Vector v1_next = detach(shape->getVertex(idx[1]));
        Vector v2_next = detach(shape->getVertex(idx[2]));
        Array uvt = rayIntersectTriangle(v0_next, v1_next, v2_next, ray);
        next.p = cur.p + uvt[2] * wo;
        Vector geo_n = detach(shape->getFaceNormal(next_indices[1]));
        Vector sh_n = detach(shape->getShadingNormal(next_indices[1], Vector2(uvt[0], uvt[1])));
        next.geoFrame = Frame(geo_n);
        next.shFrame = Frame(sh_n);
        next.wi = next.toLocal(-wo);
        next.uv = Vector2(uvt[0], uvt[1]);
        return;
    }

    void evalIntersectionRDIC(const Scene &scene, const RoughDielectricBSDF &rd, const Vector &rnd, const Array2i next_indices, 
                            const Intersection &cur, Intersection &next) { // assuming that we already have something in next
        Vector wo_local, wo;
        Float bsdf_pdf, bsdf_eta;
        int bsdf_id = cur.ptr_shape->bsdf_id;
        next.value = rd.sample(cur, rnd, wo_local,
                        bsdf_pdf, bsdf_eta,
                        EBSDFMode::EImportanceWithCorrection);
        wo = cur.toWorld(wo_local);
        Ray ray = Ray(cur.p, wo);
        Shape* shape = scene.getShape(next_indices[0]);
        Vector3i idx = shape->getIndices(next_indices[1]);
        Vector v0_next = detach(shape->getVertex(idx[0]));
        Vector v1_next = detach(shape->getVertex(idx[1]));
        Vector v2_next = detach(shape->getVertex(idx[2]));
        Array uvt = rayIntersectTriangle(v0_next, v1_next, v2_next, ray);
        next.p = cur.p + uvt[2] * wo;
        Vector geo_n = detach(shape->getFaceNormal(next_indices[1]));
        Vector sh_n = detach(shape->getShadingNormal(next_indices[1], Vector2(uvt[0], uvt[1])));
        next.geoFrame = Frame(geo_n);
        next.shFrame = Frame(sh_n);
        next.wi = next.toLocal(-wo);
        next.uv = Vector2(uvt[0], uvt[1]);
        return;
    }

    void d_evalIntersectionRDRad(const Scene &scene, Scene &d_scene,
                            const RoughDielectricBSDF &rd, 
                            const Vector &rnd, Vector &d_rnd, 
                            const Array2i next_indices, 
                            const Intersection &cur, Intersection &d_cur,
                            Intersection &d_its) {
        [[maybe_unused]] Intersection its;
#if defined(ENZYME) && defined(ENZYME_BOUNDARY_DIRECT)
        __enzyme_autodiff((void *)evalIntersectionRDRad,
                            enzyme_dup, &scene, &d_scene,
                            enzyme_const, &rd,
                            enzyme_dup, &rnd, &d_rnd,
                            enzyme_const, &next_indices,
                            enzyme_dup, &cur, &d_cur,
                            enzyme_dup, &its, &d_its);
#endif
    }

    void d_evalIntersectionRDIC(const Scene &scene, Scene &d_scene,
                            const RoughDielectricBSDF &rd, 
                            const Vector &rnd, Vector &d_rnd, 
                            const Array2i next_indices, 
                            const Intersection &cur, Intersection &d_cur,
                            Intersection &d_its) {
        [[maybe_unused]] Intersection its;
#if defined(ENZYME) && defined(ENZYME_BOUNDARY_DIRECT)
        __enzyme_autodiff((void *)evalIntersectionRDIC,
                            enzyme_dup, &scene, &d_scene,
                            enzyme_const, &rd,
                            enzyme_dup, &rnd, &d_rnd,
                            enzyme_const, &next_indices,
                            enzyme_dup, &cur, &d_cur,
                            enzyme_dup, &its, &d_its);
#endif
    }

void RD_BSDF_test(const RoughDielectricBSDF &rd, Intersection &its, const Vector &geo_n, const Vector &wi, const Vector &rnd, Spectrum &value, Vector &wo) {
    // Vector wo_local;
    its.wi = wi;
    its.geoFrame = Frame(geo_n);
    Float pdf, eta;
    value = rd.sample(its, rnd, wo, pdf, eta, EBSDFMode::ERadiance);
    // value = rd.eval(its, wo);
    // value /= (pdf);
}

void RD_BSDF_test_echo(const RoughDielectricBSDF &rd, const Vector &geo_n, const Vector &wi, const Vector &rnd, Spectrum &value, Vector &wo) {
    Float pdf, eta;
    Intersection its;
    its.wi = wi;
    its.geoFrame = Frame(geo_n);
    value = rd.sample(its, rnd, wo, pdf, eta, EBSDFMode::ERadiance);
    // its.pdf = pdf;
    // PSDR_INFO("wo_local: {}, {}, {}", wo[0], wo[1], wo[2]);
    // PSDR_INFO("value: {}, {}, {}", value[0], value[1], value[2]);
    // value = rd.eval(its, wo);
    // PSDR_INFO("value: {}, {}, {}", value[0], value[1], value[2]);
    // PSDR_INFO("pdf: {}", pdf);
}

void d_RD_BSDF_test(const RoughDielectricBSDF &rd, const Intersection &its, const Vector& geo_n, 
                    const Vector &wi, const Vector &d_wi,
                    const Vector &rnd, Vector &d_rnd,
                    Spectrum &d_value, Vector &d_wo) {
    Spectrum value = Spectrum::Zero();
    Vector wo = Vector::Zero();
    RoughDielectricBSDF rd_copy = rd;
    Intersection d_its;
    __enzyme_autodiff((void *)RD_BSDF_test,
                        enzyme_const, &rd,
                        enzyme_dup, &its, &d_its,
                        enzyme_const, &geo_n,
                        enzyme_dup, &wi, &d_wi,
                        enzyme_dup, &rnd, &d_rnd,
                        enzyme_dup, &value, &d_value,
                        enzyme_dup, &wo, &d_wo  
                        );
}

ArrayXd MetropolisIndirectEdgeIntegrator::renderD(
    SceneAD &sceneAD, RenderOptions &options, const ArrayXd &__d_image) const
{
    // PSDR_INFO("BSDF test: ");
    // RoughDielectricBSDF rd = RoughDielectricBSDF(0.1, 1.5, 1.0);
    // // RoughConductorBSDF rc = RoughConductorBSDF(0.1, Spectrum(0.155475, 0.116753, 0.138334), Spectrum(4.83181, 3.12296, 2.14866));
    // // DiffuseBSDF di = DiffuseBSDF(Spectrum(0.5, 0.5, 0.5));
    // Intersection its;
    // Float theta0 = M_PI / 4;
    // Vector wi = Vector(0.0, sin(theta0), cos(theta0));
    // Vector geo_n = Vector(0.0, 0.0, 1.0);
    // Vector rnd = Vector(0.3, 0.6, 0.9);
    // Spectrum value = Spectrum::Zero();
    // Vector wo = Vector::Zero();
    // RD_BSDF_test_echo(rd, geo_n, wi, rnd, value, wo);
    // Spectrum d_value = Spectrum(1.0, 1.0, 1.0);
    // // Spectrum d_value = Spectrum::Zero();
    // // Vector d_wo = Vector(1.0, 1.0, 1.0);
    // Vector d_wo = Vector::Zero();
    // Vector d_rnd = Vector::Zero();
    // Vector fd_d_rnd_val = Vector::Zero();
    // Vector fd_d_rnd_wo = Vector::Zero();
    // Vector fd_d_wi_val = Vector::Zero();
    // Vector fd_d_wi_wo = Vector::Zero();
    // Vector d_wi = Vector::Zero();
    // for (int i = 0; i < 3; i++) {
    //     Vector rnd_i = rnd;
    //     Vector wi_i = wi;
    //     rnd_i[i] += 0.0001;
    //     wi_i[i] += 0.0001;
    //     Spectrum value_i = Spectrum::Zero();
    //     Vector wo_i = Vector::Zero();
    //     RD_BSDF_test_echo(rd, geo_n, wi, rnd_i, value_i, wo_i);
    //     fd_d_rnd_val[i] = (value_i - value).sum() / 0.0001;
    //     fd_d_rnd_wo[i] = (wo_i - wo).sum() / 0.0001;
    //     RD_BSDF_test_echo(rd, geo_n, wi_i, rnd, value_i, wo_i);
    //     fd_d_wi_val[i] = (value_i - value).sum() / 0.0001;
    //     fd_d_wi_wo[i] = (wo_i - wo).sum() / 0.0001;
    // }
    // d_RD_BSDF_test(rd, its, geo_n, wi, d_wi, rnd, d_rnd, d_value, d_wo);
    // ArrayXd output_array(9);
    // PSDR_INFO("d_rnd: {}, {}, {}", d_rnd[0], d_rnd[1], d_rnd[2]);
    // PSDR_INFO("fd_d_rnd_val: {}, {}, {}", fd_d_rnd_val[0], fd_d_rnd_val[1], fd_d_rnd_val[2]);
    // PSDR_INFO("fd_d_rnd_wo: {}, {}, {}", fd_d_rnd_wo[0], fd_d_rnd_wo[1], fd_d_rnd_wo[2]);
    // PSDR_INFO("d_wi: {}, {}, {}", d_wi[0], d_wi[1], d_wi[2]);
    // PSDR_INFO("fd_d_wi_val: {}, {}, {}", fd_d_wi_val[0], fd_d_wi_val[1], fd_d_wi_val[2]);
    // PSDR_INFO("fd_d_wi_wo: {}, {}, {}", fd_d_wi_wo[0], fd_d_wi_wo[1], fd_d_wi_wo[2]);
    // PSDR_INFO("value: {}, {}, {}", value[0], value[1], value[2]);

    Scene tmp_der = sceneAD.der;
    GradientManager<Scene> gm_doll(tmp_der, omp_get_num_procs());

    const Scene &scene = sceneAD.val;
    [[maybe_unused]] Scene &d_scene = sceneAD.der;
    GradientManager<Scene> &gm = sceneAD.gm;
    gm.setZero(); // zero multi-thread gradient

    const int nworker = omp_get_num_procs();
    const auto &camera = scene.camera;
    const int nsamples = options.num_samples_secondary_edge_indirect;
    /* init debug info */
    debugInfo = DebugInfo(nworker, camera.getNumPixels(), nsamples);
    if (nsamples <= 0)
        return debugInfo.getArray();

    // return debugInfo.getArray();
    PSDR_INFO("MLTEdgeIntegrator::renderD with spp = {}",
              mala_config.num_chains * mala_config.num_samples / camera.getNumPixels());
    PSDR_INFO("MALA Config: {}, {}, {}", mala_config.scale[0], mala_config.scale[1], mala_config.scale[2]);
    Timer _("Indirect boundary");

/* -----------------------------Phase 0: Compute connectivity-----------------------------*/
    Timer *phase_0_timer = new Timer("Phase 0: Compute connectivity");
    DiscreteDistribution edge_dist_unnorm;

    for (int i = 0; i < edge_dist.size(); i++){
        edge_dist_unnorm.append(1.0);
    }

    std::vector<std::vector<int>> edge_indices_inv;
    edge_indices_inv.resize(scene.shape_list.size());
    for (int i = 0; i < scene.shape_list.size(); i++){
        edge_indices_inv[i].resize(scene.shape_list[i]->edges.size());
        for (int j = 0; j < scene.shape_list[i]->edges.size(); j++){
            edge_indices_inv[i][j] = -1;
        }
    }
    for (int i = 0; i < edge_indices.size(); i++){
        edge_indices_inv[edge_indices[i].x()][edge_indices[i].y()] = i;
    }
    PSDR_INFO("edge_indices.size(): {}", edge_indices.size());

    std::vector<std::vector<Vector3i>> face_to_edge_indices;
    face_to_edge_indices.resize(scene.shape_list.size());
    for (int i = 0; i < scene.shape_list.size(); i++){
        face_to_edge_indices[i].resize(scene.shape_list[i]->indices.size());
        for (int j = 0; j < scene.shape_list[i]->indices.size(); j++){
            face_to_edge_indices[i][j] = Vector3i(-1, -1, -1);
        }
    }
    // for (int i = 0; i < edge_indices.size(); i++){
    //     Vector2i edge = edge_indices[i];
    //     Shape* shape = scene.shape_list[edge.x()];
    //     Edge e = shape->edges[edge.y()];
    //     if (e.f0 >= 0){
    //         Vector3i &f0_to_edge = face_to_edge_indices[edge.x()][e.f0];
    //         if (f0_to_edge.x() < 0){
    //             f0_to_edge.x() = i;
    //         } else if (f0_to_edge.y() < 0){
    //             f0_to_edge.y() = i;
    //         } else if (f0_to_edge.z() < 0){
    //             f0_to_edge.z() = i;
    //         }
    //     }
    //     if (e.f1 >= 0){
    //         Vector3i &f1_to_edge = face_to_edge_indices[edge.x()][e.f1];
    //         if (f1_to_edge.x() < 0){
    //             f1_to_edge.x() = i;
    //         } else if (f1_to_edge.y() < 0){
    //             f1_to_edge.y() = i;
    //         } else if (f1_to_edge.z() < 0){
    //             f1_to_edge.z() = i;
    //         }
    //     }
    // }
    for (int i = 0; i < scene.shape_list.size(); i++) {
        Shape* shape = scene.shape_list[i];
        for (int j = 0; j < shape->edges.size(); j++){
            Edge e = shape->edges[j];
            if (e.f0 >= 0){
                Vector3i &f0_to_edge = face_to_edge_indices[i][e.f0];
                if (f0_to_edge.x() < 0){
                    f0_to_edge.x() = j;
                } else if (f0_to_edge.y() < 0){
                    f0_to_edge.y() = j;
                } else if (f0_to_edge.z() < 0){
                    f0_to_edge.z() = j;
                }
            }
            if (e.f1 >= 0){
                Vector3i &f1_to_edge = face_to_edge_indices[i][e.f1];
                if (f1_to_edge.x() < 0){
                    f1_to_edge.x() = j;
                } else if (f1_to_edge.y() < 0){
                    f1_to_edge.y() = j;
                } else if (f1_to_edge.z() < 0){
                    f1_to_edge.z() = j;
                }
            }
        }
    }
    delete phase_0_timer;

/* --------------------------------------- Phase 1 ---------------------------------------*/
    Timer *phase_1_timer = new Timer("Phase 1: Sample indirect boundary");
    int blockProcessed = 0;
    int phase1_samples = 1<<(mala_config.phase_one_samples);
    Spectrum phase1_sum = Spectrum::Zero();

    int valid_sample = 0;
    int nblocks_phase1 = std::ceil(sqrt(phase1_samples));
    int nsample_per_block = phase1_samples / nblocks_phase1;
    std::vector<SeedPath> seed_paths;
    seed_paths.clear();

#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < nblocks_phase1; i++)
    {
        RndSampler sampler(options.seed, i);
        Spectrum thread_sum = Spectrum::Zero();
        std::vector<SeedPath> seed_paths_thread;
        seed_paths_thread.clear();
        for (int j = 0; j < nsample_per_block; j++)
        {
            int idx = i * nsample_per_block + j;
            if (idx >= phase1_samples)
                break;
            Float pdf = 1.0;

            // MALA::MALAVector rnd(n_bounces * 3);
            Vector d_rnd = sampler.next3D();
            if (grid_distrb.distrb.getSum() > Epsilon)
            {
                d_rnd = grid_distrb.sample(d_rnd, pdf);
            }
            if (aq_distrb.distrb.getSum() > Epsilon)
            {
                d_rnd = aq_distrb.sample(d_rnd, pdf);
            }

            std::vector<SeedPath> seed_paths_sample;
            seed_paths_sample.clear();
            Spectrum contrib = sampleIndirectBoundary(scene, d_rnd, sampler, options.max_bounces, edge_dist, edge_indices, seed_paths_sample);
            if (std::isnan(contrib.sum() / pdf))
                continue;
            for (int k = 0; k < seed_paths_sample.size(); k++){
                seed_paths_sample[k].weight /= pdf;
                seed_paths_thread.push_back(seed_paths_sample[k]);
            }

            thread_sum += contrib / pdf;
        }
#pragma omp critical
        {
            phase1_sum[0] += thread_sum[0] / (Float)phase1_samples;
            phase1_sum[1] += thread_sum[1] / (Float)phase1_samples;
            phase1_sum[2] += thread_sum[2] / (Float)phase1_samples;
            valid_sample += seed_paths_thread.size();
            for (int k = 0; k < seed_paths_thread.size(); k++){
                seed_paths.push_back(seed_paths_thread[k]);
            }
        }
    }
    RndSampler sampler_wrs(options.seed, nblocks_phase1);
    PSDR_INFO("valid_sample: {}, num_chains: {}", seed_paths.size(), mala_config.num_chains);
    // if (valid_sample < mala_config.num_chains){
    //     assert(false);
    // }
    // return debugInfo.getArray();
    std::vector<std::vector<SeedPath>> sample_per_thread_vec(nworker);
    int iter = 1;
    while (seed_paths.size() < mala_config.num_chains) {
        // int new_sample_required = float(used_samples) * float(mala_config.num_chains - seed_paths.size()) / float(seed_paths.size());
        int new_sample_required = phase1_samples;
        int nblocks_phase1_new = std::ceil(sqrt(new_sample_required));
        int nsample_per_block_new = new_sample_required / nblocks_phase1_new;
        ++iter;
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
        for (int i = 0; i < nblocks_phase1_new; i++)
        {
            
            RndSampler sampler(options.seed + iter, i);
            Spectrum thread_sum = Spectrum::Zero();
            std::vector<SeedPath> seed_paths_thread;
            seed_paths_thread.clear();
            for (int j = 0; j < nsample_per_block_new; j++)
            {
                int idx = i * nsample_per_block + j;
                if (idx >= phase1_samples)
                    break;
                Float pdf = 1.0;

                // MALA::MALAVector rnd(n_bounces * 3);
                Vector d_rnd = sampler.next3D();
                if (grid_distrb.distrb.getSum() > Epsilon)
                {
                    d_rnd = grid_distrb.sample(d_rnd, pdf);
                }
                if (aq_distrb.distrb.getSum() > Epsilon)
                {
                    d_rnd = aq_distrb.sample(d_rnd, pdf);
                }

                std::vector<SeedPath> seed_paths_sample;
                seed_paths_sample.clear();
                Spectrum contrib = sampleIndirectBoundary(scene, d_rnd, sampler, options.max_bounces, edge_dist, edge_indices, seed_paths_sample);
                if (std::isnan(contrib.sum() / pdf))
                    continue;
                for (int k = 0; k < seed_paths_sample.size(); k++){
                    seed_paths_sample[k].weight /= pdf;
                    seed_paths_thread.push_back(seed_paths_sample[k]);
                }

                thread_sum += contrib / pdf;
            }
    #pragma omp critical
            {
                phase1_sum[0] += thread_sum[0] / (Float)phase1_samples;
                phase1_sum[1] += thread_sum[1] / (Float)phase1_samples;
                phase1_sum[2] += thread_sum[2] / (Float)phase1_samples;
                for (int k = 0; k < seed_paths_thread.size(); k++){
                    seed_paths.push_back(seed_paths_thread[k]);
                }
            }
        }
        PSDR_INFO("valid_sample: {}, num_chains: {}", seed_paths.size(), mala_config.num_chains);
    }

    auto comp = [](const std::pair<Float, int>& a, const std::pair<Float, int>& b) {
        return a.first > b.first; // Return true if a is greater than b
    };

    std::vector<MALA::MALAVector> init_samples;
    std::vector<std::pair<int, int>> init_sample_path_lengths;
    init_samples.resize(mala_config.num_chains);
    init_sample_path_lengths.resize(mala_config.num_chains);

    std::priority_queue<std::pair<Float, int>, std::vector<std::pair<Float, int>>, decltype(comp)> pq(comp);
    for (int i = 0; i < seed_paths.size(); i++){
        if (seed_paths[i].weight.sum() < 1e-50) {
            continue;
        }
        Float r = pow(sampler_wrs.next1D(), 1.0 / (seed_paths[i].weight.sum()));
        if (pq.size() < init_samples.size()){
            pq.push(std::make_pair(r, i));
        }
        else {
            if (pq.top().first < r){
                pq.pop();
                pq.push(std::make_pair(r, i));
            }
        }
    }
    while(!pq.empty()){
        // PSDR_INFO("pq.top().first: {}", pq.top().first);
        int i = pq.top().second;
        init_samples[pq.size() - 1] = seed_paths[i].pss;
        init_sample_path_lengths[pq.size() - 1] = {seed_paths[i].max_b, seed_paths[i].cam_b};
        pq.pop();
    }
    assert(!std::isnan(phase1_sum[1]));

    Spectrum phase1_mean = phase1_sum;
    PSDR_INFO("phase 1 sum: {}, {}, {}", phase1_sum[0], phase1_sum[1], phase1_sum[2]);
    PSDR_INFO("phase 1: {}, {}, {}", phase1_mean[0], phase1_mean[1], phase1_mean[2]);
    // return debugInfo.getArray();

    // // finite difference test:
    // MALA::MALAVector candidate0 = init_samples[0];
    // MALA::MALAVector d_candidate_fd = candidate0;
    // std::pair<int, int> path_length = init_sample_path_lengths[0];
    // algorithm1_MALA_indirect::LightPathPSS pss(candidate0.size());
    // PSDR_INFO("candidate0.size: {}", candidate0.size());
    // PSDR_INFO("candidate0: {}, {}, {}", candidate0[0], candidate0[1], candidate0[2]);
    // PSDR_INFO("path_length: {}, {}", path_length.first, path_length.second);

    // Spectrum f0 = algorithm1_MALA_indirect::eval(scene, candidate0, path_length.first, path_length.second, edge_dist, edge_indices, &pss, true);
    // algorithm1_MALA_indirect::LightPathPSSAD pssAD(pss);
    // PSDR_INFO("f0: {}, {}, {}", f0[0], f0[1], f0[2]);
    // PSDR_INFO("f0.sum: {}", f0[0] + f0[1] + f0[2]);
    // algorithm1_MALA_indirect::d_eval(scene, gm_doll.get(0), pssAD);
    // MALA::MALAVector d_candidate = pssAD.der.getPSSState(path_length.first, path_length.second);
    // for (int i = 0; i < candidate0.size(); i++) {
    //     MALA::MALAVector candidate0_ = candidate0;
    //     candidate0_[i] += 1e-4;
    //     Spectrum f = algorithm1_MALA_indirect::eval(scene, candidate0_, path_length.first, path_length.second, edge_dist, edge_indices);
    //     d_candidate_fd[i] = (f[0] + f[1] + f[2] - f0[0] - f0[1] - f0[2]) / 1e-4;
    //     PSDR_INFO("d_candidate[{}]: {}", i, d_candidate[i]);
    //     PSDR_INFO("d_candidate_fd[{}]: {}", i, d_candidate_fd[i]);
    // }
    // Vector d_edge_point = pssAD.der.edge_point;
    // PSDR_INFO("d_edge_point: {}, {}, {}", d_edge_point[0], d_edge_point[1], d_edge_point[2]);
    // Vector d_edge_dir = pssAD.der.edge_dir;
    // PSDR_INFO("d_edge_dir: {}, {}, {}", d_edge_dir[0], d_edge_dir[1], d_edge_dir[2]);
    // Vector d_rnd0 = MALA::pss_get(d_candidate, 0);
    // Vector rnd0 = MALA::pss_get(candidate0, 0);
    // Shape *shape = scene.shape_list[pssAD.val.ePSRec.shape_id];
    // Edge e = shape->edges[pssAD.val.ePSRec.edge_id];
    // return debugInfo.getArray();


    std::vector<Spectrum> _d_image_spec_list = from_tensor_to_spectrum_list(
        __d_image / mala_config.num_chains / mala_config.num_samples, camera.getNumPixels());

    for (int i = 0; i < _d_image_spec_list.size(); i++){
        _d_image_spec_list[i] *= phase1_mean;
    }
    delete phase_1_timer;

/* --------------------------------------- Phase 2 ---------------------------------------*/
    Timer *phase_2_timer = new Timer("Phase 2: MALA sampling");
    blockProcessed = 0;
    int cache_size = 8000;
    MALA::Mutation* mutations[nworker];
    // MutationHybrid* mutations_ref[nworker];
    for (int i = 0; i < nworker; i++){
        switch (mala_config.mode) {
            case 0:
                mutations[i] = new MALA::MutationDiminishing();
                break;
            case 1:
                mutations[i] = new MALA::MutationCacheBased();
                break;
            case 2:
                mutations[i] = new MALA::MutationHybrid();
                break;
            default:
                mutations[i] = new MALA::MutationDiminishing();
                break;
        }
        // mutations_ref[i] = new MutationHybrid();
    }
    MALA::KNNCache* cache[mala_config.num_chains];
    // MALA::GridCache* grid_cache[nworker];
    // int max_cache = 2000;
    for (int i = 0; i < mala_config.num_chains; i++){
        cache[i] = new MALA::KNNCache((options.max_bounces + 1) * 3, cache_size);
        // grid_cache[i] = new MALA::GridCache(cache_size);
    }
    int accepted = 0;
    int burn_in = mala_config.burn_in;
    Float avg_step_length = 0.0;
    int require_dim1 = 0;
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < mala_config.num_chains; i++)
    {
        const int tid = omp_get_thread_num();
        int thread_accepted = 0;
        int thread_require_dim1 = 0;
        RndSampler sampler(options.seed, i);
        MALA::Mutation *mutation = mutations[tid];
        // MutationHybrid *mutation_ref = mutations_ref[tid];
        // PSDR_INFO("chain start: tid {}", tid);
        mutation->setZero();
        MALA::KNNCache *cache_ptr = cache[i];
        // MALA::GridCache *grid_cache_ptr = grid_cache[tid];

// #pragma omp critical
//         {
//             if (!global_cache->write) {
//                 cache_ptr->copy(*global_cache);
//             }
//         }
        Float step_length = mala_config.step_length;
        Float target_acceptance_rate = 0.57;
        Float A = 0.0;
        MALA::MALAVector current_u = init_samples[i];
        std::pair<int, int> current_path_length = init_sample_path_lengths[i];
        int n_bounces = current_path_length.first;
        assert(n_bounces == current_u.size() / 3);
        int cam_bounces = current_path_length.second;
        // assert(n_bounces == 2);
        for (int j = 0; j < mala_config.num_samples + burn_in; j++)
        {
            // current_u = sampler.next3D();
            // Float weight;
            // eval_DirectBoundary(current_u, scene, 1, edge_dist, edge_indices, true, weight);
            // d_sampleDirectBoundary(scene, gm.get(tid), sampler,
            //                         edge_dist, edge_indices, 1.0,
            //                         _d_image_spec_list, current_u);
            // PSDR_INFO("chain: {}, sample: {}", i, j);
            bool global = false;
            Float p_global = sampler.next1D();
            MALA::MALAVector proposed_u(n_bounces * 3);
            proposed_u.setZero();
            // PSDR_INFO("here1");
            int n = n_bounces, i = cam_bounces;
            int require_mut_dim1 = -1;
            if (p_global < mala_config.p_global) {
                A = mutateLargeStep(scene, &sampler, options.max_bounces, n_bounces, cam_bounces, 
                                edge_dist, edge_indices, grid_distrb, current_u, proposed_u, n, i);
                global = true;
                // sample a point on the boundary
            } else {
                A = mutateSmallStep(scene, gm_doll.get(tid), &sampler, n_bounces, cam_bounces,
                            *mutation, *cache_ptr,
                            edge_dist, edge_indices, edge_indices_inv, face_to_edge_indices,
                            step_length, mala_config,
                            current_u, proposed_u, &require_mut_dim1);
            }
            if (require_mut_dim1 >= 0){
                thread_require_dim1++;
            }
            // PSDR_INFO("here2");
            Float a = sampler.next1D();
            if (a < A)
            {
                thread_accepted++;
                if (current_u.size() != proposed_u.size()){
                    current_u.resize(proposed_u.size());
                }
                if (global){
                    n_bounces = n;
                    cam_bounces = i;
                    mutation->setZero();
                }
                current_u = proposed_u;
            }
            // amcmc
            // Float acceptance_rate = Float(thread_accepted + 1) / (j + 2);
            // step_length = step_length * std::pow((acceptance_rate) / target_acceptance_rate, 1.0 / (Float(j) / 1000.0 + 1.0));

            // PSDR_INFO("here3");
            if (j > burn_in){
                if (j % mala_config.thinning == 0){
                    d_sampleIndirectBoundary(scene, gm.get(tid), sampler, n_bounces, cam_bounces,
                                        edge_dist, edge_indices,
                                        _d_image_spec_list, current_u);
                }
            }
            // PSDR_INFO("here4");
        }
        // PSDR_INFO("step_length: {}", step_length);
        
        // PSDR_INFO("local cache size: {}", cache_ptr->size());
        // PSDR_INFO("chain end");
//         if (global_cache->write){
// #pragma omp critical
//             {
//                 global_cache->merge(*cache_ptr);
//                 if (global_cache->size() >= cache_size){
//                     global_cache->write = false;
//                 }
//             }
//         }
//         delete cache_ptr;

        if (verbose){
#pragma omp critical
            {
                progressIndicator(static_cast<Float>(++blockProcessed) / mala_config.num_chains);
                accepted += thread_accepted;
                require_dim1 += thread_require_dim1;
            }
            // PSDR_INFO("acceptance rate: {}", Float(thread_accepted) / (10000));
        }
    }
    if (verbose){
        std::cout << std::endl;
        PSDR_INFO("acceptance rate: {}", Float(accepted) / (mala_config.num_chains * mala_config.num_samples));
        PSDR_INFO("require_mut_dim1 rate: {}%", Float(require_dim1 * 100) / (mala_config.num_chains * mala_config.num_samples));
    }
    for (int i = 0; i < nworker; i++){
        delete mutations[i];
    }
    for (int i = 0; i < mala_config.num_chains; i++){
        delete cache[i];
    }

    // merge gradient to d_scene
    gm_doll.setZero();
    gm.merge();
    d_scene.configureD(scene);
    /* normal related */
#ifdef NORMAL_PREPROCESS
    d_precompute_normal(scene, d_scene);
#endif
    // delete global_cache;
    delete phase_2_timer;
    return flattened(debugInfo.getArray());
}

// roughness-capped guiding
BSDF* get_new_capped_bsdf(const BSDF* bsdf, Float cap = 0.05) {
    assert(cap > 0.0);
    char name[100];
    name[0] = 0;
    bsdf->className(name);
    if (strcmp(name, "RoughConductorBSDF") == 0) {
        RoughConductorBSDF* bsdf_capped = dynamic_cast<RoughConductorBSDF*>(bsdf->clone());
        bsdf_capped->m_distr = MicrofacetDistribution(bsdf_capped->m_distr.m_alpha < cap ? cap : bsdf_capped->m_distr.m_alpha); // alpha capped at 0.05
        return bsdf_capped;
    } else if (strcmp(name, "RoughDielectricBSDF") == 0) {
        RoughDielectricBSDF* bsdf_capped = dynamic_cast<RoughDielectricBSDF*>(bsdf->clone());
        bsdf_capped->m_distr = MicrofacetDistribution(bsdf_capped->m_distr.m_alpha < cap ? cap : bsdf_capped->m_distr.m_alpha); // alpha capped at 0.05
        return bsdf_capped;
    } else {
        return bsdf->clone();
    }
}

void buildPhotonMap_capped(const Scene &scene, int num_paths, int max_bounces, std::vector<RadImpNode> &nodes, bool importance)
{
    const int nworker = omp_get_num_procs();
    std::vector<RndSampler> samplers;
    for (int i = 0; i < nworker; ++i)
        samplers.push_back(RndSampler(17, i));

    std::vector<std::vector<RadImpNode>> nodes_per_thread(nworker);
    for (int i = 0; i < nworker; i++)
    {
        nodes_per_thread[i].reserve(num_paths / nworker * max_bounces);
    }
    const Camera &camera = scene.camera;
    const CropRectangle &rect = camera.rect;

#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (size_t omp_i = 0; omp_i < (size_t)num_paths; omp_i++)
    {
        const int tid = omp_get_thread_num();
        RndSampler &sampler = samplers[tid];
        Ray ray;
        Spectrum throughput;
        Intersection its;
        const Medium *ptr_med = nullptr;
        int depth = 0;
        bool onSurface = false;
        if (!importance)
        {
            // Trace ray from camera
            Float x = rect.isValid() ? rect.offset_x + sampler.next1D() * rect.crop_width : sampler.next1D() * camera.width;
            Float y = rect.isValid() ? rect.offset_y + sampler.next1D() * rect.crop_height : sampler.next1D() * camera.height;
            ray = camera.samplePrimaryRay(x, y);
            throughput = Spectrum(1.0);
            ptr_med = camera.getMedID() == -1 ? nullptr : scene.medium_list[camera.getMedID()];
        }
        else
        {
            // Trace ray from emitter
            throughput = scene.sampleEmitterPosition(sampler.next2D(), its);
            ptr_med = its.ptr_med_ext;
            nodes_per_thread[tid].push_back(RadImpNode{its.p, throughput, depth});
            ray.org = its.p;
            throughput *= its.ptr_emitter->sampleDirection(sampler.next2D(), ray.dir);
            ray.dir = its.geoFrame.toWorld(ray.dir);
            depth++;
            onSurface = true;
        }
        if (scene.rayIntersect(ray, onSurface, its))
        {
            while (depth <= max_bounces)
            {
                bool inside_med = ptr_med != nullptr &&
                                    ptr_med->sampleDistance(Ray(ray), its.t, sampler.next2D(), &sampler, ray.org, throughput);
                if (inside_med)
                {
                    if (throughput.isZero())
                        break;
                    const PhaseFunction *ptr_phase = scene.phase_list[ptr_med->phase_id];
                    nodes_per_thread[tid].push_back(RadImpNode{ray.org, throughput, depth});
                    Float phase_val = ptr_phase->sample(-ray.dir, sampler.next2D(), ray.dir);
                    if (phase_val == 0.0)
                        break;
                    throughput *= phase_val;
                    scene.rayIntersect(ray, false, its);
                }
                else
                {
                    nodes_per_thread[tid].push_back(RadImpNode{its.p, throughput, depth});
                    Float bsdf_pdf, bsdf_eta;
                    Vector wo_local, wo;
                    BSDF* replacement_bsdf = get_new_capped_bsdf(its.ptr_bsdf);
                    Spectrum bsdf_weight = replacement_bsdf->sample(its, sampler.next3D(), wo_local, bsdf_pdf, bsdf_eta,
                                                            importance ? EBSDFMode::EImportanceWithCorrection : EBSDFMode::ERadiance);
                    if (bsdf_weight.isZero())
                        break;
                    throughput = throughput * bsdf_weight;

                    wo = its.toWorld(wo_local);
                    Vector wi = -ray.dir;
                    Float wiDotGeoN = wi.dot(its.geoFrame.n), woDotGeoN = wo.dot(its.geoFrame.n);
                    if (wiDotGeoN * its.wi.z() <= 0 || woDotGeoN * wo_local.z() <= 0)
                        break;

                    if (its.isMediumTransition())
                        ptr_med = its.getTargetMedium(woDotGeoN);

                    ray = Ray(its.p, wo);
                    if (!scene.rayIntersect(ray, true, its))
                        break;
                }
                depth++;
            }
        }
    }
    size_t sz_node = 0;
    for (int i = 0; i < nworker; i++)
        sz_node += nodes_per_thread[i].size();
    nodes.reserve(sz_node);
    for (int i = 0; i < nworker; i++)
        nodes.insert(nodes.end(), nodes_per_thread[i].begin(), nodes_per_thread[i].end());
}

int queryPhotonMap(const KDtree<Float> &indices, const Float *query_point, size_t *matched_indices, Float &matched_dist_sqr)
{
    int num_matched = 0;
    Float dist_sqr[NUM_NEAREST_NEIGHBORS];
    num_matched = indices.knnSearch(query_point, NUM_NEAREST_NEIGHBORS, matched_indices, dist_sqr);
    assert(num_matched == NUM_NEAREST_NEIGHBORS);
    matched_dist_sqr = dist_sqr[num_matched - 1];
    return num_matched;
}


Float eval_photon_DirectBoundary(const Vector3 &rnd_val,
                                    const Scene &scene,
                                    RndSampler &sampler,
                                    const DiscreteDistribution &edge_dist,
                                    const std::vector<Vector2i> &edge_indices, int max_bounces,
                                    const std::vector<RadImpNode> &rad_nodes,
                                    const KDtree<Float> &rad_indices)
{
    BoundarySamplingRecord eRec;
    // EdgeRaySamplingRecord eRec;
    sampleEdgeRay(scene, rnd_val, edge_dist, edge_indices, eRec);
    if (eRec.shape_id < 0)
    {
        return 0.0f;
    }
    const Shape *shape = scene.shape_list[eRec.shape_id];
    const Edge &edge = shape->edges[eRec.edge_id];

    Ray edgeRay(eRec.ref, eRec.dir);
    Intersection itsS, itsD;
    if (!scene.rayIntersect(edgeRay, true, itsS) ||
        !scene.rayIntersect(edgeRay.flipped(), true, itsD))
        return 0.0;

    if (itsS.ptr_emitter == nullptr)
        return 0.0;

    Spectrum value = itsS.ptr_emitter->eval(itsS, -edgeRay.dir);
    // sanity check
    // make sure the ray is tangent to the surface

    // NOTE prevent intersection with a backface
    Float gnDotD = itsS.geoFrame.n.dot(-edgeRay.dir);
    Float snDotD = itsS.shFrame.n.dot(-edgeRay.dir);
    bool success = (itsS.ptr_bsdf->isTransmissive() && math::signum(gnDotD) * math::signum(snDotD) > 0.5f) ||
                    (!itsS.ptr_bsdf->isTransmissive() && gnDotD > 0.01 && snDotD > 0.01);
    if (!success)
        return 0.0f;
    // populate the data in BoundarySamplingRecord eRec
    eRec.dir = -edgeRay.dir;
    eRec.shape_id_S = itsS.indices[0];
    eRec.tri_id_S = itsS.indices[1];
    eRec.shape_id_D = itsD.indices[0];
    eRec.tri_id_D = itsD.indices[1];

    /* Jacobian determinant that accounts for the change of variable */
    Vector v0 = shape->getVertex(edge.v0);
    Vector v1 = shape->getVertex(edge.v1);
    Vector v2 = shape->getVertex(edge.v2);
    const Vector xB = v0 + (v1 - v0) * eRec.t,
                    &xS = itsS.p;
    const Vector &xD = itsD.p;
    Vector n = (v0 - v1).cross(-edgeRay.dir).normalized();
    n *= -math::signum(n.dot(v2 - v0)); // make sure n points to the visible side
    Float J = dlD_dlB(xS,
                        xB, (v0 - v1).normalized(),
                        xD, itsD.geoFrame.n) *
                dA_dw(xB, xS, itsS.geoFrame.n);
    Float baseValue = J * geometric(xD, itsD.geoFrame.n, xS, itsS.geoFrame.n);
    // Float baseValue = J * geometric(xD, its.geoFrame.n,
    //                                 xS, dRec.n);
    if (std::abs(baseValue) < Epsilon)
        return 0.0f;

    Spectrum throughput(1.0f);
    Ray ray_sensor;
    int depth = 1;
    Spectrum result(0.0);
    Matrix2x4 pixel_uvs;
    Array4 attenuations(0.0);
    Vector dir;
    if (scene.isVisible(itsS.p, true, scene.camera.cpos, true))
    {
        scene.camera.sampleDirect(itsS.p, pixel_uvs, attenuations, dir);

        auto bsdf_val = itsS.evalBSDF(itsS.toLocal(dir),
                                        EBSDFMode::EImportanceWithCorrection);
        if (attenuations.maxCoeff() < Epsilon)
        {
            return 0.0;
        }
        else
        {
            size_t matched_indices[NUM_NEAREST_NEIGHBORS];

            Float pt_rad[3] = {itsS.p[0], itsS.p[1], itsS.p[2]};
            Float matched_r2_rad;

            int num_nearby_rad = queryPhotonMap(rad_indices, pt_rad, matched_indices, matched_r2_rad);
            Spectrum photon_radiances(0.0);
            for (int m = 0; m < num_nearby_rad; m++)
            {
                const RadImpNode &node = rad_nodes[matched_indices[m]];
                if (node.depth <= max_bounces)
                    photon_radiances += node.val;
            }

            result += (value * photon_radiances).maxCoeff() * baseValue / eRec.pdf / matched_r2_rad;
        }
    }

    return result.abs().maxCoeff();
}

Float eval_photon_InDirectBoundary(const Vector3 &rnd_val,
                                    const Scene &scene,
                                    RndSampler &sampler,
                                    const DiscreteDistribution &edge_dist,
                                    const std::vector<Vector2i> &edge_indices, int max_bounces,
                                    const std::vector<RadImpNode> &rad_nodes, const std::vector<RadImpNode> &imp_nodes,
                                    const KDtree<Float> &rad_indices, const KDtree<Float> &imp_indices, int shape_opt_id, bool local_backward)
{
    BoundarySamplingRecord eRec;
    // EdgeRaySamplingRecord eRec;
    sampleEdgeRay(scene, rnd_val, edge_dist, edge_indices, eRec);
    if (eRec.shape_id < 0)
    {
        return 0.0f;
    }
    const Shape *shape = scene.shape_list[eRec.shape_id];
    const Edge &edge = shape->edges[eRec.edge_id];

    Ray edgeRay(eRec.ref, eRec.dir);
    Intersection itsS, itsD;
    if (!scene.rayIntersect(edgeRay, true, itsS) ||
        !scene.rayIntersect(edgeRay.flipped(), true, itsD))
        return 0.0;
    // populate the data in BoundarySamplingRecord eRec
    eRec.shape_id_S = itsS.indices[0];
    eRec.tri_id_S = itsS.indices[1];
    eRec.shape_id_D = itsD.indices[0];
    eRec.tri_id_D = itsD.indices[1];

    // make sure the ray is tangent to the surface
    if (edge.f0 >= 0 && edge.f1 >= 0)
    {
        Vector n0 = shape->getGeoNormal(edge.f0),
                n1 = shape->getGeoNormal(edge.f1);
        Float dotn0 = edgeRay.dir.dot(n0),
                dotn1 = edgeRay.dir.dot(n1);
        if (math::signum(dotn0) * math::signum(dotn1) > -0.5)
        {
            PSDR_ASSERT_MSG(false, "Bad edge ray sample: [{}, {}]", dotn0, dotn1);
            return 0.0;
        }
    }

    /* prevent self intersection */
    const Vector2i ind0(eRec.shape_id, edge.f0), ind1(eRec.shape_id, edge.f1);
    if (itsS.indices == ind0 || itsS.indices == ind1 ||
        itsD.indices == ind0 || itsD.indices == ind1)
        return 0.0;

    // FIXME: if the isTransmissive() of a dielectric returns false, the rendering time will be short, and the variance will be small.
    const Float gn1d1 = itsS.geoFrame.n.dot(-edgeRay.dir), sn1d1 = itsS.shFrame.n.dot(-edgeRay.dir),
                gn2d1 = itsD.geoFrame.n.dot(edgeRay.dir), sn2d1 = itsD.shFrame.n.dot(edgeRay.dir);
    bool valid1 = (itsS.ptr_bsdf->isTransmissive() && math::signum(gn1d1) * math::signum(sn1d1) > 0.5f) || (!itsS.ptr_bsdf->isTransmissive() && gn1d1 > Epsilon && sn1d1 > Epsilon),
            valid2 = (itsD.ptr_bsdf->isTransmissive() && math::signum(gn2d1) * math::signum(sn2d1) > 0.5f) || (!itsD.ptr_bsdf->isTransmissive() && gn2d1 > Epsilon && sn2d1 > Epsilon);
    if (!valid1 || !valid2)
        return 0.0;

    if (eRec.shape_id != shape_opt_id && shape_opt_id != -1)
    {
        return 0.0;
    }

    Float max_normal = 1.0;

    if (local_backward)
    {
        const Vector &xB_0 = shape->getVertex(edge.v0);
        const Vector &xB_1 = shape->getVertex(edge.v1);
        const Vector &xB_2 = shape->getVertex(edge.v2);

        const Shape *shapeS = scene.shape_list[eRec.shape_id_S];
        const auto &indS = shapeS->getIndices(eRec.tri_id_S);
        const Vector &xS_0 = shapeS->getVertex(indS[0]);
        const Vector &xS_1 = shapeS->getVertex(indS[1]);
        const Vector &xS_2 = shapeS->getVertex(indS[2]);

        const Shape *shapeD = scene.shape_list[eRec.shape_id_D];
        const auto &indD = shapeD->getIndices(eRec.tri_id_D);
        const Vector &xD_0 = shapeD->getVertex(indD[0]);
        const Vector &xD_1 = shapeD->getVertex(indD[1]);
        const Vector &xD_2 = shapeD->getVertex(indD[2]);

        BoundarySegmentInfo segInfo;
        segInfo.xS_0 = xD_0;
        segInfo.xS_1 = xD_1;
        segInfo.xS_2 = xD_2;

        segInfo.xB_0 = xB_0;
        segInfo.xB_1 = xB_1;

        segInfo.xD_0 = xS_0;
        segInfo.xD_1 = xS_1;
        segInfo.xD_2 = xS_2;

        BoundarySegmentInfo d_segInfo;
        d_segInfo.setZero();

        d_normal_velocity(segInfo, d_segInfo, xB_2, eRec.t, eRec.dir, 1.0);

        if (shape_opt_id == -1)
        {
            max_normal = d_segInfo.maxCoeff();
        }
        else
        {
            max_normal = d_segInfo.maxCoeff(eRec.shape_id_S == shape_opt_id, eRec.shape_id == shape_opt_id, eRec.tri_id_D == shape_opt_id);
        }

        if (max_normal < Epsilon)
        {
            return 0.0;
        }
    }

    /* Jacobian determinant that accounts for the change of variable */
    Vector v0 = shape->getVertex(edge.v0);
    Vector v1 = shape->getVertex(edge.v1);
    Vector v2 = shape->getVertex(edge.v2);
    const Vector xB = v0 + (v1 - v0) * eRec.t,
                    &xS = itsS.p;
    const Vector &xD = itsD.p;
    Vector n = (v0 - v1).cross(-edgeRay.dir).normalized();
    n *= -math::signum(n.dot(v2 - v0)); // make sure n points to the visible side
    Float J = dlD_dlB(xS,
                        xB, (v0 - v1).normalized(),
                        xD, itsD.geoFrame.n) *
                dA_dw(xB, xS, itsS.geoFrame.n);
    Float baseValue = J * geometric(xD, itsD.geoFrame.n, xS, itsS.geoFrame.n) * max_normal;

    // assert(baseValue > 0.0);

    Matrix2x4 pixel_uvs;
    Array4 attenuations(0.0);
    Vector dir;

    if (!scene.isVisible(itsD.p, true, scene.camera.cpos, true))
        return 0.0;
    scene.camera.sampleDirect(itsD.p, pixel_uvs, attenuations, dir);

    auto bsdf_val = itsD.evalBSDF(itsD.toLocal(dir),
                                    EBSDFMode::EImportanceWithCorrection);

    if (eRec.pdf < Epsilon)
    {
        return 0.0;
    }
    if (attenuations.maxCoeff() < Epsilon)
    {
        return 0.0;
    }

    size_t matched_indices[NUM_NEAREST_NEIGHBORS];

    Float pt_rad[3] = {itsD.p[0], itsD.p[1], itsD.p[2]};
    Float pt_imp[3] = {itsS.p[0], itsS.p[1], itsS.p[2]};
    Float matched_r2_rad, matched_r2_imp;

    int num_nearby_rad = queryPhotonMap(rad_indices, pt_rad, matched_indices, matched_r2_rad);
    std::vector<Spectrum> photon_radiances(max_bounces + 1, Spectrum::Zero());
    for (int m = 0; m < num_nearby_rad; m++)
    {
        const RadImpNode &node = rad_nodes[matched_indices[m]];
        if (node.depth <= max_bounces)
            photon_radiances[node.depth] += node.val;
    }

    int num_nearby_imp = queryPhotonMap(imp_indices, pt_imp, matched_indices, matched_r2_imp);
    std::vector<Spectrum> importance(max_bounces, Spectrum::Zero());
    for (int m = 0; m < num_nearby_imp; m++)
    {
        const RadImpNode &node = imp_nodes[matched_indices[m]];
        if (node.depth < max_bounces)
            importance[node.depth] += node.val;
    }

    Spectrum value2 = Spectrum::Zero();
    int impStart = 1;
    for (int m = 0; m <= max_bounces; m++)
    {
        for (int n = impStart; n < max_bounces - m; n++)
            value2 += photon_radiances[m] * importance[n];
    }

    return abs(baseValue * eRec.pdf) * abs(value2.maxCoeff() / (matched_r2_rad * matched_r2_imp));
}

void MetropolisIndirectEdgeIntegrator::preprocess_grid(const Scene &scene, const Grid3D_Sampling::grid3D_config &config, int max_bounces)
{
    PhotonMapOptions opts(10000, 10000, max_bounces);
    std::vector<RadImpNode> rad_nodes, imp_nodes;

    std::cout << "[INFO] Indirect Guiding: #camPath = " << opts.num_cam_path << ", #lightPath = " << opts.num_light_path << std::endl;
    buildPhotonMap_capped(scene, opts.num_cam_path, max_bounces + 1, rad_nodes, false);
    buildPhotonMap_capped(scene, opts.num_light_path, max_bounces, imp_nodes, true);
    std::cout << "[INFO] Indirect Guiding: #rad_nodes = " << rad_nodes.size() << ", #imp_nodes = " << imp_nodes.size() << std::endl;
    PointCloud<Float> rad_cloud, imp_cloud;
    rad_cloud.pts.resize(rad_nodes.size());
    for (size_t i = 0; i < rad_nodes.size(); i++)
    {
        rad_cloud.pts[i].x = rad_nodes[i].p[0];
        rad_cloud.pts[i].y = rad_nodes[i].p[1];
        rad_cloud.pts[i].z = rad_nodes[i].p[2];
    }
    imp_cloud.pts.resize(imp_nodes.size());
    for (size_t i = 0; i < imp_nodes.size(); i++)
    {
        imp_cloud.pts[i].x = imp_nodes[i].p[0];
        imp_cloud.pts[i].y = imp_nodes[i].p[1];
        imp_cloud.pts[i].z = imp_nodes[i].p[2];
    }
    KDtree<Float> rad_indices(3, rad_cloud, nanoflann::KDTreeSingleIndexAdaptorParams(10));
    KDtree<Float> imp_indices(3, imp_cloud, nanoflann::KDTreeSingleIndexAdaptorParams(10));
    imp_indices.buildIndex();
    rad_indices.buildIndex();

    std::cout << "preprocessing IndirectEdgeIntegrator" << std::endl;
    auto NEE_function = [&](const Vector &AQ_rnd, RndSampler &sampler)
    {
        // Float result = eval_InDirectBoundary(AQ_rnd, scene, sampler, edge_dist, edge_indices, max_bounces);
        Float result = eval_photon_InDirectBoundary(AQ_rnd, scene, sampler, edge_dist, edge_indices, max_bounces, rad_nodes, imp_nodes, rad_indices, imp_indices, -1, false);
        if (!isnan(result))
        {
            return result;
        }
        return 0.0;
    };
    grid_distrb.setup(NEE_function, config);
    std::cout << "finish grid guiding" << std::endl;
}

void MetropolisIndirectEdgeIntegrator::preprocess_aq(const Scene &scene, const Adaptive_Sampling::adaptive3D_config &config, int max_bounces)
{
    PhotonMapOptions opts(10000, 10000, max_bounces);
    std::vector<RadImpNode> rad_nodes, imp_nodes;

    // std::cout << "[INFO] Indirect Guiding: #camPath = " << opts.num_cam_path << ", #lightPath = " << opts.num_light_path << std::endl;
    buildPhotonMap_capped(scene, opts.num_cam_path, max_bounces + 1, rad_nodes, false);
    buildPhotonMap_capped(scene, opts.num_light_path, max_bounces, imp_nodes, true);
    // std::cout << "[INFO] Indirect Guiding: #rad_nodes = " << rad_nodes.size() << ", #imp_nodes = " << imp_nodes.size() << std::endl;
    PointCloud<Float> rad_cloud, imp_cloud;
    rad_cloud.pts.resize(rad_nodes.size());
    for (size_t i = 0; i < rad_nodes.size(); i++)
    {
        rad_cloud.pts[i].x = rad_nodes[i].p[0];
        rad_cloud.pts[i].y = rad_nodes[i].p[1];
        rad_cloud.pts[i].z = rad_nodes[i].p[2];
    }
    imp_cloud.pts.resize(imp_nodes.size());
    for (size_t i = 0; i < imp_nodes.size(); i++)
    {
        imp_cloud.pts[i].x = imp_nodes[i].p[0];
        imp_cloud.pts[i].y = imp_nodes[i].p[1];
        imp_cloud.pts[i].z = imp_nodes[i].p[2];
    }
    KDtree<Float> rad_indices(3, rad_cloud, nanoflann::KDTreeSingleIndexAdaptorParams(10));
    KDtree<Float> imp_indices(3, imp_cloud, nanoflann::KDTreeSingleIndexAdaptorParams(10));
    imp_indices.buildIndex();
    rad_indices.buildIndex();

    // std::cout << "AQ preprocessing IndirectEdgeIntegrator" << std::endl;
    auto NEE_function = [&](const Vector &AQ_rnd, RndSampler &sampler)
    {
        // Float result = eval_InDirectBoundary(AQ_rnd, scene, sampler, edge_dist, edge_indices, max_bounces);
        Float result = eval_photon_InDirectBoundary(AQ_rnd, scene, sampler, edge_dist, edge_indices, max_bounces, rad_nodes, imp_nodes, rad_indices, imp_indices, config.shape_opt_id, config.local_backward);
        if (!isnan(result))
        {
            if (result < config.eps)
            {
                return config.eps;
            }
            return result;
        }
        return config.eps;
    };

    std::vector<Float> cdfx;
    if (config.edge_draw)
    {
        std::cout << "AQ using draw edge" << std::endl;
        cdfx = draw_dist.m_cdf;
    }
    else
    {
        // std::cout << "AQ using individual edge" << std::endl;
        cdfx = edge_dist.m_cdf;
    }

    // std::cout << "Inital curvature size: " << cdfx.size() << std::endl;
    aq_distrb.setup(NEE_function, cdfx, config);
    // aq_distrb.print();
    // std::cout << "finish AQ guiding" << std::endl;
}

// plotting helper functions

ArrayXd MetropolisIndirectEdgeIntegrator::solve_Grid(const Scene &scene, const Vector3i &size, 
                    const Vector &min, const Vector &max) const {
    int size_x = size[0];
    int size_y = size[1];
    int size_z = size[2];
    ArrayXd vol_array(size_x * size_y * size_z);
    
    const int nworker = omp_get_num_procs();
    // Fill the tensor using get_1b_func
    int blockProcessed = 0;
    Timer _("Sample Space Density");
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < size_x; ++i) {
        for (int j = 0; j < size_y * size_z; ++j) {
            // (xpos, ypos, zpos
            int idx = i * size_y * size_z + j;
            int xpos = idx % size_x;
            int ypos = (idx / size_x) % size_y;
            int zpos = idx / (size_x * size_y);
            Vector current = Vector(min[0] + (max[0] - min[0]) * xpos / size_x,
                                    min[1] + (max[1] - min[1]) * ypos / size_y,
                                    min[2] + (max[2] - min[2]) * zpos / size_z);
            Spectrum fu_spec = algorithm1_MALA_indirect::eval(scene, current, 1, 0, edge_dist, edge_indices);
            vol_array(idx) = fu_spec.sum();
        }
#pragma omp critical
        progressIndicator(static_cast<Float>(++blockProcessed) / size_x);
    }

    return vol_array;
}

ArrayXd MetropolisIndirectEdgeIntegrator::solve_Grid_rough(const Scene &scene, const Vector3i &size, 
                    const Vector &min, const Vector &max) const {
    int size_x = size[0];
    int size_y = size[1];
    int size_z = size[2];
    ArrayXd vol_array(size_x * size_y * size_z);
    int max_bounces = 1;
    PhotonMapOptions opts(10000, 10000, max_bounces);
    std::vector<RadImpNode> rad_nodes, imp_nodes;

    // std::cout << "[INFO] Indirect Guiding: #camPath = " << opts.num_cam_path << ", #lightPath = " << opts.num_light_path << std::endl;
    buildPhotonMap_capped(scene, opts.num_cam_path, max_bounces + 1, rad_nodes, false);
    buildPhotonMap_capped(scene, opts.num_light_path, max_bounces, imp_nodes, true);
    // std::cout << "[INFO] Indirect Guiding: #rad_nodes = " << rad_nodes.size() << ", #imp_nodes = " << imp_nodes.size() << std::endl;
    PointCloud<Float> rad_cloud, imp_cloud;
    rad_cloud.pts.resize(rad_nodes.size());
    for (size_t i = 0; i < rad_nodes.size(); i++)
    {
        rad_cloud.pts[i].x = rad_nodes[i].p[0];
        rad_cloud.pts[i].y = rad_nodes[i].p[1];
        rad_cloud.pts[i].z = rad_nodes[i].p[2];
    }
    imp_cloud.pts.resize(imp_nodes.size());
    for (size_t i = 0; i < imp_nodes.size(); i++)
    {
        imp_cloud.pts[i].x = imp_nodes[i].p[0];
        imp_cloud.pts[i].y = imp_nodes[i].p[1];
        imp_cloud.pts[i].z = imp_nodes[i].p[2];
    }
    KDtree<Float> rad_indices(3, rad_cloud, nanoflann::KDTreeSingleIndexAdaptorParams(10));
    KDtree<Float> imp_indices(3, imp_cloud, nanoflann::KDTreeSingleIndexAdaptorParams(10));
    imp_indices.buildIndex();
    rad_indices.buildIndex();

    
    const int nworker = omp_get_num_procs();
    // Fill the tensor using get_1b_func
    int blockProcessed = 0;
    Timer _("Sample Space Density");
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < size_x; ++i) {
        for (int j = 0; j < size_y * size_z; ++j) {
            // (xpos, ypos, zpos
            int idx = i * size_y * size_z + j;
            int xpos = idx % size_x;
            int ypos = (idx / size_x) % size_y;
            int zpos = idx / (size_x * size_y);
            Vector current = Vector(min[0] + (max[0] - min[0]) * xpos / size_x,
                                    min[1] + (max[1] - min[1]) * ypos / size_y,
                                    min[2] + (max[2] - min[2]) * zpos / size_z);
            RndSampler sampler(idx, 0);
            // Spectrum fu_spec = algorithm1_MALA_indirect::eval(scene, current, 1, 0, edge_dist, edge_indices);
            Float result = eval_photon_DirectBoundary(current, scene, sampler, edge_dist, edge_indices, 
                        1, rad_nodes, rad_indices);
            vol_array(idx) = result;
        }
#pragma omp critical
        progressIndicator(static_cast<Float>(++blockProcessed) / size_x);
    }

    return vol_array;
}

ArrayX3d MetropolisIndirectEdgeIntegrator::solve_MALA(const Scene &scene, const Vector3i &size, 
                    const Vector &min, const Vector &max) const
{

/* -----------------------------Compute edge_dist_inv----------------------------*/
    std::vector<std::vector<int>> edge_indices_inv;
    edge_indices_inv.resize(scene.shape_list.size());
    for (int i = 0; i < scene.shape_list.size(); i++){
        edge_indices_inv[i].resize(scene.shape_list[i]->edges.size());
        for (int j = 0; j < scene.shape_list[i]->edges.size(); j++){
            edge_indices_inv[i][j] = -1;
        }
    }
    for (int i = 0; i < edge_indices.size(); i++){
        edge_indices_inv[edge_indices[i].x()][edge_indices[i].y()] = i;
    }

    std::vector<std::vector<Vector3i>> face_to_edge_indices;
    face_to_edge_indices.resize(scene.shape_list.size());
    for (int i = 0; i < scene.shape_list.size(); i++){
        face_to_edge_indices[i].resize(scene.shape_list[i]->indices.size());
        for (int j = 0; j < scene.shape_list[i]->indices.size(); j++){
            face_to_edge_indices[i][j] = Vector3i(-1, -1, -1);
        }
    }
    for (int i = 0; i < edge_indices.size(); i++){
        Vector2i edge = edge_indices[i];
        Shape* shape = scene.shape_list[edge.x()];
        Edge e = shape->edges[edge.y()];
        if (e.f0 >= 0){
            Vector3i &f0_to_edge = face_to_edge_indices[edge.x()][e.f0];
            if (f0_to_edge.x() < 0){
                f0_to_edge.x() = i;
            } else if (f0_to_edge.y() < 0){
                f0_to_edge.y() = i;
            } else if (f0_to_edge.z() < 0){
                f0_to_edge.z() = i;
            }
        }
        if (e.f1 >= 0){
            Vector3i &f1_to_edge = face_to_edge_indices[edge.x()][e.f1];
            if (f1_to_edge.x() < 0){
                f1_to_edge.x() = i;
            } else if (f1_to_edge.y() < 0){
                f1_to_edge.y() = i;
            } else if (f1_to_edge.z() < 0){
                f1_to_edge.z() = i;
            }
        }
    }
/* --------------------------------------- Phase 1 ---------------------------------------*/

    int blockProcessed = 0;
    int phase1_samples = 1<<(mala_config.phase_one_samples);
    Spectrum phase1_sum = Spectrum::Zero();

    int nblocks_phase1 = std::ceil(sqrt(phase1_samples));
    int nsample_per_block = phase1_samples / nblocks_phase1;
    std::vector<SeedPath> seed_paths;
    seed_paths.clear();

    std::vector<std::pair<int, int>> path_lengths;
    path_lengths.resize(phase1_samples);
    const int nworker = omp_get_num_procs();
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < nblocks_phase1; i++)
    {
        RndSampler sampler(0, i);
        Spectrum thread_sum = Spectrum::Zero();
        std::vector<SeedPath> seed_paths_thread;
        seed_paths_thread.clear();
        for (int j = 0; j < nsample_per_block; j++)
        {
            int idx = i * nsample_per_block + j;
            if (idx >= phase1_samples)
                break;
            Float pdf = 1.0;

            // MALA::MALAVector rnd(n_bounces * 3);
            Vector d_rnd = sampler.next3D();
            if (grid_distrb.distrb.getSum() > Epsilon)
            {
                d_rnd = grid_distrb.sample(d_rnd, pdf);
            }
            if (aq_distrb.distrb.getSum() > Epsilon)
            {
                d_rnd = aq_distrb.sample(d_rnd, pdf);
            }

            std::vector<SeedPath> seed_paths_sample;
            seed_paths_sample.clear();
            Spectrum contrib = sampleIndirectBoundary(scene, d_rnd, sampler, 1, edge_dist, edge_indices, seed_paths_sample);
            if (std::isnan(contrib.sum() / pdf))
                continue;
            for (int k = 0; k < seed_paths_sample.size(); k++){
                seed_paths_sample[k].weight /= pdf;
                seed_paths_thread.push_back(seed_paths_sample[k]);
            }

            thread_sum += contrib / pdf;
        }
#pragma omp critical
        {
            phase1_sum[0] += thread_sum[0] / (Float)phase1_samples;
            phase1_sum[1] += thread_sum[1] / (Float)phase1_samples;
            phase1_sum[2] += thread_sum[2] / (Float)phase1_samples;
            
            for (int k = 0; k < seed_paths_thread.size(); k++){
                seed_paths.push_back(seed_paths_thread[k]);
            }
        }
    }
    RndSampler sampler_wrs(1, nblocks_phase1);
    PSDR_INFO("valid_sample: {}, num_chains: {}", seed_paths.size(), mala_config.num_chains);
    assert(seed_paths.size() > mala_config.num_chains);

    auto comp = [](const std::pair<Float, int>& a, const std::pair<Float, int>& b) {
        return a.first > b.first; // Return true if a is greater than b
    };

    std::vector<MALA::MALAVector> init_samples;
    std::vector<std::pair<int, int>> init_sample_path_lengths;
    init_samples.resize(mala_config.num_chains);
    init_sample_path_lengths.resize(mala_config.num_chains);

    std::priority_queue<std::pair<Float, int>, std::vector<std::pair<Float, int>>, decltype(comp)> pq(comp);
    for (int i = 0; i < seed_paths.size(); i++){
        if (seed_paths[i].weight.sum() < 1e-50) {
            continue;
        }
        Float r = pow(sampler_wrs.next1D(), 1.0 / (seed_paths[i].weight.sum()));
        if (pq.size() < init_samples.size()){
            pq.push(std::make_pair(r, i));
        }
        else {
            if (pq.top().first < r){
                pq.pop();
                pq.push(std::make_pair(r, i));
            }
        }
    }
    while(!pq.empty()){
        // PSDR_INFO("pq.top().first: {}", pq.top().first);
        int i = pq.top().second;
        init_samples[pq.size() - 1] = seed_paths[i].pss;
        init_sample_path_lengths[pq.size() - 1] = {seed_paths[i].max_b, seed_paths[i].cam_b};
        pq.pop();
    }
    assert(!std::isnan(phase1_sum[1]));

    Spectrum phase1_mean = phase1_sum;
    PSDR_INFO("phase 1 sum: {}, {}, {}", phase1_sum[0], phase1_sum[1], phase1_sum[2]);
    PSDR_INFO("phase 1: {}, {}, {}", phase1_mean[0], phase1_mean[1], phase1_mean[2]);

    // PSDR_INFO("here1");
    RndSampler sampler(13, 1);
    // PSDR_INFO("pt_cloud size: {}", mala_config.num_samples * mala_config.num_chains);
    ArrayX3d point_cloud(mala_config.num_samples * mala_config.num_chains, 3);
    point_cloud.setZero();
    // PSDR_INFO("here2");
    
    Scene tmp_der = scene;
    GradientManager<Scene> gm_doll(tmp_der, omp_get_num_procs());

    
    MALA::MutationDiminishing mutations[nworker];
    MALA::KNNCache* cache[mala_config.num_chains];
    // MALA::GridCache* grid_cache[nworker];
    // int max_cache = 2000;
    for (int i = 0; i < mala_config.num_chains; i++){
        cache[i] = new MALA::KNNCache(3, 3000);
        // grid_cache[i] = new MALA::GridCache(cache_size);
    }
    // LMC mutation:
    int accepted = 0;
    int rejected = 0;
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < mala_config.num_chains; i++)
    {
        const int tid = omp_get_thread_num();
        int thread_accepted = 0;
        RndSampler sampler(13, i);
        MALA::Mutation *mutation = &mutations[tid];
        
        mutation->setZero();
        MALA::KNNCache *cache_ptr = cache[i];
        // PSDR_INFO("here3");

        Float step_length = mala_config.step_length;
        Float target_acceptance_rate = 0.57;
        Float A = 0.0;
        MALA::MALAVector current_u = init_samples[i];
        std::pair<int, int> current_path_length = init_sample_path_lengths[i];
        int n_bounces = current_path_length.first;
        assert(n_bounces == 1);
        int cam_bounces = current_path_length.second;
        assert(cam_bounces == 0);
        for (int j = 0; j < mala_config.num_samples; j++)
        {
            bool global = false;
            Float p_global = sampler.next1D();
            MALA::MALAVector proposed_u(n_bounces * 3);
            proposed_u.setZero();
            // PSDR_INFO("here1");
            int n = n_bounces, i = cam_bounces;
            // PSDR_INFO("here4");
            if (p_global < mala_config.p_global) {
                A = mutateLargeStep(scene, &sampler, 1, n_bounces, cam_bounces, 
                                edge_dist, edge_indices, grid_distrb, current_u, proposed_u, n, i);
                global = true;
                // sample a point on the boundary
            } else {
                A = mutateSmallStep(scene, gm_doll.get(tid), &sampler, n_bounces, cam_bounces,
                            *mutation, *cache_ptr,
                            edge_dist, edge_indices, edge_indices_inv, face_to_edge_indices,
                            step_length, mala_config,
                            current_u, proposed_u);
            }
            // PSDR_INFO("here5");
            Float a = sampler.next1D();
            if (a < A)
            {
                thread_accepted++;
                if (current_u.size() != proposed_u.size()){
                    current_u.resize(proposed_u.size());
                }
                if (global){
                    n_bounces = n;
                    cam_bounces = i;
                    mutation->setZero();
                }
                current_u = proposed_u;
            }
            Float acceptance_rate = Float(thread_accepted + 1) / (j + 2);
            step_length = step_length * std::pow((acceptance_rate) / target_acceptance_rate, 1.0 / (Float(j) / 1000.0 + 1.0));

            // PSDR_INFO("here6");
            if (current_u[0] > min[0] && current_u[0] < max[0] &&
                current_u[1] > min[1] && current_u[1] < max[1] &&
                current_u[2] > min[2] && current_u[2] < max[2])
            {
                MALA::MALAVector current_u_copy = current_u;
                current_u_copy[0] = (current_u_copy[0] - min[0]) / (max[0] - min[0]);
                current_u_copy[1] = (current_u_copy[1] - min[1]) / (max[1] - min[1]);
                current_u_copy[2] = (current_u_copy[2] - min[2]) / (max[2] - min[2]);
                point_cloud.row(i * mala_config.num_samples + j) = current_u_copy;
            }
        }

        if (verbose){
#pragma omp critical
            {
                progressIndicator(static_cast<Float>(++blockProcessed) / mala_config.num_chains);
                accepted += thread_accepted;
            }
            // PSDR_INFO("acceptance rate: {}", Float(thread_accepted) / (10000));
        }
    }
    std::cout << std::endl;
        PSDR_INFO("acceptance rate: {}", Float(accepted) / (mala_config.num_chains * mala_config.num_samples));
    return point_cloud;
}

Eigen::Array<Float, -1, 4, 1> MetropolisIndirectEdgeIntegrator::solve_MALA_rough(const Scene &scene, const Vector3i &size, 
                    const Vector &min, const Vector &max) const
{

/* -----------------------------Compute edge_dist_inv----------------------------*/
    std::vector<std::vector<int>> edge_indices_inv;
    edge_indices_inv.resize(scene.shape_list.size());
    for (int i = 0; i < scene.shape_list.size(); i++){
        edge_indices_inv[i].resize(scene.shape_list[i]->edges.size());
        for (int j = 0; j < scene.shape_list[i]->edges.size(); j++){
            edge_indices_inv[i][j] = -1;
        }
    }
    for (int i = 0; i < edge_indices.size(); i++){
        edge_indices_inv[edge_indices[i].x()][edge_indices[i].y()] = i;
    }


/* --------------------------------------- Phase 1 ---------------------------------------*/

    int blockProcessed = 0;
    int phase1_samples = 1<<(mala_config.phase_one_samples);
    Spectrum phase1_sum = Spectrum::Zero();

    int nblocks_phase1 = std::ceil(sqrt(phase1_samples));
    int nsample_per_block = phase1_samples / nblocks_phase1;
    std::vector<SeedPath> seed_paths;
    seed_paths.clear();

    std::vector<std::pair<int, int>> path_lengths;
    path_lengths.resize(phase1_samples);

    int nworker = omp_get_num_procs();
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < nblocks_phase1; i++)
    {
        RndSampler sampler(0, i);
        Spectrum thread_sum = Spectrum::Zero();
        std::vector<SeedPath> seed_paths_thread;
        seed_paths_thread.clear();
        for (int j = 0; j < nsample_per_block; j++)
        {
            int idx = i * nsample_per_block + j;
            if (idx >= phase1_samples)
                break;
            Float pdf = 1.0;

            // MALA::MALAVector rnd(n_bounces * 3);
            Vector d_rnd = sampler.next3D();
            if (grid_distrb.distrb.getSum() > Epsilon)
            {
                d_rnd = grid_distrb.sample(d_rnd, pdf);
            }
            if (aq_distrb.distrb.getSum() > Epsilon)
            {
                d_rnd = aq_distrb.sample(d_rnd, pdf);
            }

            std::vector<SeedPath> seed_paths_sample;
            seed_paths_sample.clear();
            Spectrum contrib = sampleIndirectBoundary(scene, d_rnd, sampler, 1, edge_dist, edge_indices, seed_paths_sample);
            if (std::isnan(contrib.sum() / pdf))
                continue;
            for (int k = 0; k < seed_paths_sample.size(); k++){
                seed_paths_sample[k].weight /= pdf;
                seed_paths_thread.push_back(seed_paths_sample[k]);
            }

            thread_sum += contrib / pdf;
        }
#pragma omp critical
        {
            phase1_sum[0] += thread_sum[0] / (Float)phase1_samples;
            phase1_sum[1] += thread_sum[1] / (Float)phase1_samples;
            phase1_sum[2] += thread_sum[2] / (Float)phase1_samples;
            
            for (int k = 0; k < seed_paths_thread.size(); k++){
                seed_paths.push_back(seed_paths_thread[k]);
            }
        }
    }
    RndSampler sampler_wrs(1, nblocks_phase1);
    PSDR_INFO("valid_sample: {}, num_chains: {}", seed_paths.size(), mala_config.num_chains);
    assert(seed_paths.size() > mala_config.num_chains);

    auto comp = [](const std::pair<Float, int>& a, const std::pair<Float, int>& b) {
        return a.first > b.first; // Return true if a is greater than b
    };

    std::vector<MALA::MALAVector> init_samples;
    std::vector<std::pair<int, int>> init_sample_path_lengths;
    init_samples.resize(mala_config.num_chains);
    init_sample_path_lengths.resize(mala_config.num_chains);

    std::priority_queue<std::pair<Float, int>, std::vector<std::pair<Float, int>>, decltype(comp)> pq(comp);
    for (int i = 0; i < seed_paths.size(); i++){
        if (seed_paths[i].weight.sum() < 1e-50) {
            continue;
        }
        Float r = pow(sampler_wrs.next1D(), 1.0 / (seed_paths[i].weight.sum()));
        if (pq.size() < init_samples.size()){
            pq.push(std::make_pair(r, i));
        }
        else {
            if (pq.top().first < r){
                pq.pop();
                pq.push(std::make_pair(r, i));
            }
        }
    }
    Eigen::Array<Float, -1, 4, 1> point_cloud(seed_paths.size(), 4);
    for (int i = 0; i < seed_paths.size(); i++){
        MALA::MALAVector current_u_copy = seed_paths[i].pss;
        current_u_copy[0] = (current_u_copy[0] - min[0]) / (max[0] - min[0]);
        current_u_copy[1] = (current_u_copy[1] - min[1]) / (max[1] - min[1]);
        current_u_copy[2] = (current_u_copy[2] - min[2]) / (max[2] - min[2]);
        point_cloud.row(i)[0] = current_u_copy[0];
        point_cloud.row(i)[1] = current_u_copy[1];
        point_cloud.row(i)[2] = current_u_copy[2];
        point_cloud.row(i)[3] = 0.0;
    }
    while(!pq.empty()){
        // PSDR_INFO("pq.top().first: {}", pq.top().first);
        int i = pq.top().second;
        init_samples[pq.size() - 1] = seed_paths[i].pss;
        init_sample_path_lengths[pq.size() - 1] = {seed_paths[i].max_b, seed_paths[i].cam_b};
        pq.pop();
        if (seed_paths[i].pss[0] > min[0] && seed_paths[i].pss[0] < max[0] &&
            seed_paths[i].pss[1] > min[1] && seed_paths[i].pss[1] < max[1] &&
            seed_paths[i].pss[2] > min[2] && seed_paths[i].pss[2] < max[2])
        {
            point_cloud.row(i)[3] = 1.0;
        }
    }
    assert(!std::isnan(phase1_sum[1]));

    Spectrum phase1_mean = phase1_sum;
    PSDR_INFO("phase 1 sum: {}, {}, {}", phase1_sum[0], phase1_sum[1], phase1_sum[2]);
    PSDR_INFO("phase 1: {}, {}, {}", phase1_mean[0], phase1_mean[1], phase1_mean[2]);


    return point_cloud;
}

ArrayX3d MetropolisIndirectEdgeIntegrator::get_edge_ray(const Scene &scene, const Vector &rnd) const {
    EdgeRaySamplingRecord eRec;
    sampleEdgeRay(scene, rnd, edge_dist, edge_indices, eRec);
    ArrayX3d edge_ray(2, 3);
    edge_ray.row(0) = eRec.ref;
    edge_ray.row(1) = eRec.dir;
    return edge_ray;
}


ArrayX3d MetropolisIndirectEdgeIntegrator::perturbe_sample(const Scene &scene, const Vector &rnd, const Vector &mutation) const {
    
/* -----------------------------Phase 0: Compute connectivity-----------------------------*/
    std::vector<std::vector<int>> edge_indices_inv;
    edge_indices_inv.resize(scene.shape_list.size());
    for (int i = 0; i < scene.shape_list.size(); i++){
        edge_indices_inv[i].resize(scene.shape_list[i]->edges.size());
        for (int j = 0; j < scene.shape_list[i]->edges.size(); j++){
            edge_indices_inv[i][j] = -1;
        }
    }
    for (int i = 0; i < edge_indices.size(); i++){
        edge_indices_inv[edge_indices[i].x()][edge_indices[i].y()] = i;
    }

    std::vector<std::vector<Vector3i>> face_to_edge_indices;
    face_to_edge_indices.resize(scene.shape_list.size());
    for (int i = 0; i < scene.shape_list.size(); i++){
        face_to_edge_indices[i].resize(scene.shape_list[i]->indices.size());
        for (int j = 0; j < scene.shape_list[i]->indices.size(); j++){
            face_to_edge_indices[i][j] = Vector3i(-1, -1, -1);
        }
    }
    for (int i = 0; i < edge_indices.size(); i++){
        Vector2i edge = edge_indices[i];
        Shape* shape = scene.shape_list[edge.x()];
        Edge e = shape->edges[edge.y()];
        if (e.f0 >= 0){
            Vector3i &f0_to_edge = face_to_edge_indices[edge.x()][e.f0];
            if (f0_to_edge.x() < 0){
                f0_to_edge.x() = edge.y();
            } else if (f0_to_edge.y() < 0){
                f0_to_edge.y() = edge.y();
            } else if (f0_to_edge.z() < 0){
                f0_to_edge.z() = edge.y();
            }
        }
        if (e.f1 >= 0){
            Vector3i &f1_to_edge = face_to_edge_indices[edge.x()][e.f1];
            if (f1_to_edge.x() < 0){
                f1_to_edge.x() = edge.y();
            } else if (f1_to_edge.y() < 0){
                f1_to_edge.y() = edge.y();
            } else if (f1_to_edge.z() < 0){
                f1_to_edge.z() = edge.y();
            }
        }
    }
    algorithm1_MALA_indirect::EdgeBound bound;
    EdgeRaySamplingRecord eRec;
    sampleEdgeRay(scene, rnd, edge_dist, edge_indices, eRec);
    bound.shape_idx = eRec.shape_id;
    bound.edge_idx = eRec.edge_id;
    int edge_id_dist = edge_indices_inv[bound.shape_idx][bound.edge_idx];
    // PSDR_INFO("edge_id_dist: {}", edge_id_dist);
    bound.min = edge_dist.m_cdf[edge_id_dist];
    bound.max = edge_dist.m_cdf[edge_id_dist + 1];
    bound.dir = eRec.dir;
    RndSampler sampler(0, 0);
    std::vector<std::pair<Vector, int>> mutation_path;
    Vector final_dir;
    blockJump(rnd, mutation, scene, bound, edge_indices, edge_indices_inv, edge_dist, face_to_edge_indices, &sampler, final_dir, mutation_path, true);
    PSDR_INFO("mutation_path size: {}", mutation_path.size());
    ArrayX3d path_array(mutation_path.size(), 3);
    for (int i = 0; i < mutation_path.size(); i++){
        path_array.row(i) = mutation_path[i].first;
    }
    return path_array;
}