#include "boundary.h"
#include <render/scene.h>
#include <core/math_func.h>
#include <core/timer.h>
#include <core/nanoflann.hpp>
#include <render/photon_map.h>
#include <core/logger.h>
#include <unsupported/Eigen/CXX11/Tensor>
#include <emitter/area.h>

namespace
{
    static DebugInfo debugInfo;
    void buildPhotonMap(const Scene &scene, int num_paths, int max_bounces, std::vector<RadImpNode> &nodes, bool importance)
    {
        const int nworker = omp_get_num_procs();
        std::vector<RndSampler> samplers;
        for (int i = 0; i < nworker; ++i)
            samplers.push_back(RndSampler(17, i));

        std::vector<std::vector<RadImpNode>> nodes_per_thread(nworker);
        for (int i = 0; i < nworker; i++)
        {
            nodes_per_thread[i].reserve(num_paths / nworker * max_bounces);
        }

        const Camera &camera = scene.camera;
        const CropRectangle &rect = camera.rect;

#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
        for (size_t omp_i = 0; omp_i < (size_t)num_paths; omp_i++)
        {
            const int tid = omp_get_thread_num();
            RndSampler &sampler = samplers[tid];
            Ray ray;
            Spectrum throughput;
            Intersection its;
            const Medium *ptr_med = nullptr;
            int depth = 0;
            bool onSurface = false;
            if (!importance)
            {
                // Trace ray from camera
                Float x = rect.isValid() ? rect.offset_x + sampler.next1D() * rect.crop_width : sampler.next1D() * camera.width;
                Float y = rect.isValid() ? rect.offset_y + sampler.next1D() * rect.crop_height : sampler.next1D() * camera.height;
                ray = camera.samplePrimaryRay(x, y);
                throughput = Spectrum(1.0);
                ptr_med = camera.getMedID() == -1 ? nullptr : scene.medium_list[camera.getMedID()];
            }
            else
            {
                // Trace ray from emitter
                throughput = scene.sampleEmitterPosition(sampler.next2D(), its);
                ptr_med = its.ptr_med_ext;
                nodes_per_thread[tid].push_back(RadImpNode{its.p, throughput, depth});
                ray.org = its.p;
                throughput *= its.ptr_emitter->sampleDirection(sampler.next2D(), ray.dir);
                ray.dir = its.geoFrame.toWorld(ray.dir);
                depth++;
                onSurface = true;
            }

            if (scene.rayIntersect(ray, onSurface, its))
            {
                while (depth <= max_bounces)
                {
                    bool inside_med = ptr_med != nullptr &&
                                      ptr_med->sampleDistance(Ray(ray), its.t, sampler.next2D(), &sampler, ray.org, throughput);
                    if (inside_med)
                    {
                        if (throughput.isZero())
                            break;
                        const PhaseFunction *ptr_phase = scene.phase_list[ptr_med->phase_id];
                        nodes_per_thread[tid].push_back(RadImpNode{ray.org, throughput, depth});
                        Float phase_val = ptr_phase->sample(-ray.dir, sampler.next2D(), ray.dir);
                        if (phase_val == 0.0)
                            break;
                        throughput *= phase_val;
                        scene.rayIntersect(ray, false, its);
                    }
                    else
                    {
                        nodes_per_thread[tid].push_back(RadImpNode{its.p, throughput, depth});
                        Float bsdf_pdf, bsdf_eta;
                        Vector wo_local, wo;
                        Spectrum bsdf_weight = its.sampleBSDF(sampler.next3D(), wo_local, bsdf_pdf, bsdf_eta,
                                                              importance ? EBSDFMode::EImportanceWithCorrection : EBSDFMode::ERadiance);
                        if (bsdf_weight.isZero())
                            break;
                        throughput = throughput * bsdf_weight;

                        wo = its.toWorld(wo_local);
                        Vector wi = -ray.dir;
                        Float wiDotGeoN = wi.dot(its.geoFrame.n), woDotGeoN = wo.dot(its.geoFrame.n);
                        if (wiDotGeoN * its.wi.z() <= 0 || woDotGeoN * wo_local.z() <= 0)
                            break;

                        if (its.isMediumTransition())
                            ptr_med = its.getTargetMedium(woDotGeoN);

                        ray = Ray(its.p, wo);
                        if (!scene.rayIntersect(ray, true, its))
                            break;
                    }
                    depth++;
                }
            }
        }
        size_t sz_node = 0;
        for (int i = 0; i < nworker; i++)
            sz_node += nodes_per_thread[i].size();
        nodes.reserve(sz_node);
        for (int i = 0; i < nworker; i++)
            nodes.insert(nodes.end(), nodes_per_thread[i].begin(), nodes_per_thread[i].end());
    }

    int queryPhotonMap(const KDtree<Float> &indices, const Float *query_point, size_t *matched_indices, Float &matched_dist_sqr)
    {
        int num_matched = 0;
        Float dist_sqr[NUM_NEAREST_NEIGHBORS];
        num_matched = indices.knnSearch(query_point, NUM_NEAREST_NEIGHBORS, matched_indices, dist_sqr);
        assert(num_matched == NUM_NEAREST_NEIGHBORS);
        matched_dist_sqr = dist_sqr[num_matched - 1];
        return num_matched;
    }

    [[maybe_unused]] void velocity(const Scene &scene,
                                   const BoundarySamplingRecord &bRec,
                                   Float &res)
    {
        const Shape *shape = scene.shape_list[bRec.shape_id];
        const Edge &edge = shape->edges[bRec.edge_id];
        const Vector &xB_0 = shape->getVertex(edge.v0);
        const Vector &xB_1 = shape->getVertex(edge.v1);
        const Vector &xB_2 = shape->getVertex(edge.v2);

        const Shape *shapeS = scene.shape_list[bRec.shape_id_S];
        const auto &indS = shapeS->getIndices(bRec.tri_id_S);
        const Vector &xS_0 = shapeS->getVertex(indS[0]);
        const Vector &xS_1 = shapeS->getVertex(indS[1]);
        const Vector &xS_2 = shapeS->getVertex(indS[2]);

        const Shape *shapeD = scene.shape_list[bRec.shape_id_D];
        const auto &indD = shapeD->getIndices(bRec.tri_id_D);
        const Vector &xD_0 = shapeD->getVertex(indD[0]);
        const Vector &xD_1 = shapeD->getVertex(indD[1]);
        const Vector &xD_2 = shapeD->getVertex(indD[2]);

        res = normal_velocity(xS_0, xS_1, xS_2,
                              xB_0, xB_1, xB_2, bRec.t, bRec.dir,
                              xD_0, xD_1, xD_2);
    }

    void d_velocity(const Scene &scene, Scene &d_scene,
                    Vector2 &rnd_light,
                    EdgeSamplingRecord &eRec,
                    Float d_u)
    {
        [[maybe_unused]] Float u;
#if defined(ENZYME) && defined(ENZYME_BOUNDARY_DIRECT)
        __enzyme_autodiff((void *)velocity,
                          enzyme_dup, &scene, &d_scene,
                          enzyme_const, &eRec,
                          enzyme_dupnoneed, &u, &d_u);
#endif
    }

    inline void handleSurfaceInteraction(const Intersection &its,
                                         const Scene &scene, Scene &d_scene,
                                         Vector2 &rnd_light, EdgeSamplingRecord &eRec,
                                         RndSampler &sampler, const Spectrum &weight,
                                         int max_depth, std::vector<Spectrum> &d_image)
    {
        int shape_idx = -1;
        if (forward) shape_idx = scene.getShapeRequiresGrad();
        CameraDirectSamplingRecord cRec;
        if (!scene.camera.sampleDirect(its.p, cRec))
            return;
        if (!scene.isVisible(its.p, true, scene.camera.cpos, true))
            return;
        auto [pixel_idx, sensor_val] = scene.camera.sampleDirectPixel(cRec, sampler.next1D());
        if (sensor_val < Epsilon)
            return;

        auto bsdf_val = its.evalBSDF(its.toLocal(cRec.dir),
                                     EBSDFMode::EImportanceWithCorrection);

        Spectrum value = weight * bsdf_val;
        Float d_u = (d_image[pixel_idx] * value * sensor_val).sum();
        Float param = 0;
        if (forward) {
            d_scene.shape_list[shape_idx]->param = 0.;
            param = d_scene.shape_list[shape_idx]->param;
        }
        d_velocity(scene, d_scene, rnd_light, eRec, d_u);
        if (forward) {
            param = d_scene.shape_list[shape_idx]->param - param;
            const int tid = omp_get_thread_num();
            debugInfo.image_per_thread[tid][pixel_idx] += Spectrum(param, 0, 0);
        }
    }

    Float eval_photon_DirectBoundary(const Vector3 &rnd_val,
                                     const Scene &scene,
                                     RndSampler &sampler,
                                     const DiscreteDistribution &edge_dist,
                                     const std::vector<Vector2i> &edge_indices, int max_bounces,
                                     const std::vector<RadImpNode> &rad_nodes,
                                     const KDtree<Float> &rad_indices)
    {
        BoundarySamplingRecord eRec;
        scene.sampleEdgePoint(rnd_val[0],
                              edge_dist, edge_indices,
                              eRec);
        const Shape *shape = scene.shape_list[eRec.shape_id];
        const Edge &edge = shape->edges[eRec.edge_id];
        assert(edge.f0 >= 0);

        /* Sample point on emitters */
        DirectSamplingRecord dRec(eRec.ref);
        Vector2 rnd_light(rnd_val[1], rnd_val[2]);
        Spectrum value = scene.sampleEmitterDirect(rnd_light, dRec);
        if (value.isZero(Epsilon))
            return 0.0f;
        const Vector xB = eRec.ref,
                     &xS = dRec.p;
        Ray ray(xB, (xB - xS).normalized());
        Intersection its;
        if (!scene.rayIntersect(ray, true, its))
            return 0.0f;

        // sanity check
        // make sure the ray is tangent to the surface
        if (edge.f0 >= 0 && edge.f1 >= 0)
        {
            Vector n0 = shape->getGeoNormal(edge.f0),
                   n1 = shape->getGeoNormal(edge.f1);
            Float dotn0 = ray.dir.dot(n0),
                  dotn1 = ray.dir.dot(n1);
            if (math::signum(dotn0) * math::signum(dotn1) > -0.5)
                return 0.0;
        }

        // NOTE prevent intersection with a backface
        Float gnDotD = its.geoFrame.n.dot(-ray.dir);
        Float snDotD = its.shFrame.n.dot(-ray.dir);
        bool success = (its.ptr_bsdf->isTransmissive() && math::signum(gnDotD) * math::signum(snDotD) > 0.5f) ||
                       (!its.ptr_bsdf->isTransmissive() && gnDotD > 0.01 && snDotD > 0.01);
        if (!success)
            return 0.0f;
        // populate the data in BoundarySamplingRecord eRec
        eRec.dir = -ray.dir;
        eRec.shape_id_S = dRec.shape_id;
        eRec.tri_id_S = dRec.tri_id;
        eRec.shape_id_D = its.indices[0];
        eRec.tri_id_D = its.indices[1];

        /* Jacobian determinant that accounts for the change of variable */
        Vector v0 = shape->getVertex(edge.v0);
        Vector v1 = shape->getVertex(edge.v1);
        Vector v2 = shape->getVertex(edge.v2);
        const Vector &xD = its.p;
        Vector n = (v0 - v1).cross(ray.dir).normalized();
        n *= -math::signum(n.dot(v2 - v0)); // make sure n points to the visible side
        Float J = dlD_dlB(xS,
                          xB, (v0 - v1).normalized(),
                          xD, its.geoFrame.n) *
                  dA_dw(xB, xS, dRec.n);
        Float baseValue = J * geometric(xD, its.geoFrame.n,
                                        xS, dRec.n);
        if (std::abs(baseValue) < Epsilon)
            return 0.0f;

        Spectrum throughput(1.0f);
        Ray ray_sensor;
        int depth = 1;
        Spectrum result(0.0);
        Matrix2x4 pixel_uvs;
        Array4 attenuations(0.0);
        Vector dir;
        if (scene.isVisible(its.p, true, scene.camera.cpos, true))
        {
            scene.camera.sampleDirect(its.p, pixel_uvs, attenuations, dir);

            auto bsdf_val = its.evalBSDF(its.toLocal(dir),
                                         EBSDFMode::EImportanceWithCorrection);
            if (attenuations.maxCoeff() < Epsilon)
            {
                return 0.0;
            }
            else
            {
                size_t matched_indices[NUM_NEAREST_NEIGHBORS];

                Float pt_rad[3] = {its.p[0], its.p[1], its.p[2]};
                Float matched_r2_rad;

                int num_nearby_rad = queryPhotonMap(rad_indices, pt_rad, matched_indices, matched_r2_rad);
                Spectrum photon_radiances(0.0);
                for (int m = 0; m < num_nearby_rad; m++)
                {
                    const RadImpNode &node = rad_nodes[matched_indices[m]];
                    if (node.depth <= max_bounces)
                        photon_radiances += node.val;
                }

                result += (value * photon_radiances).maxCoeff() * baseValue / eRec.pdf / matched_r2_rad;
            }
        }

        return result.abs().maxCoeff();
    }

    
    
    void eval_DirectBoundary(const Vector3 &rnd, const Scene &scene,
                                const int &max_bounces,
                                const DiscreteDistribution &edge_dist,
                                const std::vector<Vector2i> &edge_indices,
                                Float &contrib
                                )
    {
        contrib = 0;
        BoundarySamplingRecord eRec;
        scene.sampleEdgePoint(rnd[0],
                              edge_dist, edge_indices,
                              eRec);
        if (eRec.shape_id == -1)
        {
            PSDR_WARN(eRec.shape_id == -1);
            contrib = 0;
            return;
        }
        const Shape *shape = scene.shape_list[eRec.shape_id];
        const Edge &edge = shape->edges[eRec.edge_id];
        assert(edge.f0 >= 0);
        /* Sample point on emitters */
        DirectSamplingRecord dRec(eRec.ref);
        Vector2 rnd_light(rnd[1], rnd[2]);
        Spectrum value = scene.sampleEmitterDirect(rnd_light, dRec);

        if (value.isZero(Epsilon)){
            contrib = 0;
            return;
        }
        const Vector xB = eRec.ref,
                     xS = dRec.p;
        Ray ray(xB, (xB - xS).normalized());
        Intersection its;
        if (!scene.rayIntersect(ray, true, its)){
            contrib = 0;
            return;
        }

        // sanity check
        // make sure the ray is tangent to the surface
        if (edge.f0 >= 0 && edge.f1 >= 0)
        {
            Vector n0 = shape->getGeoNormal(edge.f0),
                   n1 = shape->getGeoNormal(edge.f1);
            Float dotn0 = ray.dir.dot(n0),
                  dotn1 = ray.dir.dot(n1);
            if (math::signum(dotn0) * math::signum(dotn1) > -0.5){
                contrib = 0;
                return;
            }
        }
        // NOTE prevent intersection with a backface
        Float gnDotD = its.geoFrame.n.dot(-ray.dir);
        Float snDotD = its.shFrame.n.dot(-ray.dir);
        bool success = (its.ptr_bsdf->isTransmissive() && math::signum(gnDotD) * math::signum(snDotD) > 0.5f) ||
                       (!its.ptr_bsdf->isTransmissive() && gnDotD > 0.01 && snDotD > 0.01);
        if (!success){
            contrib = 0;
            return;
        }
        // populate the data in BoundarySamplingRecord eRec
        eRec.dir = -ray.dir;
        eRec.shape_id_S = dRec.shape_id;
        eRec.tri_id_S = dRec.tri_id;
        eRec.shape_id_D = its.indices[0];
        eRec.tri_id_D = its.indices[1];
        /* Jacobian determinant that accounts for the change of variable */
        Vector v0 = detach(shape->getVertex(edge.v0));
        Vector v1 = detach(shape->getVertex(edge.v1));
        Vector v2 = detach(shape->getVertex(edge.v2));
        const Vector &xD = its.p;
        Float J = dlD_dlB(xS,
                          xB, (v0 - v1).normalized(),
                          xD, its.geoFrame.n) *
                  dA_dw(xB, xS, dRec.n);
        Float baseValue = J * geometric(xD, its.geoFrame.n, xS, dRec.n);
            
        CameraDirectSamplingRecord cRec;
        if (!scene.camera.sampleDirect(its.p, cRec)){
            contrib = 0;
            return;
        }

        if (!scene.isVisible(its.p, true, scene.camera.cpos, true)){
            contrib = 0;
            return;
        }
        Float sensor_val = cRec.baseVal;

        if (sensor_val < Epsilon){
            contrib = 0;
            return;
        }
        Spectrum bsdf_val = its.evalBSDF(its.toLocal(cRec.dir),
                                    EBSDFMode::EImportanceWithCorrection);
        Spectrum b_value = value * baseValue / detach(eRec.pdf) * bsdf_val; // eRec.pdf, bsdf_val;
        contrib = b_value.sum();
        return;
    }

    void d_sampleDirectBoundary(const Scene &scene, Scene &d_scene,
                                RndSampler &sampler, const RenderOptions &options,
                                const DiscreteDistribution &edge_dist,
                                const std::vector<Vector2i> &edge_indices,
                                std::vector<Spectrum> &d_image, const Vector3 &d_rnd, Float d_pdf)
    {
        /* Sample a point on the boundary */
        BoundarySamplingRecord eRec;
        scene.sampleEdgePoint(d_rnd[0],
                              edge_dist, edge_indices,
                              eRec);
        if (eRec.shape_id == -1)
        {
            PSDR_WARN(eRec.shape_id == -1);
            return;
        }
        const Shape *shape = scene.shape_list[eRec.shape_id];
        const Edge &edge = shape->edges[eRec.edge_id];
        assert(edge.f0 >= 0);

        /* Sample point on emitters */
        DirectSamplingRecord dRec(eRec.ref);
        Vector2 rnd_light(d_rnd[1], d_rnd[2]);
        Spectrum value = scene.sampleEmitterDirect(rnd_light, dRec);
        if (value.isZero(Epsilon))
            return;
        const Vector xB = eRec.ref,
                     &xS = dRec.p;
        Ray ray(xB, (xB - xS).normalized());
        Intersection its;
        if (!scene.rayIntersect(ray, true, its))
            return;

        // sanity check
        // make sure the ray is tangent to the surface
        if (edge.f0 >= 0 && edge.f1 >= 0)
        {
            Vector n0 = shape->getGeoNormal(edge.f0),
                   n1 = shape->getGeoNormal(edge.f1);
            Float dotn0 = ray.dir.dot(n0),
                  dotn1 = ray.dir.dot(n1);
            if (math::signum(dotn0) * math::signum(dotn1) > -0.5)
                return;
        }
        // NOTE prevent intersection with a backface
        Float gnDotD = its.geoFrame.n.dot(-ray.dir);
        Float snDotD = its.shFrame.n.dot(-ray.dir);
        bool success = (its.ptr_bsdf->isTransmissive() && math::signum(gnDotD) * math::signum(snDotD) > 0.5f) ||
                       (!its.ptr_bsdf->isTransmissive() && gnDotD > 0.01 && snDotD > 0.01);
        if (!success)
            return;
        // populate the data in BoundarySamplingRecord eRec
        eRec.dir = -ray.dir;
        eRec.shape_id_S = dRec.shape_id;
        eRec.tri_id_S = dRec.tri_id;
        eRec.shape_id_D = its.indices[0];
        eRec.tri_id_D = its.indices[1];

        /* Jacobian determinant that accounts for the change of variable */
        Vector v0 = shape->getVertex(edge.v0);
        Vector v1 = shape->getVertex(edge.v1);
        Vector v2 = shape->getVertex(edge.v2);
        const Vector &xD = its.p;
        Vector n = (v0 - v1).cross(ray.dir).normalized();
        n *= -math::signum(n.dot(v2 - v0)); // make sure n points to the visible side
        Float J = dlD_dlB(xS,
                          xB, (v0 - v1).normalized(),
                          xD, its.geoFrame.n) *
                  dA_dw(xB, xS, dRec.n);
        Float baseValue = J * geometric(xD, its.geoFrame.n, xS, dRec.n) / d_pdf;
        if (std::abs(baseValue) < Epsilon)
            return;
        // assert(baseValue > -Epsilon);

        /* Sample detector path */
        Spectrum throughput(1.0f);
        Ray ray_sensor;
        for (int i = 0; i < options.max_bounces; i++)
        {
            handleSurfaceInteraction(its, scene, d_scene, rnd_light,
                                     eRec, sampler, value * baseValue / eRec.pdf,
                                     options.max_bounces, d_image);
            Vector wo_local, wo;
            Float bsdf_pdf, bsdf_eta;
            Spectrum bsdf_weight = its.sampleBSDF(sampler.next3D(), wo_local,
                                                  bsdf_pdf, bsdf_eta,
                                                  EBSDFMode::EImportanceWithCorrection);
            if (bsdf_weight.isZero())
                break;
            wo = its.toWorld(wo_local);
            Vector wi = its.toWorld(its.wi);
            Float wiDotGeoN = wi.dot(its.geoFrame.n), woDotGeoN = wo.dot(its.geoFrame.n);
            if (wiDotGeoN * its.wi.z() <= 0 || woDotGeoN * wo_local.z() <= 0)
                break;
            value *= bsdf_weight;
            ray_sensor = Ray(its.p, wo);
            scene.rayIntersect(ray_sensor, true, its);
            if (!its.isValid())
                break;
        }
    }
} // namespace

DirectEdgeIntegrator::DirectEdgeIntegrator(const Scene &scene)
{
    configure(scene);
}

DirectEdgeIntegrator::DirectEdgeIntegrator(const Properties &props)
{}

void DirectEdgeIntegrator::configure(const Scene &scene)
{
    edge_dist.clear();
    edge_indices.clear();
    edge_indices_inv.clear();
    edge_indices_inv.resize(scene.shape_list.size());
    /* generate the edge distribution */
    for (size_t i = 0; i < scene.shape_list.size(); i++)
    {
        auto &indices_inv = edge_indices_inv[i];
        auto &shape = *scene.shape_list[i];
        indices_inv.resize(shape.edges.size());
        for (size_t j = 0; j < shape.edges.size(); j++)
        {
            indices_inv[j] = -1;
        }
        if (!shape.enable_edge)
        {
            continue;
        }
        const BSDF &bsdf = *scene.bsdf_list[shape.bsdf_id];
        if (bsdf.isNull() &&
            shape.light_id < 0)
            continue;

        Float total_dist = 0.0;
        for (size_t j = 0; j < shape.edges.size(); j++)
        {
            const Edge &edge = shape.edges[j];

            if (edge.mode == 0 ||
                bsdf.isTransmissive())
            {
                indices_inv[j] = edge_indices.size();
                edge_indices.push_back({i, j});
                edge_dist.append(edge.length * 4 * M_PI); // area of a unit sphere
                if (!shape.enable_draw)
                {
                    draw_dist.append(edge.length * 4 * M_PI);
                    total_dist = 0.0;
                }
                total_dist += edge.length * 4 * M_PI;
            }
            if (edge.mode > 0)
            {
                indices_inv[j] = edge_indices.size();
                const Vector &n0 = shape.getGeoNormal(edge.f0),
                             &n1 = shape.getGeoNormal(edge.f1);
                Float cos_angle = n0.dot(n1);
                edge_indices.push_back({i, j});
                if (!shape.enable_draw)
                {
                    draw_dist.append(edge.length * 4 * std::acos(cos_angle));
                    total_dist = 0.0;
                }
                edge_dist.append(edge.length * 4 * std::acos(cos_angle));
                total_dist += edge.length * 4 * std::acos(cos_angle);
            }
            if (shape.enable_draw)
            {
                if (j < shape.edges.size() - 1)
                {
                    const Edge &edge_next = shape.edges[j + 1];
                    if (edge.v1 != edge_next.v0)
                    {
                        draw_dist.append(total_dist);
                        total_dist = 0.0;
                    }
                }
            }
        }
        if (shape.enable_draw)
        {
            draw_dist.append(total_dist);
            total_dist = 0.0;
        }
    }
    edge_dist.normalize();
    draw_dist.normalize();
    // std::cout << "edge dist size: " << edge_dist.m_cdf.size() - 1 << std::endl;
    // std::cout << "edge draw size: " << draw_dist.m_cdf.size() - 1 << std::endl;
}

void DirectEdgeIntegrator::recompute_edge(const Scene &scene)
{
    edge_dist.clear();
    draw_dist.clear();
    edge_indices.clear();
    edge_indices_inv.clear();
    edge_indices_inv.resize(scene.shape_list.size());
    for (size_t i = 0; i < scene.shape_list.size(); i++)
    {
        auto &indices_inv = edge_indices_inv[i];
        indices_inv.clear();
        auto &shape = *scene.shape_list[i];
        indices_inv.resize(shape.edges.size());
        for (size_t j = 0; j < shape.edges.size(); j++)
        {
            indices_inv[j] = -1;
        }
        if (!shape.enable_edge)
        {
            continue;
        }
        const BSDF &bsdf = *scene.bsdf_list[shape.bsdf_id];
        if (bsdf.isNull() &&
            shape.light_id < 0)
            continue;

        Float total_dist = 0.0;
        for (size_t j = 0; j < shape.edges.size(); j++)
        {
            const Edge &edge = shape.edges[j];

            if (edge.mode == 0 ||
                bsdf.isTransmissive())
            {
                indices_inv[j] = edge_indices.size();
                edge_indices.push_back({i, j});
                edge_dist.append(edge.length * 4 * M_PI); // area of a unit sphere
                if (!shape.enable_draw)
                {
                    draw_dist.append(edge.length * 4 * M_PI);
                    total_dist = 0.0;
                }
                total_dist += edge.length * 4 * M_PI;
            }
            if (edge.mode > 0)
            {
                indices_inv[j] = edge_indices.size();
                const Vector &n0 = shape.getGeoNormal(edge.f0),
                             &n1 = shape.getGeoNormal(edge.f1);
                Float cos_angle = n0.dot(n1);
                edge_indices.push_back({i, j});
                if (!shape.enable_draw)
                {
                    draw_dist.append(edge.length * 4 * std::acos(cos_angle));
                    total_dist = 0.0;
                }
                edge_dist.append(edge.length * 4 * std::acos(cos_angle));
                total_dist += edge.length * 4 * std::acos(cos_angle);
            }
            if (shape.enable_draw)
            {
                if (j < shape.edges.size() - 1)
                {
                    const Edge &edge_next = shape.edges[j + 1];
                    if (edge.v1 != edge_next.v0)
                    {
                        draw_dist.append(total_dist);
                        total_dist = 0.0;
                    }
                }
            }
        }
        if (shape.enable_draw)
        {
            draw_dist.append(total_dist);
            total_dist = 0.0;
        }
    }
    edge_dist.normalize();
    draw_dist.normalize();
    std::cout << "edge draw size: " << draw_dist.m_cdf.size() - 1 << std::endl;
}

void DirectEdgeIntegrator::preprocess_grid(const Scene &scene, const Grid3D_Sampling::grid3D_config &config, int max_bounces)
{
    PhotonMapOptions opts(10000, 10000, max_bounces);
    std::cout << "[INFO] Direct Guiding: #camPath = " << opts.num_cam_path << std::endl;
    std::vector<RadImpNode> rad_nodes;
    buildPhotonMap(scene, opts.num_cam_path, max_bounces, rad_nodes, false);
    std::cout << "[INFO] Direct Guiding: #rad_nodes = " << rad_nodes.size() << std::endl;

    PointCloud<Float> rad_cloud;
    rad_cloud.pts.resize(rad_nodes.size());
    for (size_t i = 0; i < rad_nodes.size(); i++)
    {
        rad_cloud.pts[i].x = rad_nodes[i].p[0];
        rad_cloud.pts[i].y = rad_nodes[i].p[1];
        rad_cloud.pts[i].z = rad_nodes[i].p[2];
    }
    KDtree<Float> rad_indices(3, rad_cloud, nanoflann::KDTreeSingleIndexAdaptorParams(10));
    rad_indices.buildIndex();
    std::cout << "preprocessing DirectEdgeIntegrator" << std::endl;
    auto NEE_function = [&](const Vector &AQ_rnd, RndSampler &sampler)
    {
        // Float result = eval_DirectBoundary(AQ_rnd, scene, sampler, edge_dist, edge_indices, max_bounces);
        Float result = eval_photon_DirectBoundary(AQ_rnd, scene, sampler, edge_dist, edge_indices, max_bounces, rad_nodes, rad_indices);
        if (!isnan(result))
        {
            return result;
        }
        return 0.0;
    };
    grid_distrb.setup(NEE_function, config);
    std::cout << "finish grid guiding" << std::endl;
}

void DirectEdgeIntegrator::preprocess_aq(const Scene &scene, const Adaptive_Sampling::adaptive3D_config &config, int max_bounces)
{
    PhotonMapOptions opts(10000, 10000, max_bounces);
    std::cout << "[INFO] Direct Guiding: #camPath = " << opts.num_cam_path << std::endl;
    std::vector<RadImpNode> rad_nodes;
    buildPhotonMap(scene, opts.num_cam_path, max_bounces, rad_nodes, false);
    std::cout << "[INFO] Direct Guiding: #rad_nodes = " << rad_nodes.size() << std::endl;

    PointCloud<Float> rad_cloud;
    rad_cloud.pts.resize(rad_nodes.size());
    for (size_t i = 0; i < rad_nodes.size(); i++)
    {
        rad_cloud.pts[i].x = rad_nodes[i].p[0];
        rad_cloud.pts[i].y = rad_nodes[i].p[1];
        rad_cloud.pts[i].z = rad_nodes[i].p[2];
    }
    KDtree<Float> rad_indices(3, rad_cloud, nanoflann::KDTreeSingleIndexAdaptorParams(10));
    rad_indices.buildIndex();

    std::cout << "AQ preprocessing DirectEdgeIntegrator" << std::endl;
    auto NEE_function = [&](const Vector &AQ_rnd, RndSampler &sampler)
    {
        // Float result = eval_DirectBoundary(AQ_rnd, scene, sampler, edge_dist, edge_indices, max_bounces);
        Float result = eval_photon_DirectBoundary(AQ_rnd, scene, sampler, edge_dist, edge_indices, max_bounces, rad_nodes, rad_indices);
        if (!isnan(result))
        {
            return result;
        }
        return 0.0;
    };
    std::vector<Float> cdfx;
    if (config.edge_draw)
    {
        std::cout << "AQ using draw edge" << std::endl;
        cdfx = draw_dist.m_cdf;
    }
    else
    {
        std::cout << "AQ using individual edge" << std::endl;
        cdfx = edge_dist.m_cdf;
    }

    std::cout << "Inital curvature size: " << cdfx.size() << std::endl;
    aq_distrb.setup(NEE_function, cdfx, config);
    // aq_distrb.print();
    std::cout << "finish AQ guiding" << std::endl;
}

ArrayXd DirectEdgeIntegrator::renderD(
    SceneAD &sceneAD, RenderOptions &options, const ArrayXd &__d_image) const
{
    PSDR_INFO("DirectEdgeIntegrator::renderD with spp = {}",
              options.num_samples_secondary_edge_direct);
    const Scene &scene = sceneAD.val;
    [[maybe_unused]] Scene &d_scene = sceneAD.der;
    GradientManager<Scene> &gm = sceneAD.gm;
    gm.setZero(); // zero multi-thread gradient

    const int nworker = omp_get_num_procs();
    const auto &camera = scene.camera;
    const int nsamples = options.num_samples_secondary_edge_direct;
    const int nblocks = std::ceil(static_cast<Float>(camera.getNumPixels()) / (options.block_size * options.block_size));
    const int nblock_samples = options.block_size * options.block_size * nsamples;
    /* init debug info */
    debugInfo = DebugInfo(nworker, camera.getNumPixels(), nsamples);
    if (nsamples <= 0)
        return debugInfo.getArray();

    std::vector<Spectrum> _d_image_spec_list = from_tensor_to_spectrum_list(
        __d_image / nblock_samples / nblocks, camera.getNumPixels());

    Timer _("Direct boundary");
    int blockProcessed = 0;
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < nblocks; i++)
    {
        for (int j = 0; j < nblock_samples; j++)
        {
            const int tid = omp_get_thread_num();
            RndSampler sampler(options.seed, i * nblock_samples + j);

            Vector3 d_rnd = sampler.next3D();
            Float d_pdf = 1.0;

            if (grid_distrb.distrb.getSum() > Epsilon)
            {
                d_rnd = grid_distrb.sample(d_rnd, d_pdf);
            }
            else if (aq_distrb.distrb.getSum() > Epsilon)
            {
                d_rnd = aq_distrb.sample(d_rnd, d_pdf);
            }
            assert(d_pdf > 0.0);
            // sample a point on the boundary
            d_sampleDirectBoundary(scene, gm.get(tid),
                                   sampler, options,
                                   edge_dist, edge_indices,
                                   _d_image_spec_list, d_rnd, d_pdf);
        }
        if (verbose)
#pragma omp critical
            progressIndicator(static_cast<Float>(++blockProcessed) / nblocks);
    }
    if (verbose)
        std::cout << std::endl;

    // merge gradient to d_scene
    gm.merge();
    d_scene.configureD(scene);
    /* normal related */
#ifdef NORMAL_PREPROCESS
    d_precompute_normal(scene, d_scene);
#endif

    return flattened(debugInfo.getArray());
}

ArrayXd DirectEdgeIntegrator::forwardRenderD(
    SceneAD &sceneAD, RenderOptions &options) const
{
    PSDR_INFO("DirectEdgeIntegrator::forwardRenderD with spp = {}",
              options.num_samples_secondary_edge_direct);
    const Scene &scene = sceneAD.val;
    [[maybe_unused]] Scene &d_scene = sceneAD.der;
    GradientManager<Scene> &gm = sceneAD.gm;
    gm.setZero(); // zero multi-thread gradient

    const int nworker = omp_get_num_procs();
    const auto &camera = scene.camera;
    const int nsamples = options.num_samples_secondary_edge_direct;
    const int nblocks = std::ceil(static_cast<Float>(camera.getNumPixels()) / (options.block_size * options.block_size));
    const int nblock_samples = options.block_size * options.block_size * nsamples;
    /* init debug info */
    debugInfo = DebugInfo(nworker, camera.getNumPixels(), nsamples);
    if (nsamples <= 0)
        return debugInfo.getArray();

    std::vector<Spectrum> _d_image_spec_list(camera.getNumPixels(), Spectrum(1, 0, 0) / nblock_samples / nblocks);

    Timer _("Direct boundary");
    int blockProcessed = 0;
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < nblocks; i++)
    {
        for (int j = 0; j < nblock_samples; j++)
        {
            const int tid = omp_get_thread_num();
            RndSampler sampler(options.seed, i * nblock_samples + j);

            Vector3 d_rnd = sampler.next3D();
            Float d_pdf = 1.0;

            if (grid_distrb.distrb.getSum() > Epsilon)
            {
                d_rnd = grid_distrb.sample(d_rnd, d_pdf);
            }
            else if (aq_distrb.distrb.getSum() > Epsilon)
            {
                d_rnd = aq_distrb.sample(d_rnd, d_pdf);
            }
            assert(d_pdf > 0.0);
            // sample a point on the boundary
            d_sampleDirectBoundary(scene, gm.get(tid),
                                   sampler, options,
                                   edge_dist, edge_indices,
                                   _d_image_spec_list, d_rnd, d_pdf);
        }
        if (verbose)
#pragma omp critical
            progressIndicator(static_cast<Float>(++blockProcessed) / nblocks);
    }
    if (verbose)
        std::cout << std::endl;

    // merge gradient to d_scene
    gm.merge();
    d_scene.configureD(scene);
    /* normal related */
#ifdef NORMAL_PREPROCESS
    d_precompute_normal(scene, d_scene);
#endif

    return flattened(debugInfo.getArray());
}