#include "boundary.h"
#include <render/scene.h>
#include <core/math_func.h>
#include <core/timer.h>
#include <core/nanoflann.hpp>
#include <render/photon_map.h>
#include <core/logger.h>
#include <unsupported/Eigen/CXX11/Tensor>
#include <emitter/area.h>
#include <bsdf/diffuse.h>
#include <bsdf/roughconductor.h>
#include <bsdf/roughdielectric.h>

using MALA::PSS_State;

namespace {
    int sampleInt(const Float &rnd, const int &max) { // sample from {1, 2, ..., max}
        return std::min((int)(rnd * max), max - 1) + 1;
    }

    void sampleEdgeRay(const Scene &scene, const Array3 rnd,
                       const DiscreteDistribution &edge_dist,
                       const std::vector<Vector2i> &edge_indices,
                       EdgeRaySamplingRecord &eRec, 
                       EdgePrimarySampleRecord *ePSRec = nullptr)
    {
        /* Sample a point on the boundary */
        scene.sampleEdgePoint(rnd[0],
                              edge_dist, edge_indices,
                              eRec, ePSRec);
        if (eRec.shape_id == -1)
        {
            PSDR_WARN(eRec.shape_id == -1);
            return;
        }
        const Shape *shape = scene.shape_list[eRec.shape_id];
        const Edge &edge = shape->edges[eRec.edge_id];
        assert(edge.f0 >= 0);

        /* Sample edge ray */
        // if (edge.f1 < 0) // Case 1: boundary edge
        // {
            eRec.dir = squareToUniformSphere(Vector2{rnd[1], rnd[2]});
            eRec.pdf /= 4. * M_PI;
        // }
        // else // Case 2: non-boundary edge
        // {
        //     Vector n0 = shape->getGeoNormal(edge.f0);
        //     Vector n1 = shape->getGeoNormal(edge.f1);
        //     Float pdf1;
        //     eRec.dir = squareToEdgeRayDirection(Vector2(rnd[1], rnd[2]), n0, n1, pdf1);
        //     eRec.pdf *= pdf1;

        //     Float dotn0 = eRec.dir.dot(n0), dotn1 = eRec.dir.dot(n1);
        //     if (math::signum(dotn0) * math::signum(dotn1) > -0.5f)
        //     {
        //         std::cerr << "\n[WARN] Bad edge ray sample: [" << dotn0 << ", " << dotn1 << "]" << std::endl;
        //         eRec.shape_id = -1;
        //     }
        // }
        if (ePSRec != nullptr) {
            ePSRec->pdf = eRec.pdf;
        }
    }

    BSDF* get_new_capped_bsdf(const BSDF* bsdf, Float cap = 0.02) {
        assert(cap > 0.0);
        char name[100];
        name[0] = 0;
        bsdf->className(name);
        if (strcmp(name, "RoughConductorBSDF") == 0) {
            RoughConductorBSDF* bsdf_capped = dynamic_cast<RoughConductorBSDF*>(bsdf->clone());
            bsdf_capped->m_distr = MicrofacetDistribution(bsdf_capped->m_distr.m_alpha < cap ? cap : bsdf_capped->m_distr.m_alpha); // alpha capped at 0.05
            return bsdf_capped;
        } else if (strcmp(name, "RoughDielectricBSDF") == 0) {
            RoughDielectricBSDF* bsdf_capped = dynamic_cast<RoughDielectricBSDF*>(bsdf->clone());
            bsdf_capped->m_distr = MicrofacetDistribution(bsdf_capped->m_distr.m_alpha < cap ? cap : bsdf_capped->m_distr.m_alpha); // alpha capped at 0.05
            return bsdf_capped;
        } else {
            return bsdf->clone();
        }
    }

    void build_PhotonMap_capped(const Scene &scene, int num_paths, int max_bounces, std::vector<RadImpNode> &nodes, bool importance)
    {
        const int nworker = omp_get_num_procs();
        std::vector<RndSampler> samplers;
        for (int i = 0; i < nworker; ++i)
            samplers.push_back(RndSampler(17, i));

        std::vector<std::vector<RadImpNode>> nodes_per_thread(nworker);
        for (int i = 0; i < nworker; i++)
        {
            nodes_per_thread[i].reserve(num_paths / nworker * max_bounces);
        }

        const Camera &camera = scene.camera;
        const CropRectangle &rect = camera.rect;

#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
        for (size_t omp_i = 0; omp_i < (size_t)num_paths; omp_i++)
        {
            const int tid = omp_get_thread_num();
            RndSampler &sampler = samplers[tid];
            Ray ray;
            Spectrum throughput = Spectrum::Zero();
            Intersection its;
            const Medium *ptr_med = nullptr;
            int depth = 0;
            bool onSurface = false;
            if (!importance)
            {
                // Trace ray from camera
                Float x = rect.isValid() ? rect.offset_x + sampler.next1D() * rect.crop_width : sampler.next1D() * camera.width;
                Float y = rect.isValid() ? rect.offset_y + sampler.next1D() * rect.crop_height : sampler.next1D() * camera.height;
                ray = camera.samplePrimaryRay(x, y);
                throughput = Spectrum(1.0);
                ptr_med = camera.getMedID() == -1 ? nullptr : scene.medium_list[camera.getMedID()];
            }
            else
            {
                // Trace ray from emitter
                throughput = scene.sampleEmitterPosition(sampler.next2D(), its);
                ptr_med = its.ptr_med_ext;
                nodes_per_thread[tid].push_back(RadImpNode{its.p, throughput, depth});
                ray.org = its.p;
                throughput *= its.ptr_emitter->sampleDirection(sampler.next2D(), ray.dir);
                ray.dir = its.geoFrame.toWorld(ray.dir);
                depth++;
                onSurface = true;
            }

            if (scene.rayIntersect(ray, onSurface, its))
            {
                while (depth <= max_bounces)
                {
                    bool inside_med = ptr_med != nullptr &&
                                      ptr_med->sampleDistance(Ray(ray), its.t, sampler.next2D(), &sampler, ray.org, throughput);
                    if (inside_med)
                    {
                        if (throughput.isZero())
                            break;
                        const PhaseFunction *ptr_phase = scene.phase_list[ptr_med->phase_id];
                        nodes_per_thread[tid].push_back(RadImpNode{ray.org, throughput, depth});
                        Float phase_val = ptr_phase->sample(-ray.dir, sampler.next2D(), ray.dir);
                        if (phase_val == 0.0)
                            break;
                        throughput *= phase_val;
                        scene.rayIntersect(ray, false, its);
                    }
                    else
                    {
                        nodes_per_thread[tid].push_back(RadImpNode{its.p, throughput, depth});
                        Float bsdf_pdf, bsdf_eta;
                        Vector wo_local, wo;
                        BSDF* replacement_bsdf = get_new_capped_bsdf(its.ptr_bsdf);
                        Spectrum bsdf_weight = replacement_bsdf->sample(its, sampler.next3D(), wo_local, bsdf_pdf, bsdf_eta,
                                                              importance ? EBSDFMode::EImportanceWithCorrection : EBSDFMode::ERadiance);
                        // Spectrum bsdf_weight = its.sampleBSDF(sampler.next3D(), wo_local, bsdf_pdf, bsdf_eta,
                        //                                       importance ? EBSDFMode::EImportanceWithCorrection : EBSDFMode::ERadiance);
                        delete replacement_bsdf;
                        if (bsdf_weight.isZero())
                            break;
                        throughput = throughput * bsdf_weight;

                        wo = its.toWorld(wo_local);
                        Vector wi = -ray.dir;
                        Float wiDotGeoN = wi.dot(its.geoFrame.n), woDotGeoN = wo.dot(its.geoFrame.n);
                        if (wiDotGeoN * its.wi.z() <= 0 || woDotGeoN * wo_local.z() <= 0)
                            break;

                        if (its.isMediumTransition())
                            ptr_med = its.getTargetMedium(woDotGeoN);

                        ray = Ray(its.p, wo);
                        if (!scene.rayIntersect(ray, true, its))
                            break;
                    }
                    depth++;
                }
            }
        }
        size_t sz_node = 0;
        for (int i = 0; i < nworker; i++)
            sz_node += nodes_per_thread[i].size();
        nodes.reserve(sz_node);
        for (int i = 0; i < nworker; i++)
            nodes.insert(nodes.end(), nodes_per_thread[i].begin(), nodes_per_thread[i].end());
    }

    int query_PhotonMap(const KDtree<Float> &indices, const Float *query_point, size_t *matched_indices, Float &matched_dist_sqr)
    {
        int num_matched = 0;
        Float dist_sqr[NUM_NEAREST_NEIGHBORS];
        num_matched = indices.knnSearch(query_point, NUM_NEAREST_NEIGHBORS, matched_indices, dist_sqr);
        assert(num_matched == NUM_NEAREST_NEIGHBORS);
        matched_dist_sqr = dist_sqr[num_matched - 1];
        return num_matched;
    }

    Float eval_photon_DirectBoundary_capped(const Vector3 &rnd_val,
                                     const Scene &scene,
                                     RndSampler &sampler,
                                     const DiscreteDistribution &edge_dist,
                                     const std::vector<Vector2i> &edge_indices, int max_bounces,
                                     const std::vector<RadImpNode> &rad_nodes,
                                     const KDtree<Float> &rad_indices)
    {
        BoundarySamplingRecord eRec;
        // scene.sampleEdgePoint(rnd[0],
        //                       edge_dist, edge_indices,
        //                       eRec, (path == nullptr) ? nullptr : &(path->ePSRec));
        sampleEdgeRay(scene, rnd_val, edge_dist, edge_indices, eRec);
        // scene.sampleEdgePoint(rnd_val[0],
        //                       edge_dist, edge_indices,
        //                       eRec);
        if (eRec.shape_id == -1)
        {
            // PSDR_WARN(eRec.shape_id == -1);
            return 0.0;
        }
        const Shape *shape = scene.shape_list[eRec.shape_id];
        const Edge &edge = shape->edges[eRec.edge_id];
        assert(edge.f0 >= 0);

        Ray edgeRay(eRec.ref, eRec.dir);
        Intersection itsS, its;
        if (!scene.rayIntersect(edgeRay, true, itsS) ||
            !scene.rayIntersect(edgeRay.flipped(), true, its))
            return 0.0f;
        if (!itsS.ptr_shape->isEmitter()) {
            if (!its.ptr_shape->isEmitter())
                return 0.0f;
            else {
                std::swap(itsS, its);
                edgeRay = edgeRay.flipped();
                eRec.pdf *= 2.0;
            }
        }
        /* Sample point on emitters */
        // DirectSamplingRecord dRec(eRec.ref);
        Spectrum value = itsS.Le(-edgeRay.dir);

        // Vector2 rnd_light(rnd[1], rnd[2]);
        // Spectrum value = scene.sampleEmitterDirect(rnd_light, dRec, (path == nullptr) ? nullptr : &(path->dPSRec));
        // if (echo) {
        //     PSDR_INFO("value_fwd: {}, {}, {}", value[0], value[1], value[2]);
        //     PSDR_INFO("dRec.p: {}, {}, {}", dRec.p[0], dRec.p[1], dRec.p[2]);
        //     PSDR_INFO("dRec.barycentric: {}, {}", dRec.barycentric[0], dRec.barycentric[1]);
        // }
        if (value.isZero(Epsilon))
            return 0.0f;
        const Vector xB = eRec.ref,
                     &xS = itsS.p;
        // /* Sample point on emitters */
        // DirectSamplingRecord dRec(eRec.ref);
        // Vector2 rnd_light(rnd_val[1], rnd_val[2]);
        // Spectrum value = scene.sampleEmitterDirect(rnd_light, dRec);
        // if (value.isZero(Epsilon))
        //     return 0.0f;
        // const Vector xB = eRec.ref,
        //              &xS = dRec.p;
        // Ray ray(xB, (xB - xS).normalized());
        // Intersection its;
        // if (!scene.rayIntersect(ray, true, its))
        //     return 0.0f;

        // // sanity check
        // // make sure the ray is tangent to the surface
        // if (edge.f0 >= 0 && edge.f1 >= 0)
        // {
        //     Vector n0 = shape->getGeoNormal(edge.f0),
        //            n1 = shape->getGeoNormal(edge.f1);
        //     Float dotn0 = ray.dir.dot(n0),
        //           dotn1 = ray.dir.dot(n1);
        //     if (math::signum(dotn0) * math::signum(dotn1) > -0.5)
        //         return 0.0;
        // }

        // NOTE prevent intersection with a backface
        Float gnDotD = its.geoFrame.n.dot(edgeRay.dir);
        Float snDotD = its.shFrame.n.dot(edgeRay.dir);
        bool success = (its.ptr_bsdf->isTransmissive() && math::signum(gnDotD) * math::signum(snDotD) > 0.5f) ||
                       (!its.ptr_bsdf->isTransmissive() && gnDotD > 0.01 && snDotD > 0.01);
        if (!success)
            return 0.0f;
        // populate the data in BoundarySamplingRecord eRec
        eRec.dir = edgeRay.dir;
        eRec.shape_id_S = itsS.indices[0];
        eRec.tri_id_S = itsS.indices[1];
        eRec.shape_id_D = its.indices[0];
        eRec.tri_id_D = its.indices[1];

        /* Jacobian determinant that accounts for the change of variable */
        Vector v0 = shape->getVertex(edge.v0);
        Vector v1 = shape->getVertex(edge.v1);
        Vector v2 = shape->getVertex(edge.v2);
        const Vector &xD = its.p;
        Vector n = (v0 - v1).cross(-edgeRay.dir).normalized();
        n *= -math::signum(n.dot(v2 - v0)); // make sure n points to the visible side
        Float J = dlD_dlB(xS,
                          xB, (v0 - v1).normalized(),
                          xD, its.geoFrame.n) *
                  dA_dw(xB, xS, itsS.geoFrame.n);
        Float baseValue = J * geometric(xD, its.geoFrame.n, xS, itsS.geoFrame.n);
        if (std::abs(baseValue) < Epsilon)
            return 0.0f;

        Spectrum throughput(1.0f);
        Ray ray_sensor;
        int depth = 1;
        Spectrum result(0.0);
        Matrix2x4 pixel_uvs;
        Array4 attenuations(0.0);
        Vector dir;
        if (scene.isVisible(its.p, true, scene.camera.cpos, true))
        {
            scene.camera.sampleDirect(its.p, pixel_uvs, attenuations, dir);
            BSDF* replacement_bsdf = get_new_capped_bsdf(its.ptr_bsdf);
            // auto bsdf_val = its.evalBSDF(its.toLocal(dir),
            //                              EBSDFMode::EImportanceWithCorrection);
            auto bsdf_val = replacement_bsdf->eval(its, its.toLocal(dir),
                                         EBSDFMode::EImportanceWithCorrection);
            delete replacement_bsdf;
            if (attenuations.maxCoeff() < Epsilon)
            {
                return 0.0;
            }
            else
            {
                size_t matched_indices[NUM_NEAREST_NEIGHBORS];

                Float pt_rad[3] = {its.p[0], its.p[1], its.p[2]};
                Float matched_r2_rad;

                int num_nearby_rad = query_PhotonMap(rad_indices, pt_rad, matched_indices, matched_r2_rad);
                Spectrum photon_radiances(0.0);
                for (int m = 0; m < num_nearby_rad; m++)
                {
                    const RadImpNode &node = rad_nodes[matched_indices[m]];
                    if (node.depth <= max_bounces)
                        photon_radiances += node.val;
                }

                result += (value * photon_radiances).maxCoeff() * baseValue / eRec.pdf / matched_r2_rad;
            }
        }

        return result.abs().maxCoeff();
    }


    static DebugInfo debugInfo;

    bool test_Edge_Silhouette(const Scene &scene, int edge_idx, int shape_idx, const Vector &ray_dir) {
        const Shape *shape = scene.shape_list[shape_idx];
        const Edge &edge = shape->edges[edge_idx];

        Vector n0 = shape->getFaceNormal(edge.f0);
        Vector n1;
        if (edge.f1 >= 0) {
            n1 = shape->getFaceNormal(edge.f1);
        } else {
            n1 = -n0;
        }
        if (edge.mode == 0 || (edge.mode != -1 && n0.dot(ray_dir) * n1.dot(ray_dir) < 0)){
            return true;
        }
        return false;
    }

    bool test_Mutation_Validity(const Scene &scene, const Vector &rnd0, const Vector &rnd1,
                                const DiscreteDistribution &edge_dist,
                                const std::vector<Vector2i> &edge_indices, const std::vector<int> &mutation_path) {
        BoundarySamplingRecord eRec0, eRec1;
        // scene.sampleEdgePoint(rnd0[0],
        //                       edge_dist, edge_indices,
        //                       eRec0);
        // scene.sampleEdgePoint(rnd1[0],
        //                       edge_dist, edge_indices,
        //                       eRec1);
        
        // assert(eRec0.shape_id == eRec1.shape_id);
        // Vector2 rnd_light0(rnd0[1], rnd0[2]);
        // Vector2 rnd_light1(rnd1[1], rnd1[2]);
        // DirectSamplingRecord dRec0(eRec0.ref), dRec1(eRec1.ref); // emitter point before & after mutation
        // scene.sampleEmitterDirect(rnd_light0, dRec0);
        // scene.sampleEmitterDirect(rnd_light1, dRec1);
        sampleEdgeRay(scene, rnd0, edge_dist, edge_indices, eRec0);

        sampleEdgeRay(scene, rnd1, edge_dist, edge_indices, eRec1);

        Vector dir0 = eRec0.dir;
        Vector dir1 = eRec1.dir;
        // PSDR_INFO("rnd0: {}, {}, {}", rnd0[0], rnd0[1], rnd0[2]);
        // PSDR_INFO("rnd1: {}, {}, {}", rnd1[0], rnd1[1], rnd1[2]);

        // PSDR_INFO("dir0: {}, {}, {}", dir0[0], dir0[1], dir0[2]);
        // PSDR_INFO("dir1: {}, {}, {}", dir1[0], dir1[1], dir1[2]);
        // for (int i = 0; i < mutation_path.size(); i++) {
        //     int edge_idx = mutation_path[i];
        //     const Shape *shape = scene.shape_list[eRec0.shape_id];
        //     const Edge &edge = shape->edges[edge_idx];
        //     PSDR_INFO("face: {}, {}", edge.f0, edge.f1);
        //     Vector n0 = shape->getFaceNormal(edge.f0);
        //     Vector n1 = shape->getFaceNormal(edge.f1);
        //     PSDR_INFO("n0: {}, {}, {}", n0[0], n0[1], n0[2]);
        //     PSDR_INFO("n1: {}, {}, {}", n1[0], n1[1], n1[2]);
        // }

        for (int i = 0; i < mutation_path.size(); i++) {
            int edge_idx = mutation_path[i];
            if (!test_Edge_Silhouette(scene, edge_idx, eRec0.shape_id, dir0) || !test_Edge_Silhouette(scene, edge_idx, eRec0.shape_id, dir1)) {
                // PSDR_INFO("edge {}: {} failed silhouette test", i, edge_idx);
                // Shape *shape = scene.shape_list[eRec0.shape_id];
                return false;
            }
        }
        return true;
    }

    [[maybe_unused]] void velocity(const Scene &scene,
                                   const BoundarySamplingRecord &bRec,
                                   Float &res)
    {
        const Shape *shape = scene.shape_list[bRec.shape_id];
        const Edge &edge = shape->edges[bRec.edge_id];
        const Vector &xB_0 = shape->getVertex(edge.v0);
        const Vector &xB_1 = shape->getVertex(edge.v1);
        const Vector &xB_2 = shape->getVertex(edge.v2);

        const Shape *shapeS = scene.shape_list[bRec.shape_id_S];
        const auto &indS = shapeS->getIndices(bRec.tri_id_S);
        const Vector &xS_0 = shapeS->getVertex(indS[0]);
        const Vector &xS_1 = shapeS->getVertex(indS[1]);
        const Vector &xS_2 = shapeS->getVertex(indS[2]);

        const Shape *shapeD = scene.shape_list[bRec.shape_id_D];
        const auto &indD = shapeD->getIndices(bRec.tri_id_D);
        const Vector &xD_0 = shapeD->getVertex(indD[0]);
        const Vector &xD_1 = shapeD->getVertex(indD[1]);
        const Vector &xD_2 = shapeD->getVertex(indD[2]);

        res = normal_velocity(xS_0, xS_1, xS_2,
                              xB_0, xB_1, xB_2, bRec.t, bRec.dir,
                              xD_0, xD_1, xD_2);
    }

    void d_velocity(const Scene &scene, Scene &d_scene,
                    EdgeSamplingRecord &eRec,
                    Float d_u)
    {
        [[maybe_unused]] Float u;
#if defined(ENZYME) && defined(ENZYME_BOUNDARY_DIRECT)
        __enzyme_autodiff((void *)velocity,
                          enzyme_dup, &scene, &d_scene,
                          enzyme_const, &eRec,
                          enzyme_dupnoneed, &u, &d_u);
#endif
    }

    inline void handleSurfaceInteraction(const Intersection &its,
                                         const Scene &scene, Scene &d_scene,
                                         EdgeSamplingRecord &eRec,
                                         RndSampler &sampler, const Spectrum &weight,
                                         std::vector<Spectrum> &d_image)
    {
#if defined(FORWARD)
        int shape_idx = scene.getShapeRequiresGrad();
#endif
        CameraDirectSamplingRecord cRec;
        if (!scene.camera.sampleDirect(its.p, cRec))
            return;
        if (!scene.isVisible(its.p, true, scene.camera.cpos, true))
            return;
        auto [pixel_idx, sensor_val] = scene.camera.sampleDirectPixel(cRec, sampler.next1D());
        if (sensor_val < Epsilon)
            return;
        Spectrum d_u_spec = d_image[pixel_idx] * weight * sensor_val;
        Float d_u = d_u_spec[0] + d_u_spec[1] + d_u_spec[2];
#ifdef FORWARD
        d_scene.shape_list[shape_idx]->param = 0.;
        Float param = d_scene.shape_list[shape_idx]->param;
#endif
        d_velocity(scene, d_scene, eRec, d_u);
#ifdef FORWARD
        param = d_scene.shape_list[shape_idx]->param - param;
        const int tid = omp_get_thread_num();
        debugInfo.image_per_thread[tid][pixel_idx] += Spectrum(param, 0, 0);
#endif
    }

    void d_sampleDirectBoundary(const Scene &scene, Scene &d_scene,
                                const MALA::MALAVector &rnd, RndSampler &sampler, int max_bounces,
                                const DiscreteDistribution &edge_dist,
                                const std::vector<Vector2i> &edge_indices, const Spectrum &weight,
                                std::vector<Spectrum> &d_image)
    {
        /* Sample a point on the boundary */
        BoundarySamplingRecord eRec;
        Vector rnd0 = MALA::pss_get(rnd, 0);
        sampleEdgeRay(scene, rnd0, edge_dist, edge_indices, eRec);
        if (eRec.shape_id == -1)
        {
            PSDR_WARN(eRec.shape_id == -1);
            return;
        }
        const Shape *shape = scene.shape_list[eRec.shape_id];
        const Edge &edge = shape->edges[eRec.edge_id];
        assert(edge.f0 >= 0);

        Ray edgeRay(eRec.ref, eRec.dir);
        Intersection itsS, its;
        if (!scene.rayIntersect(edgeRay, true, itsS) ||
            !scene.rayIntersect(edgeRay.flipped(), true, its))
            return;
        if (!itsS.ptr_shape->isEmitter())
            return;
        /* Sample point on emitters */
        // DirectSamplingRecord dRec(eRec.ref);
        Spectrum value = itsS.Le(-edgeRay.dir);
        const Vector xB = eRec.ref,
                     &xS = itsS.p;
        // Ray ray(xB, (xB - xS).normalized());
        // Intersection its;
        // if (!scene.rayIntersect(ray, true, its))
        //     return;

        // sanity check
        // make sure the ray is tangent to the surface
        // if (edge.f0 >= 0 && edge.f1 >= 0)
        // {
        //     Vector n0 = shape->getGeoNormal(edge.f0),
        //            n1 = shape->getGeoNormal(edge.f1);
        //     Float dotn0 = ray.dir.dot(n0),
        //           dotn1 = ray.dir.dot(n1);
        //     if (math::signum(dotn0) * math::signum(dotn1) > -0.5)
        //         return;
        // }
        // NOTE prevent intersection with a backface
        // Float gnDotD = its.geoFrame.n.dot(-ray.dir);
        // Float snDotD = its.shFrame.n.dot(-ray.dir);
        // bool success = (its.ptr_bsdf->isTransmissive() && math::signum(gnDotD) * math::signum(snDotD) > 0.5f) ||
        //                (!its.ptr_bsdf->isTransmissive() && gnDotD > 0.01 && snDotD > 0.01);
        // if (!success)
        //     return;
        // populate the data in BoundarySamplingRecord eRec
        eRec.dir = edgeRay.dir;
        eRec.shape_id_S = itsS.indices[0];
        eRec.tri_id_S = itsS.indices[1];
        eRec.shape_id_D = its.indices[0];
        eRec.tri_id_D = its.indices[1];

        /* Jacobian determinant that accounts for the change of variable */
        Vector v0 = shape->getVertex(edge.v0);
        Vector v1 = shape->getVertex(edge.v1);
        Vector v2 = shape->getVertex(edge.v2);
        const Vector &xD = its.p;
        Vector n = (v0 - v1).cross(-edgeRay.dir).normalized();
        n *= -math::signum(n.dot(v2 - v0)); // make sure n points to the visible side
        Float J = dlD_dlB(xS,
                          xB, (v0 - v1).normalized(),
                          xD, its.geoFrame.n) *
                  dA_dw(xB, xS, itsS.geoFrame.n);
        Float baseValue = J * geometric(xD, its.geoFrame.n, xS, itsS.geoFrame.n);
        if (std::abs(baseValue) < Epsilon)
            return;
        // assert(baseValue > -Epsilon);

        /* Sample detector path */
        Spectrum throughput(1.0f);
        Ray ray_sensor;
        // PSDR_INFO("baseValue: {}", baseValue);
        for (int i = 0; i < max_bounces; i++)
        {
            // PSDR_INFO("value {}: {}, {}, {}", i+1, value[0], value[1], value[2]);
            if (i == max_bounces - 1){
                handleSurfaceInteraction(its, scene, d_scene,
                                        eRec, sampler, weight,
                                        d_image);
            }
            if (i == max_bounces - 1)
                break;
            Vector wo_local, wo;
            Float bsdf_pdf, bsdf_eta;
            Vector rnd_i = MALA::pss_get(rnd, i+1);
            Spectrum bsdf_weight = its.sampleBSDF(rnd_i, wo_local,
                                                  bsdf_pdf, bsdf_eta,
                                                  EBSDFMode::EImportanceWithCorrection);
            if (bsdf_weight.isZero())
                break;
            wo = its.toWorld(wo_local);
            Vector wi = its.toWorld(its.wi);
            Float wiDotGeoN = wi.dot(its.geoFrame.n), woDotGeoN = wo.dot(its.geoFrame.n);
            if (wiDotGeoN * its.wi.z() <= 0 || woDotGeoN * wo_local.z() <= 0)
                break;
            value *= bsdf_weight;
            ray_sensor = Ray(its.p, wo);
            scene.rayIntersect(ray_sensor, true, its);
            if (!its.isValid())
                break;
        }
    }

    void RD_BSDF_test(const RoughDielectricBSDF &rd, const Intersection& its, const Vector &rnd, Spectrum &value) {
        Vector wo_local;
        Float pdf, eta;
        value = rd.sample(its, rnd, wo_local, pdf, eta, EBSDFMode::ERadiance);
        // value = rd.eval(its, wo);
    }

    void RD_BSDF_test_echo(const RoughDielectricBSDF &rd, const Intersection& its, const Vector &rnd, Spectrum &value) {
        Vector wo_local;
        Float pdf, eta;
        value = rd.sample(its, rnd, wo_local, pdf, eta, EBSDFMode::ERadiance);
        PSDR_INFO("wo_local: {}, {}, {}", wo_local[0], wo_local[1], wo_local[2]);
        PSDR_INFO("value: {}, {}, {}", value[0], value[1], value[2]);
        // value = rd.eval(its, wo);
    }

    void d_RD_BSDF_test(const RoughDielectricBSDF &rd, const Intersection& its,
                        const Vector &rnd, Vector &d_rnd,
                        Spectrum &d_value) {
        [[maybe_unused]] Spectrum value = Spectrum::Zero();
#if defined(ENZYME) && defined(ENZYME_BOUNDARY_DIRECT)
        __enzyme_autodiff((void *)RD_BSDF_test,
                          enzyme_const, &rd,
                          enzyme_const, &its,
                          enzyme_dup, &rnd, &d_rnd,
                          enzyme_dup, &value, &d_value
                          );
#endif
    }
}

ArrayXd MetropolisDirectEdgeIntegrator::diff_bsdf_test(const Float &roughness, const Float &theta0, const Float &theta1) const {
    // <spectrum name="eta" value="0.200438, 0.924033, 1.10221"/>
    // <spectrum name="k" value="3.91295, 2.45285, 2.14219"/>
    RoughDielectricBSDF rd = RoughDielectricBSDF(roughness, 1.5, 1.0);
    Intersection its;
    its.wi = Vector(0.0, sin(theta0), cos(theta0));
    Vector rnd = Vector(0.5, 0.4, 0.3);
    Spectrum value = Spectrum::Zero();
    RD_BSDF_test_echo(rd, its, rnd, value);
    Spectrum d_value = Spectrum(1.0, 1.0, 1.0);
    Vector d_rnd = Vector::Zero();
    Vector fd_d_rnd = Vector::Zero();
    for (int i = 0; i < 3; i++) {
        Vector rnd_i = rnd;
        rnd_i[i] += 0.0001;
        Spectrum value_i = Spectrum::Zero();
        RD_BSDF_test_echo(rd, its, rnd_i, value_i);
        fd_d_rnd[i] = (value_i - value).sum() / 0.0001;
    }
    d_RD_BSDF_test(rd, its, rnd, d_rnd, d_value);
    ArrayXd output_array(9);
    output_array << d_rnd[0], d_rnd[1], d_rnd[2], fd_d_rnd[0], fd_d_rnd[1], fd_d_rnd[2]
                    , value[0], value[1], value[2];
    return output_array;

}
ArrayXd MetropolisDirectEdgeIntegrator::get_sample_vol(const Scene &scene, const Vector3i size, const Vector &min_bound, const Vector &max_bound) {
    int size_x = size[0];
    int size_y = size[1];
    int size_z = size[2];
    ArrayXd vol_array(size_x * size_y * size_z);
    const int nworker = omp_get_num_procs();
    // Fill the tensor using get_1b_func
    int blockProcessed = 0;
    Timer _("Sample Space Density");
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < size_x; ++i) {
        for (int j = 0; j < size_y * size_z; ++j) {
            // (xpos, ypos, zpos
            int idx = i * size_y * size_z + j;
            int xpos = idx % size_x;
            int ypos = (idx / size_x) % size_y;
            int zpos = idx / (size_x * size_y);
            Vector rnd = Vector::Zero();
            rnd[0] = (xpos + 0.5) / size_x;
            rnd[1] = (ypos + 0.5) / size_y;
            rnd[2] = (zpos + 0.5) / size_z;
            rnd = min_bound + (max_bound - min_bound).cwiseProduct(rnd);
            MALA::MALAVector mala_rnd(3);
            mala_rnd[0] = rnd[0];
            mala_rnd[1] = rnd[1];
            mala_rnd[2] = rnd[2];
            // vol_array(idx) = func_1b(scene, (xpos + 0.5) / size_x, (ypos + 0.5) / size_y, (zpos + 0.5) / size_z);
            Spectrum contrib = algorithm1_MALA_direct::eval(scene, mala_rnd, 1, edge_dist, edge_indices);
            vol_array(idx) = contrib[0] + contrib[1] + contrib[2];
        }
#pragma omp critical
        progressIndicator(static_cast<Float>(++blockProcessed) / size_x);
    }

    return vol_array;
}
ArrayXd MetropolisDirectEdgeIntegrator::get_sample_slice(const Scene &scene, int axis, Float u0, int size_1, int size_2) {
    ArrayXd vol_array(size_1 * size_2);
    
    const int nworker = omp_get_num_procs();
    // Fill the tensor using get_1b_func
    int blockProcessed = 0;
    Timer _("Sample Space Density");
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < size_1; ++i) {
        for (int j = 0; j < size_2; ++j) {
            // (xpos, ypos, zpos
            int idx = i * size_2 + j;
            int u1 = idx % size_1;
            int u2 = idx / size_1;
            Vector rnd;
            rnd[axis] = u0;
            rnd[(axis + 1) % 3] = (u1 + 0.5) / size_1;
            rnd[(axis + 2) % 3] = (u2 + 0.5) / size_2;
            // vol_array(idx) = func_1b(scene, rnd[0], rnd[1], rnd[2]);
        }
#pragma omp critical
        progressIndicator(static_cast<Float>(++blockProcessed) / size_1);
    }

    return vol_array;
}

ArrayXd MetropolisDirectEdgeIntegrator::renderD(
    SceneAD &sceneAD, RenderOptions &options, const ArrayXd &__d_image) const
{
    Scene tmp_der = sceneAD.der;
    GradientManager<Scene> gm_doll(tmp_der, omp_get_num_procs());

    PSDR_INFO("MLTEdgeIntegrator::renderD with spp = {}",
              options.num_samples_secondary_edge_direct);
    const Scene &scene = sceneAD.val;
    [[maybe_unused]] Scene &d_scene = sceneAD.der;
    GradientManager<Scene> &gm = sceneAD.gm;
    gm.setZero(); // zero multi-thread gradient

    const int nworker = omp_get_num_procs();
    const auto &camera = scene.camera;
    const int nsamples = options.num_samples_secondary_edge_direct;
    const int nblocks = std::ceil(static_cast<Float>(camera.getNumPixels()) / (options.block_size * options.block_size));
    const int nblock_samples = options.block_size * options.block_size * nsamples;
    /* init debug info */
    debugInfo = DebugInfo(nworker, camera.getNumPixels(), nsamples);
    if (nsamples <= 0)
        return debugInfo.getArray();

    Timer _("Direct boundary");

/* --------------------------------------- Phase 1 ---------------------------------------*/

    int blockProcessed = 0;
    int phase1_samples = 1<<24;
    Spectrum phase1_sum = Spectrum::Zero();
    std::vector<MALA::MALAVector> init_samples;
    std::vector<MALA::MALAVector> candidates;
    std::vector<Spectrum> weights;
    std::vector<Spectrum> init_sample_weights;

    candidates.resize(phase1_samples);
    weights.resize(phase1_samples);
    init_samples.resize(mala_config.num_chains);
    init_sample_weights.resize(mala_config.num_chains);
    int valid_sample = 0;
    int nblocks_phase1 = std::ceil(static_cast<Float>(phase1_samples) / (options.block_size * options.block_size));
    int nsample_per_block = phase1_samples / nblocks_phase1;
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < nblocks_phase1; i++)
    {
        RndSampler sampler(options.seed, i);
        Spectrum thread_sum = Spectrum::Zero();
        int thread_valid_sample = 0;
        for (int j = 0; j < nsample_per_block; j++)
        {
            int idx = i * nsample_per_block + j;
            int n_bounces = sampleInt(sampler.next1D(), options.max_bounces);
            Float n_pdf = 1.0 / options.max_bounces;
            if (idx >= phase1_samples)
                break;
            Float pdf = 1.0;
            MALA::MALAVector rnd(n_bounces * 3);
            Vector d_rnd = sampler.next3D();
            if (grid_distrb.distrb.getSum() > Epsilon)
            {
                d_rnd = grid_distrb.sample(d_rnd, pdf);
            }
            if (aq_distrb.distrb.getSum() > Epsilon)
            {
                d_rnd = aq_distrb.sample(d_rnd, pdf);
            }
            MALA::pss_set(rnd, 0, d_rnd);
            for (int j = 1; j < n_bounces; j++){
                MALA::pss_set(rnd, j, sampler.next3D());
            }
            candidates[idx] = rnd;
            Spectrum contrib = algorithm1_MALA_direct::eval(scene, rnd, n_bounces, edge_dist, edge_indices) / n_pdf;
            assert(!std::isnan(contrib[1]));
            weights[idx] = contrib / pdf; // contrib/1.0
            thread_sum += contrib / pdf;
            assert(!std::isnan(thread_sum[1]));
            if (abs(contrib[0]) + abs(contrib[1]) + abs(contrib[2]) > 0.0){
                thread_valid_sample++;
            }
        }
#pragma omp critical
        {
            assert(!std::isnan(thread_sum[1]));
            phase1_sum[0] += thread_sum[0];
            phase1_sum[1] += thread_sum[1];
            phase1_sum[2] += thread_sum[2];
            assert(!std::isnan(phase1_sum[1]));
            valid_sample += thread_valid_sample;
        }
    }
    RndSampler sampler_wrs(options.seed, 0);
    PSDR_INFO("valid_sample: {}, num_chains: {}", valid_sample, mala_config.num_chains);
    // if (valid_sample < mala_config.num_chains){
    //     assert(false);
    // }
    while (valid_sample < mala_config.num_chains) {
        int i = valid_sample;
        int n_bounces = sampleInt(sampler_wrs.next1D(), options.max_bounces);
        Float n_pdf = 1.0 / options.max_bounces;
        Float pdf = 1.0;
        MALA::MALAVector rnd(n_bounces * 3);
        Vector d_rnd = sampler_wrs.next3D();
        if (grid_distrb.distrb.getSum() > Epsilon)
        {
            d_rnd = grid_distrb.sample(d_rnd, pdf);
        }
        if (aq_distrb.distrb.getSum() > Epsilon)
        {
            d_rnd = aq_distrb.sample(d_rnd, pdf);
        }
        MALA::pss_set(rnd, 0, d_rnd);
        for (int j = 1; j < n_bounces; j++){
            MALA::pss_set(rnd, j, sampler_wrs.next3D());
        }
        Spectrum contrib = algorithm1_MALA_direct::eval(scene, rnd, n_bounces, edge_dist, edge_indices) / n_pdf;
        // eval_DirectBoundary(d_rnd, scene, 1, edge_dist, edge_indices, false, contrib);
        if (contrib[0] + contrib[1] + contrib[2] > 0.0){
            candidates.push_back(rnd);
            weights.push_back(contrib / pdf); // contrib/1.0
            valid_sample++;
            progressIndicator(static_cast<Float>(valid_sample) / mala_config.num_chains);
        }
    }
    std::cout << std::endl;

    auto comp = [](const std::pair<Float, MALA::MALAVector>& a, const std::pair<Float, MALA::MALAVector>& b) {
        return a.first > b.first; // Return true if a is greater than b
    };
    Float total_weight = 0.0;
    std::priority_queue<std::pair<Float, MALA::MALAVector>, std::vector<std::pair<Float, MALA::MALAVector>>, decltype(comp)> pq(comp);
    for (int i = 0; i < candidates.size(); i++){
        total_weight += weights[i][0] + weights[i][1] + weights[i][2];
        Float r = sampler_wrs.next1D() * weights[i][0] + weights[i][1] + weights[i][2];
        if (pq.size() < init_samples.size()){
            pq.push(std::make_pair(r, candidates[i]));
        }
        else {
            if (pq.top().first < r){
                pq.pop();
                pq.push(std::make_pair(r, candidates[i]));
            }
        }
    }
    Float normalize_f = 0.0;
    while(!pq.empty()){
        init_samples[pq.size() - 1] = pq.top().second;
        init_sample_weights[pq.size() - 1] = pq.top().first;
        normalize_f += pq.top().first;
        pq.pop();
    }
    for (int i = 0; i < init_sample_weights.size(); i++){
        init_sample_weights[i] = Spectrum(total_weight / candidates.size());
    }
    assert(!std::isnan(phase1_sum[1]));

    Spectrum phase1_mean = phase1_sum / (Float)phase1_samples;
    PSDR_INFO("phase 1 sum: {}, {}, {}", phase1_sum[0], phase1_sum[1], phase1_sum[2]);
    PSDR_INFO("phase 1: {}, {}, {}", phase1_mean[0], phase1_mean[1], phase1_mean[2]);


    // finite difference test:
    MALA::MALAVector candidate0 = init_samples[0];
    MALA::MALAVector d_candidate_fd = candidate0;
    algorithm1_MALA_direct::LightPathPSS pss(candidate0.size());
    PSDR_INFO("candidate0.size: {}", candidate0.size());
    PSDR_INFO("candidate0: {}, {}, {}", candidate0[0], candidate0[1], candidate0[2]);
    Spectrum f0 = algorithm1_MALA_direct::eval(scene, candidate0, options.max_bounces, edge_dist, edge_indices, &pss);
    algorithm1_MALA_direct::LightPathPSSAD pssAD(pss);
    PSDR_INFO("f0: {}", f0.sum());
    algorithm1_MALA_direct::d_eval(scene, gm_doll.get(0), pssAD);
    MALA::MALAVector d_candidate = pssAD.der.pss_state;
    for (int i = 0; i < candidate0.size(); i++) {
        MALA::MALAVector candidate0_ = candidate0;
        candidate0_[i] += 1e-6;
        Spectrum f = algorithm1_MALA_direct::eval(scene, candidate0_, options.max_bounces, edge_dist, edge_indices);
        d_candidate_fd[i] = (f[0] + f[1] + f[2] - f0[0] - f0[1] - f0[2]) / 1e-6;
    }
    PSDR_INFO("d_candidate: {}, {}, {}", d_candidate[0], d_candidate[1], d_candidate[2]);
    PSDR_INFO("d_candidate_fd: {}, {}, {}", d_candidate_fd[0], d_candidate_fd[1], d_candidate_fd[2]);
    return debugInfo.getArray();


    std::vector<Spectrum> _d_image_spec_list = from_tensor_to_spectrum_list(
        __d_image / mala_config.num_chains / mala_config.num_samples, camera.getNumPixels());

    // for (int i = 0; i < _d_image_spec_list.size(); i++){
    //     _d_image_spec_list[i] *= phase1_mean;
    // }

/* --------------------------------------- Phase 2 ---------------------------------------*/
    blockProcessed = 0;
    int cache_size = 8000;
    MALA::Mutation* mutations[nworker];
    // MutationHybrid* mutations_ref[nworker];
    for (int i = 0; i < nworker; i++){
        switch (mala_config.mode) {
            case 0:
                mutations[i] = new MALA::MutationDiminishing();
                break;
            case 1:
                mutations[i] = new MALA::MutationCacheBased();
                break;
            case 2:
                mutations[i] = new MALA::MutationHybrid();
                break;
            default:
                mutations[i] = new MALA::MutationDiminishing();
                break;
        }
        // mutations_ref[i] = new MutationHybrid();
    }
    MALA::KNNCache* cache[nworker];
    // MALA::GridCache* grid_cache[nworker];
    // int max_cache = 2000;
    for (int i = 0; i < nworker; i++){
        cache[i] = new MALA::KNNCache(options.max_bounces * 3, cache_size);
        // grid_cache[i] = new MALA::GridCache(cache_size);
    }
    int accepted = 0;
    int burn_in = mala_config.burn_in;
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < mala_config.num_chains; i++)
    {
        const int tid = omp_get_thread_num();
        int thread_accepted = 0;
        RndSampler sampler(options.seed, i);
        MALA::Mutation *mutation = mutations[tid];
        // MutationHybrid *mutation_ref = mutations_ref[tid];
        // PSDR_INFO("chain start: tid {}", tid);
        mutation->setZero();
        MALA::KNNCache *cache_ptr = cache[tid];
        // MALA::GridCache *grid_cache_ptr = grid_cache[tid];

// #pragma omp critical
//         {
//             if (!global_cache->write) {
//                 cache_ptr->copy(*global_cache);
//             }
//         }
        Float step_length = mala_config.step_length;

        Float A = 0.0;
        MALA::MALAVector current_u = init_samples[i];
        int n_bounces = current_u.size() / 3 ;
        // assert(n_bounces == 2);
        for (int j = 0; j < mala_config.num_samples + burn_in; j++)
        {
            // current_u = sampler.next3D();
            // Float weight;
            // eval_DirectBoundary(current_u, scene, 1, edge_dist, edge_indices, true, weight);
            // d_sampleDirectBoundary(scene, gm.get(tid), sampler,
            //                         edge_dist, edge_indices, 1.0,
            //                         _d_image_spec_list, current_u);
            // PSDR_INFO("chain: {}, sample: {}", i, j);
            Float p_global = sampler.next1D();
            MALA::MALAVector proposed_u(n_bounces * 3);
            proposed_u.setZero();
            // PSDR_INFO("here1");
            if (p_global < mala_config.p_global) {
                A = mutateLargeStep(scene, &sampler, n_bounces, current_u, proposed_u);
                mutation->setZero();
                // sample a point on the boundary
            } else {
                A = mutateSmallStep(scene, gm_doll.get(tid), &sampler, n_bounces,
                            *mutation, *cache_ptr,
                            current_u, proposed_u);
            }
            // PSDR_INFO("here2");
            Float a = sampler.next1D();
            if (a < A)
            {
                thread_accepted++;
                current_u = proposed_u;
            }
            if (j > burn_in){
                if (j % mala_config.thinning == 0){
                    d_sampleDirectBoundary(scene, gm.get(tid), current_u, sampler, n_bounces,
                                        edge_dist, edge_indices, init_sample_weights[i],
                                        _d_image_spec_list);
                }
            }
            // PSDR_INFO("here3");
        }
        // PSDR_INFO("local cache size: {}", cache_ptr->size());
        // PSDR_INFO("chain end");
//         if (global_cache->write){
// #pragma omp critical
//             {
//                 global_cache->merge(*cache_ptr);
//                 if (global_cache->size() >= cache_size){
//                     global_cache->write = false;
//                 }
//             }
//         }
//         delete cache_ptr;

        if (verbose){
#pragma omp critical
            {
                progressIndicator(static_cast<Float>(++blockProcessed) / mala_config.num_chains);
                accepted += thread_accepted;
            }
            // PSDR_INFO("acceptance rate: {}", Float(thread_accepted) / (10000));
        }
    }
    if (verbose){
        std::cout << std::endl;
        PSDR_INFO("acceptance rate: {}", Float(accepted) / (mala_config.num_chains * mala_config.num_samples));
    }
    for (int i = 0; i < nworker; i++){
        delete mutations[i];
        delete cache[i];
    }

    // merge gradient to d_scene
    gm_doll.setZero();
    gm.merge();
    d_scene.configureD(scene);
    /* normal related */
#ifdef NORMAL_PREPROCESS
    d_precompute_normal(scene, d_scene);
#endif
    // delete global_cache;

    return flattened(debugInfo.getArray());
}

Float MetropolisDirectEdgeIntegrator::mutateLargeStep(const Scene &scene, RndSampler *sampler, int max_bounces, const MALA::MALAVector &current, MALA::MALAVector &proposal) const { // start a new chain
    Vector proposal0 = sampler->next3D();
    Float proposal_pdf = 1.0, current_pdf = 1.0;
    if (grid_distrb.distrb.getSum() > Epsilon)
    {
        Vector current0 = MALA::pss_get(current, 0);
        current_pdf = grid_distrb.query_pdf(current0);
        proposal0 = grid_distrb.sample(proposal0, proposal_pdf);
    }
    MALA::pss_set(proposal, 0, proposal0);
    for (int i = 1; i < max_bounces; i++){
        MALA::pss_set(proposal, i, sampler->next3D());
    }
    // else if (aq_distrb.distrb.getSum() > Epsilon)
    // {
    //     current_pdf = aq_distrb.query_pdf(current);
    //     proposal = aq_distrb.sample(proposal, proposal_pdf);
    // }
    Spectrum contrib_cur_spec = algorithm1_MALA_direct::eval(scene, current, max_bounces, edge_dist, edge_indices);
    Spectrum contrib_prop_spec = algorithm1_MALA_direct::eval(scene, proposal, max_bounces, edge_dist, edge_indices);
    // eval_DirectBoundary(current, scene, 1, edge_dist, edge_indices, false, contrib_cur_spec);
    // eval_DirectBoundary(proposal, scene, 1, edge_dist, edge_indices, false, contrib_prop_spec);
    Float contrib_cur = contrib_cur_spec[0] + contrib_cur_spec[1] + contrib_cur_spec[2];
    Float contrib_prop = contrib_prop_spec[0] + contrib_prop_spec[1] + contrib_prop_spec[2];
    Float A = contrib_prop / contrib_cur / proposal_pdf * current_pdf;
    A = clamp(A, 0.0, 1.0);
    return A; // will be accepted
}

Float MetropolisDirectEdgeIntegrator::mutateSmallStep(const Scene &scene, Scene &d_scene, RndSampler *sampler, int max_bounces,
                            MALA::Mutation &mutation, MALA::KNNCache &cache,
                            const MALA::MALAVector &current, MALA::MALAVector &proposal) const {
    // if (reuse) {
    //     current = proposal;
    // } else {

    MALA::PSS_State current_state(max_bounces * 3);
    MALA::PSS_State proposal_state(max_bounces * 3);
    current_state.u = current;
    // current_state.u[0] = current[1];
    // current_state.u[1] = current[2];
    algorithm1_MALA_direct::EdgeBound bound;
    algorithm1_MALA_direct::LightPathPSS path(max_bounces * 3);
    Spectrum fu_spec(0.0f);
    fu_spec = algorithm1_MALA_direct::eval(scene, current, max_bounces, edge_dist, edge_indices, &path);
    // eval_DirectBoundary(current, scene, 1, edge_dist, edge_indices, false, fu_spec, &bound);
    bound = path.bound;
    int edge_id_dist = edge_indices_inv[bound.shape_idx][bound.edge_idx];
    bound.min = edge_dist.m_cdf[edge_id_dist];
    bound.max = edge_dist.m_cdf[edge_id_dist + 1];
    current_state.f_u = fu_spec[0] + fu_spec[1] + fu_spec[2];
    // gaussian_test(current_state.u, current_state.f_u);
    if (current_state.f_u == 0){
        assert(false);
    }
    Float log_pdf_uv = 0;
    Float log_pdf_vu = 0;

    MALA::MALAVector m_u, M_u;
    // if (cache.write || !mutation.step_readonly(cache, current_state.u, Vector(1.0, 1.0, 1.0), m_u, M_u)){
    // Vector d_u_ = d_eval_DirectBoundary(current, scene, d_scene, 1, edge_dist, edge_indices);
    algorithm1_MALA_direct::LightPathPSSAD pathAD(path);
    algorithm1_MALA_direct::d_eval(scene, d_scene, pathAD);
    MALA::MALAVector d_u = pathAD.der.pss_state;
    current_state.g = d_u / current_state.f_u;
    mutation.step(cache, current_state.u, current_state.g, m_u, M_u, 0);
    // } // else: cache full, step_r returns some valid value, no need for gradient calculation.
    MALA::Gaussian gaussian_uv;
    MALA::ComputeGaussian(mala_config.step_length, m_u, M_u, path.discrete_dim, gaussian_uv);
    // Vector w = sampler->next2D();
    proposal_state.u = gaussian_uv.GenerateSample(sampler, max_bounces * 3).cwiseProduct(1.0 - path.discrete_dim) + current_state.u;
    MALA::MALAVector offset = proposal_state.u - current_state.u;
    log_pdf_uv = gaussian_uv.GaussianLogPdf(offset, path.discrete_dim);
    for (int i = 1; i < max_bounces * 3; i++){
        if (proposal_state.u[i] < 0.0 || proposal_state.u[i] > 1.0){
            return 0.0; // reject
        }
    }
    if (proposal_state.u[0] < bound.min || proposal_state.u[0] > bound.max){
        // return 0.0;
        std::vector<int> mutation_path;
        Float rnd = sampler->next1D();
        Float distance = abs(offset[0]);
        Shape* shape = scene.shape_list[bound.shape_idx];
        Float curr_edgelen;
        int transition_vertex_idx;
        bool v0_based;
        if (proposal_state.u[0] < bound.min) { // walk the first step depending on the direction to walk
            curr_edgelen = abs(current_state.u[0] - bound.min);
            transition_vertex_idx = shape->edges[bound.edge_idx].v0;
            v0_based = true;
        } else {
            curr_edgelen = abs(current_state.u[0] - bound.max);
            transition_vertex_idx = shape->edges[bound.edge_idx].v1;
            v0_based = false;
        }
        int curr_edge_idx = bound.edge_idx;
        int curr_edge_idx_dist = edge_indices_inv[bound.shape_idx][curr_edge_idx];
        mutation_path.push_back(curr_edge_idx);
        int counter = 0;
        while (distance > curr_edgelen){ // mutate to the appropriate edge
            // PSDR_INFO("distance: {}, curr_edgelen: {}", distance, curr_edgelen);
            // PSDR_INFO("curr_edge_idx: {}, transition_vertex_idx: {}", curr_edge_idx, transition_vertex_idx);
            distance -= curr_edgelen;
            counter++;
            // work at the transition vertex
            std::vector<int> valid_edges;
            Vector transition_vertex = shape->vertices[transition_vertex_idx];
            Vector ray_dir = bound.emitter_dir;
            // if (bound.mode == 0){
            //     ray_dir = (transition_vertex - bound.emitter_point).normalized();
            // } else if (bound.mode == 1) {
            // }
            // PSDR_INFO("adjacent edge size: {}", shape->adjacentEdges[transition_vertex_idx].size());
            // PSDR_INFO("current edge: ");
            // test_Edge_Silhouette(scene, curr_edge_idx, bound.shape_idx, ray_dir);
            for (int i = 0; i < shape->adjacentEdges[transition_vertex_idx].size(); i++){
                int edge_idx = shape->adjacentEdges[transition_vertex_idx][i];
                // Edge edge = shape->edges[edge_idx];
                if (edge_idx == curr_edge_idx){
                    continue;
                }
                // Vector n0 = shape->getFaceNormal(edge.f0);
                // Vector n1 = shape->getFaceNormal(edge.f1);
                // if (edge.mode == 0 || (n0.dot(bound.ray_dir) * n1.dot(bound.ray_dir) < 0 && edge.mode != -1)){
                //     valid_edges.push_back(edge_idx);
                // }
                if (test_Edge_Silhouette(scene, edge_idx, bound.shape_idx, ray_dir)){
                    // PSDR_INFO("edge silhouette: {}", edge_idx);
                    // PSDR_INFO("ray_dir: {}, {}, {}", ray_dir[0], ray_dir[1], ray_dir[2]);
                    // Vector vertex0 = shape->vertices[shape->edges[edge_idx].v0];
                    // PSDR_INFO("endpoint0: {}, {}, {}", vertex0[0], vertex0[1], vertex0[2]);
                    // Vector vertex1 = shape->vertices[shape->edges[edge_idx].v1];
                    // PSDR_INFO("endpoint1: {}, {}, {}", vertex1[0], vertex1[1], vertex1[2]);
                    valid_edges.push_back(edge_idx);
                }
            }
            // assert(valid_edges.size() == 1);
            if (valid_edges.size() == 0){
                // reject
                return 0.0; // remain current state
            }
            rnd = sampler->next1D();
            int picked = floor(rnd * valid_edges.size()); // randomly pick an edge at the transition vertex
            // sample reuse
            // rnd = rnd * valid_edges.size() - picked;

            // move to the next transition vertex
            curr_edge_idx = valid_edges[picked];
            curr_edge_idx_dist = edge_indices_inv[bound.shape_idx][curr_edge_idx];
            // PSDR_INFO("picked {}: {}", mutation_path.size(), curr_edge_idx);
            mutation_path.push_back(curr_edge_idx);
            if (transition_vertex_idx == shape->edges[curr_edge_idx].v0){
                transition_vertex_idx = shape->edges[curr_edge_idx].v1;
                v0_based = true;
            }
            else {
                transition_vertex_idx = shape->edges[curr_edge_idx].v0;
                v0_based = false;
            }
            curr_edgelen = edge_dist[curr_edge_idx_dist];
            // curr_edgelen = (shape->vertices[shape->edges[curr_edge_idx].v0] - shape->vertices[shape->edges[curr_edge_idx].v1]).norm();
        }
        Float remaining_dist_pss = (distance);
        if (v0_based) { // walk from v0 to v1
            proposal_state.u[0] = edge_dist.m_cdf[curr_edge_idx_dist] + remaining_dist_pss;
        }
        else {
            proposal_state.u[0] = edge_dist.m_cdf[curr_edge_idx_dist + 1] - remaining_dist_pss;
        }
        if (!(proposal_state.u[0] >= edge_dist.m_cdf[curr_edge_idx_dist] && proposal_state.u[0] <= edge_dist.m_cdf[curr_edge_idx_dist + 1])) {
            assert(false);
        }
        if (!test_Mutation_Validity(scene, MALA::pss_get(current_state.u, 0), MALA::pss_get(proposal_state.u, 0), edge_dist, edge_indices, mutation_path)) {
            // PSDR_INFO("mutation path: {}", mutation_path.size());
            return 0.0; // remain current state
        }
    }
    proposal = proposal_state.u;
    Spectrum fv_spec(0.0f);
    algorithm1_MALA_direct::LightPathPSS path_v(max_bounces * 3);
    // eval_DirectBoundary(proposal_state.u, scene, 1, edge_dist, edge_indices, false, fv_spec);
    fv_spec = algorithm1_MALA_direct::eval(scene, proposal, max_bounces, edge_dist, edge_indices, &path_v);
    proposal_state.f_u = fv_spec[0] + fv_spec[1] + fv_spec[2];
    if (proposal_state.f_u == 0.0){
        // reject
        return 0.0; // remain current state
    }
    if (strcmp(path.type, path_v.type) != 0){
        // reject
        return 0.0; // remain current state
    }
    MALA::MALAVector d_rnd_v(max_bounces * 3);
    d_rnd_v.setZero();

    MALA::MALAVector m_v, M_v;
    algorithm1_MALA_direct::LightPathPSSAD pathAD_v(path_v);
    algorithm1_MALA_direct::d_eval(scene, d_scene, pathAD_v);
    d_rnd_v = pathAD_v.der.pss_state;
    proposal_state.g = d_rnd_v / proposal_state.f_u;
    // proposal_state.g[0] = d_rnd_v[1] / proposal_state.f_u;
    // proposal_state.g[1] = d_rnd_v[2] / proposal_state.f_u;
    mutation.step_hypo(cache, proposal_state.u, proposal_state.g, m_v, M_v, 0);
    MALA::Gaussian gaussian_vu;
    MALA::ComputeGaussian(mala_config.step_length, m_v, M_v, path_v.discrete_dim, gaussian_vu);

    log_pdf_vu = gaussian_vu.GaussianLogPdf(-offset, path_v.discrete_dim);

    Float A = exp(log_pdf_vu - log_pdf_uv) * proposal_state.f_u / current_state.f_u;
    A = clamp(A, 0.0, 1.0);
    return A;
}

// roughness-capped guiding

void MetropolisDirectEdgeIntegrator::preprocess_grid(const Scene &scene, const Grid3D_Sampling::grid3D_config &config, int max_bounces)
{
    PhotonMapOptions opts(10000, 10000, max_bounces);
    std::cout << "[INFO] Direct Guiding: #camPath = " << opts.num_cam_path << std::endl;
    std::vector<RadImpNode> rad_nodes;
    build_PhotonMap_capped(scene, opts.num_cam_path, max_bounces, rad_nodes, false);
    std::cout << "[INFO] Direct Guiding: #rad_nodes = " << rad_nodes.size() << std::endl;

    PointCloud<Float> rad_cloud;
    rad_cloud.pts.resize(rad_nodes.size());
    for (size_t i = 0; i < rad_nodes.size(); i++)
    {
        rad_cloud.pts[i].x = rad_nodes[i].p[0];
        rad_cloud.pts[i].y = rad_nodes[i].p[1];
        rad_cloud.pts[i].z = rad_nodes[i].p[2];
    }
    KDtree<Float> rad_indices(3, rad_cloud, nanoflann::KDTreeSingleIndexAdaptorParams(10));
    rad_indices.buildIndex();
    std::cout << "preprocessing DirectEdgeIntegrator" << std::endl;
    auto NEE_function = [&](const Vector &AQ_rnd, RndSampler &sampler)
    {
        // Float result = eval_DirectBoundary(AQ_rnd, scene, sampler, edge_dist, edge_indices, max_bounces);
        Float result = eval_photon_DirectBoundary_capped(AQ_rnd, scene, sampler, edge_dist, edge_indices, max_bounces, rad_nodes, rad_indices);
        if (!isnan(result))
        {
            return result;
        }
        return 0.0;
    };
    grid_distrb.setup(NEE_function, config);
    std::cout << "finish grid guiding" << std::endl;
}


void MetropolisDirectEdgeIntegrator::preprocess_aq(const Scene &scene, const Adaptive_Sampling::adaptive3D_config &config, int max_bounces)
{
    PhotonMapOptions opts(10000, 10000, max_bounces);
    std::cout << "[INFO] Direct Guiding: #camPath = " << opts.num_cam_path << std::endl;
    std::vector<RadImpNode> rad_nodes;
    build_PhotonMap_capped(scene, opts.num_cam_path, max_bounces, rad_nodes, false);
    std::cout << "[INFO] Direct Guiding: #rad_nodes = " << rad_nodes.size() << std::endl;

    PointCloud<Float> rad_cloud;
    rad_cloud.pts.resize(rad_nodes.size());
    for (size_t i = 0; i < rad_nodes.size(); i++)
    {
        rad_cloud.pts[i].x = rad_nodes[i].p[0];
        rad_cloud.pts[i].y = rad_nodes[i].p[1];
        rad_cloud.pts[i].z = rad_nodes[i].p[2];
    }
    KDtree<Float> rad_indices(3, rad_cloud, nanoflann::KDTreeSingleIndexAdaptorParams(10));
    rad_indices.buildIndex();

    std::cout << "AQ preprocessing DirectEdgeIntegrator" << std::endl;
    auto NEE_function = [&](const Vector &AQ_rnd, RndSampler &sampler)
    {
        // Float result = eval_DirectBoundary(AQ_rnd, scene, sampler, edge_dist, edge_indices, max_bounces);
        Float result = eval_photon_DirectBoundary_capped(AQ_rnd, scene, sampler, edge_dist, edge_indices, max_bounces, rad_nodes, rad_indices);
        if (!isnan(result))
        {
            return result;
        }
        return 0.0;
    };
    std::vector<Float> cdfx;
    if (config.edge_draw)
    {
        std::cout << "AQ using draw edge" << std::endl;
        cdfx = draw_dist.m_cdf;
    }
    else
    {
        std::cout << "AQ using individual edge" << std::endl;
        cdfx = edge_dist.m_cdf;
    }

    std::cout << "Inital curvature size: " << cdfx.size() << std::endl;
    aq_distrb.setup(NEE_function, cdfx, config);
    // aq_distrb.print();
    std::cout << "finish AQ guiding" << std::endl;
}