#include "boundary.h"
#include <render/scene.h>
#include <core/math_func.h>
#include <core/timer.h>
#include <render/photon_map.h>
#include <core/logger.h>
#include <core/properties.h>

namespace
{
    struct BoundaryRadianceQueryRecord
    {
        RndSampler *sampler;
        int max_bounces;
        std::vector<Spectrum> values;
        BoundaryRadianceQueryRecord(RndSampler *sampler, int max_bounces)
            : sampler(sampler), max_bounces(max_bounces), values(max_bounces + 1, Spectrum::Zero()) {}
    };
    static DebugInfo debugInfo;

    void buildPhotonMap(const Scene &scene, int num_paths, int max_bounces, std::vector<RadImpNode> &nodes, bool importance)
    {
        const int nworker = omp_get_num_procs();
        std::vector<RndSampler> samplers;
        for (int i = 0; i < nworker; ++i)
            samplers.push_back(RndSampler(17, i));

        std::vector<std::vector<RadImpNode>> nodes_per_thread(nworker);
        for (int i = 0; i < nworker; i++)
        {
            nodes_per_thread[i].reserve(num_paths / nworker * max_bounces);
        }
        const Camera &camera = scene.camera;
        const CropRectangle &rect = camera.rect;

#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
        for (size_t omp_i = 0; omp_i < (size_t)num_paths; omp_i++)
        {
            const int tid = omp_get_thread_num();
            RndSampler &sampler = samplers[tid];
            Ray ray;
            Spectrum throughput;
            Intersection its;
            const Medium *ptr_med = nullptr;
            int depth = 0;
            bool onSurface = false;
            if (!importance)
            {
                // Trace ray from camera
                Float x = rect.isValid() ? rect.offset_x + sampler.next1D() * rect.crop_width : sampler.next1D() * camera.width;
                Float y = rect.isValid() ? rect.offset_y + sampler.next1D() * rect.crop_height : sampler.next1D() * camera.height;
                ray = camera.samplePrimaryRay(x, y);
                throughput = Spectrum(1.0);
                ptr_med = camera.getMedID() == -1 ? nullptr : scene.medium_list[camera.getMedID()];
            }
            else
            {
                // Trace ray from emitter
                throughput = scene.sampleEmitterPosition(sampler.next2D(), its);
                ptr_med = its.ptr_med_ext;
                nodes_per_thread[tid].push_back(RadImpNode{its.p, throughput, depth});
                ray.org = its.p;
                throughput *= its.ptr_emitter->sampleDirection(sampler.next2D(), ray.dir);
                ray.dir = its.geoFrame.toWorld(ray.dir);
                depth++;
                onSurface = true;
            }
            if (scene.rayIntersect(ray, onSurface, its))
            {
                while (depth <= max_bounces)
                {
                    bool inside_med = ptr_med != nullptr &&
                                      ptr_med->sampleDistance(Ray(ray), its.t, sampler.next2D(), &sampler, ray.org, throughput);
                    if (inside_med)
                    {
                        if (throughput.isZero())
                            break;
                        const PhaseFunction *ptr_phase = scene.phase_list[ptr_med->phase_id];
                        nodes_per_thread[tid].push_back(RadImpNode{ray.org, throughput, depth});
                        Float phase_val = ptr_phase->sample(-ray.dir, sampler.next2D(), ray.dir);
                        if (phase_val == 0.0)
                            break;
                        throughput *= phase_val;
                        scene.rayIntersect(ray, false, its);
                    }
                    else
                    {
                        nodes_per_thread[tid].push_back(RadImpNode{its.p, throughput, depth});
                        Float bsdf_pdf, bsdf_eta;
                        Vector wo_local, wo;
                        Spectrum bsdf_weight = its.sampleBSDF(sampler.next3D(), wo_local, bsdf_pdf, bsdf_eta,
                                                              importance ? EBSDFMode::EImportanceWithCorrection : EBSDFMode::ERadiance);
                        if (bsdf_weight.isZero())
                            break;
                        throughput = throughput * bsdf_weight;

                        wo = its.toWorld(wo_local);
                        Vector wi = -ray.dir;
                        Float wiDotGeoN = wi.dot(its.geoFrame.n), woDotGeoN = wo.dot(its.geoFrame.n);
                        if (wiDotGeoN * its.wi.z() <= 0 || woDotGeoN * wo_local.z() <= 0)
                            break;

                        if (its.isMediumTransition())
                            ptr_med = its.getTargetMedium(woDotGeoN);

                        ray = Ray(its.p, wo);
                        if (!scene.rayIntersect(ray, true, its))
                            break;
                    }
                    depth++;
                }
            }
        }
        size_t sz_node = 0;
        for (int i = 0; i < nworker; i++)
            sz_node += nodes_per_thread[i].size();
        nodes.reserve(sz_node);
        for (int i = 0; i < nworker; i++)
            nodes.insert(nodes.end(), nodes_per_thread[i].begin(), nodes_per_thread[i].end());
    }

    int queryPhotonMap(const KDtree<Float> &indices, const Float *query_point, size_t *matched_indices, Float &matched_dist_sqr)
    {
        int num_matched = 0;
        Float dist_sqr[NUM_NEAREST_NEIGHBORS];
        num_matched = indices.knnSearch(query_point, NUM_NEAREST_NEIGHBORS, matched_indices, dist_sqr);
        assert(num_matched == NUM_NEAREST_NEIGHBORS);
        matched_dist_sqr = dist_sqr[num_matched - 1];
        return num_matched;
    }

    [[maybe_unused]] void velocity(const Scene &scene,
                                   const BoundarySamplingRecord &bRec,
                                   Float &res)
    {
        const Shape *shape = scene.shape_list[bRec.shape_id];
        const Edge &edge = shape->edges[bRec.edge_id];
        const Vector &xB_0 = shape->getVertex(edge.v0);
        const Vector &xB_1 = shape->getVertex(edge.v1);
        const Vector &xB_2 = shape->getVertex(edge.v2);

        const Shape *shapeS = scene.shape_list[bRec.shape_id_S];
        const auto &indS = shapeS->getIndices(bRec.tri_id_S);
        const Vector &xS_0 = shapeS->getVertex(indS[0]);
        const Vector &xS_1 = shapeS->getVertex(indS[1]);
        const Vector &xS_2 = shapeS->getVertex(indS[2]);

        const Shape *shapeD = scene.shape_list[bRec.shape_id_D];
        const auto &indD = shapeD->getIndices(bRec.tri_id_D);
        const Vector &xD_0 = shapeD->getVertex(indD[0]);
        const Vector &xD_1 = shapeD->getVertex(indD[1]);
        const Vector &xD_2 = shapeD->getVertex(indD[2]);

        res = normal_velocity(xS_0, xS_1, xS_2,
                              xB_0, xB_1, xB_2, bRec.t, bRec.dir,
                              xD_0, xD_1, xD_2);
    }

    void d_velocity(const Scene &scene, Scene &d_scene,
                    const EdgeRaySamplingRecord &eRec,
                    Float d_u)
    {
        [[maybe_unused]] Float u;
#if defined(ENZYME) && defined(ENZYME_BOUNDARY_INDIRECT)
        __enzyme_autodiff((void *)velocity,
                          enzyme_dup, &scene, &d_scene,
                          enzyme_const, &eRec,
                          enzyme_dupnoneed, &u, &d_u);
#endif
    }

    inline void handleSurfaceInteraction(const Intersection &its,
                                         const Scene &scene, Scene &d_scene,
                                         EdgeRaySamplingRecord &eRec,
                                         RndSampler &sampler, const Spectrum &weight,
                                         int max_depth, std::vector<Spectrum> &d_image)
    {
        int shape_idx = -1;
        if (forward) shape_idx = scene.getShapeRequiresGrad();
        CameraDirectSamplingRecord cRec;
        if (!scene.camera.sampleDirect(its.p, cRec))
            return;
        if (!scene.isVisible(its.p, true, scene.camera.cpos, true))
            return;
        auto [pixel_idx, sensor_val] = scene.camera.sampleDirectPixel(cRec, sampler.next1D());
        if (sensor_val < Epsilon)
            return;

        auto bsdf_val = its.evalBSDF(its.toLocal(cRec.dir),
                                     EBSDFMode::EImportanceWithCorrection);
        Spectrum value = weight * bsdf_val;
        Float d_u = (d_image[pixel_idx] * value * sensor_val).sum();

        Float param = 0;
        if (forward) {
            d_scene.shape_list[shape_idx]->param = 0.;
            param = d_scene.shape_list[shape_idx]->param;
        }
        d_velocity(scene, d_scene, eRec, d_u);
        if (forward) {
            param = d_scene.shape_list[shape_idx]->param - param;
            const int tid = omp_get_thread_num();
            debugInfo.image_per_thread[tid][pixel_idx] += Spectrum(param, 0, 0);
        }
    }

    void sampleEdgeRay(const Scene &scene, const Array3 rnd,
                       const DiscreteDistribution &edge_dist,
                       const std::vector<Vector2i> &edge_indices,
                       EdgeRaySamplingRecord &eRec)
    {
        /* Sample a point on the boundary */
        scene.sampleEdgePoint(rnd[0],
                              edge_dist, edge_indices,
                              eRec);
        if (eRec.shape_id == -1)
        {
            PSDR_WARN(eRec.shape_id == -1);
            return;
        }
        const Shape *shape = scene.shape_list[eRec.shape_id];
        const Edge &edge = shape->edges[eRec.edge_id];
        assert(edge.f0 >= 0);

        /* Sample edge ray */
        if (edge.f1 < 0) // Case 1: boundary edge
        {
            eRec.dir = squareToUniformSphere(Vector2{rnd[1], rnd[2]});
            eRec.pdf /= 4. * M_PI;
        }
        else // Case 2: non-boundary edge
        {
            Vector n0 = shape->getGeoNormal(edge.f0);
            Vector n1 = shape->getGeoNormal(edge.f1);
            Float pdf1;
            eRec.dir = squareToEdgeRayDirection(Vector2(rnd[1], rnd[2]), n0, n1, pdf1);
            eRec.pdf *= pdf1;

            Float dotn0 = eRec.dir.dot(n0), dotn1 = eRec.dir.dot(n1);
            if (math::signum(dotn0) * math::signum(dotn1) > -0.5f)
            {
                std::cerr << "\n[WARN] Bad edge ray sample: [" << dotn0 << ", " << dotn1 << "]" << std::endl;
                eRec.shape_id = -1;
            }
        }
    }

    Spectrum __Li(const Scene &scene, const Ray &_ray, BoundaryRadianceQueryRecord &rRec)
    {
        Ray ray(_ray);
        RndSampler *sampler = rRec.sampler;
        Intersection its;

        Spectrum ret = Spectrum::Zero();
        scene.rayIntersect(ray, true, its);
        if (!its.isValid())
            return Spectrum::Zero();

        Spectrum throughput = Spectrum::Ones();
        if (its.isEmitter())
        {
            ret += throughput * its.Le(-ray.dir);
            rRec.values[0] += throughput * its.Le(-ray.dir);
        }
        for (int depth = 0; depth < rRec.max_bounces && its.isValid(); depth++)
        {
            // Direct illumination
            Float pdf_nee;
            Vector wo;
            DirectSamplingRecord dRec(its);
            auto value = scene.sampleEmitterDirect(sampler->next2D(), dRec);
            wo = its.toLocal(dRec.dir);
            if (!value.isZero(Epsilon))
            {
                auto bsdf_val = its.evalBSDF(wo);
#if defined(MIS)
                Float bsdf_pdf = its.pdfBSDF(wo);
                pdf_nee = dRec.pdf / geometric(its.p, dRec.p, dRec.n);
                auto mis_weight = square(pdf_nee) / (square(pdf_nee) + square(bsdf_pdf));
                ret += throughput * value * bsdf_val * mis_weight;
                rRec.values[depth + 1] += throughput * value * bsdf_val * mis_weight;
#else
                ret += throughput * value * bsdf_val;
                rRec.values[depth + 1] += throughput * value * bsdf_val;
#endif
            }
            // Indirect illumination
            Float bsdf_pdf, bsdf_eta;
            auto bsdf_weight = its.sampleBSDF(sampler->next3D(), wo, bsdf_pdf, bsdf_eta);
            if (bsdf_weight.isZero(Epsilon))
                break;
            wo = its.toWorld(wo);
            ray = Ray(its.p, wo);

            Vector pre_p = its.p;
            if (!scene.rayIntersect(ray, true, its))
                break;

            throughput *= bsdf_weight;

#if defined(MIS)
            if (its.isEmitter())
            {
                Spectrum light_contrib = its.Le(-ray.dir);
                if (!light_contrib.isZero(Epsilon))
                {
                    auto dist_sq = (its.p - ray.org).squaredNorm();
                    auto geometry_term = its.wi.z() / dist_sq;
                    pdf_nee = scene.pdfEmitterSample(its) / geometry_term;
                    auto mis_weight = square(bsdf_pdf) / (square(pdf_nee) + square(bsdf_pdf));
                    ret += throughput * light_contrib * mis_weight;
                    rRec.values[depth + 1] += throughput * light_contrib * mis_weight;
                }
            }
#endif
        }
        return ret;
    }

    Float eval_photon_InDirectBoundary(const Vector3 &rnd_val,
                                       const Scene &scene,
                                       RndSampler &sampler,
                                       const DiscreteDistribution &edge_dist,
                                       const std::vector<Vector2i> &edge_indices, int max_bounces,
                                       const std::vector<RadImpNode> &rad_nodes, const std::vector<RadImpNode> &imp_nodes,
                                       const KDtree<Float> &rad_indices, const KDtree<Float> &imp_indices, int shape_opt_id, bool local_backward)
    {
        BoundarySamplingRecord eRec;
        // EdgeRaySamplingRecord eRec;
        sampleEdgeRay(scene, rnd_val, edge_dist, edge_indices, eRec);
        if (eRec.shape_id < 0)
        {
            return 0.0f;
        }
        const Shape *shape = scene.shape_list[eRec.shape_id];
        const Edge &edge = shape->edges[eRec.edge_id];

        Ray edgeRay(eRec.ref, eRec.dir);
        Intersection itsS, itsD;
        if (!scene.rayIntersect(edgeRay, true, itsS) ||
            !scene.rayIntersect(edgeRay.flipped(), true, itsD))
            return 0.0;
        // populate the data in BoundarySamplingRecord eRec
        eRec.shape_id_S = itsS.indices[0];
        eRec.tri_id_S = itsS.indices[1];
        eRec.shape_id_D = itsD.indices[0];
        eRec.tri_id_D = itsD.indices[1];

        // make sure the ray is tangent to the surface
        if (edge.f0 >= 0 && edge.f1 >= 0)
        {
            Vector n0 = shape->getGeoNormal(edge.f0),
                   n1 = shape->getGeoNormal(edge.f1);
            Float dotn0 = edgeRay.dir.dot(n0),
                  dotn1 = edgeRay.dir.dot(n1);
            if (math::signum(dotn0) * math::signum(dotn1) > -0.5)
            {
                PSDR_ASSERT_MSG(false, "Bad edge ray sample: [{}, {}]", dotn0, dotn1);
                return 0.0;
            }
        }

        /* prevent self intersection */
        const Vector2i ind0(eRec.shape_id, edge.f0), ind1(eRec.shape_id, edge.f1);
        if (itsS.indices == ind0 || itsS.indices == ind1 ||
            itsD.indices == ind0 || itsD.indices == ind1)
            return 0.0;

        // FIXME: if the isTransmissive() of a dielectric returns false, the rendering time will be short, and the variance will be small.
        const Float gn1d1 = itsS.geoFrame.n.dot(-edgeRay.dir), sn1d1 = itsS.shFrame.n.dot(-edgeRay.dir),
                    gn2d1 = itsD.geoFrame.n.dot(edgeRay.dir), sn2d1 = itsD.shFrame.n.dot(edgeRay.dir);
        bool valid1 = (itsS.ptr_bsdf->isTransmissive() && math::signum(gn1d1) * math::signum(sn1d1) > 0.5f) || (!itsS.ptr_bsdf->isTransmissive() && gn1d1 > Epsilon && sn1d1 > Epsilon),
             valid2 = (itsD.ptr_bsdf->isTransmissive() && math::signum(gn2d1) * math::signum(sn2d1) > 0.5f) || (!itsD.ptr_bsdf->isTransmissive() && gn2d1 > Epsilon && sn2d1 > Epsilon);
        if (!valid1 || !valid2)
            return 0.0;

        if (eRec.shape_id != shape_opt_id && shape_opt_id != -1)
        {
            return 0.0;
        }

        Float max_normal = 1.0;

        if (local_backward)
        {
            const Vector &xB_0 = shape->getVertex(edge.v0);
            const Vector &xB_1 = shape->getVertex(edge.v1);
            const Vector &xB_2 = shape->getVertex(edge.v2);

            const Shape *shapeS = scene.shape_list[eRec.shape_id_S];
            const auto &indS = shapeS->getIndices(eRec.tri_id_S);
            const Vector &xS_0 = shapeS->getVertex(indS[0]);
            const Vector &xS_1 = shapeS->getVertex(indS[1]);
            const Vector &xS_2 = shapeS->getVertex(indS[2]);

            const Shape *shapeD = scene.shape_list[eRec.shape_id_D];
            const auto &indD = shapeD->getIndices(eRec.tri_id_D);
            const Vector &xD_0 = shapeD->getVertex(indD[0]);
            const Vector &xD_1 = shapeD->getVertex(indD[1]);
            const Vector &xD_2 = shapeD->getVertex(indD[2]);

            BoundarySegmentInfo segInfo;
            segInfo.xS_0 = xD_0;
            segInfo.xS_1 = xD_1;
            segInfo.xS_2 = xD_2;

            segInfo.xB_0 = xB_0;
            segInfo.xB_1 = xB_1;

            segInfo.xD_0 = xS_0;
            segInfo.xD_1 = xS_1;
            segInfo.xD_2 = xS_2;

            BoundarySegmentInfo d_segInfo;
            d_segInfo.setZero();

            d_normal_velocity(segInfo, d_segInfo, xB_2, eRec.t, eRec.dir, 1.0);

            if (shape_opt_id == -1)
            {
                max_normal = d_segInfo.maxCoeff();
            }
            else
            {
                max_normal = d_segInfo.maxCoeff(eRec.shape_id_S == shape_opt_id, eRec.shape_id == shape_opt_id, eRec.tri_id_D == shape_opt_id);
            }

            if (max_normal < Epsilon)
            {
                return 0.0;
            }
        }

        /* Jacobian determinant that accounts for the change of variable */
        Vector v0 = shape->getVertex(edge.v0);
        Vector v1 = shape->getVertex(edge.v1);
        Vector v2 = shape->getVertex(edge.v2);
        const Vector xB = v0 + (v1 - v0) * eRec.t,
                     &xS = itsS.p;
        const Vector &xD = itsD.p;
        Vector n = (v0 - v1).cross(-edgeRay.dir).normalized();
        n *= -math::signum(n.dot(v2 - v0)); // make sure n points to the visible side
        Float J = dlD_dlB(xS,
                          xB, (v0 - v1).normalized(),
                          xD, itsD.geoFrame.n) *
                  dA_dw(xB, xS, itsS.geoFrame.n);
        Float baseValue = J * geometric(xD, itsD.geoFrame.n, xS, itsS.geoFrame.n) * max_normal;

        // assert(baseValue > 0.0);

        Matrix2x4 pixel_uvs;
        Array4 attenuations(0.0);
        Vector dir;

        if (!scene.isVisible(itsD.p, true, scene.camera.cpos, true))
            return 0.0;
        scene.camera.sampleDirect(itsD.p, pixel_uvs, attenuations, dir);

        auto bsdf_val = itsD.evalBSDF(itsD.toLocal(dir),
                                      EBSDFMode::EImportanceWithCorrection);

        if (eRec.pdf < Epsilon)
        {
            return 0.0;
        }
        if (attenuations.maxCoeff() < Epsilon)
        {
            return 0.0;
        }

        size_t matched_indices[NUM_NEAREST_NEIGHBORS];

        Float pt_rad[3] = {itsD.p[0], itsD.p[1], itsD.p[2]};
        Float pt_imp[3] = {itsS.p[0], itsS.p[1], itsS.p[2]};
        Float matched_r2_rad, matched_r2_imp;

        int num_nearby_rad = queryPhotonMap(rad_indices, pt_rad, matched_indices, matched_r2_rad);
        std::vector<Spectrum> photon_radiances(max_bounces + 1, Spectrum::Zero());
        for (int m = 0; m < num_nearby_rad; m++)
        {
            const RadImpNode &node = rad_nodes[matched_indices[m]];
            if (node.depth <= max_bounces)
                photon_radiances[node.depth] += node.val;
        }

        int num_nearby_imp = queryPhotonMap(imp_indices, pt_imp, matched_indices, matched_r2_imp);
        std::vector<Spectrum> importance(max_bounces, Spectrum::Zero());
        for (int m = 0; m < num_nearby_imp; m++)
        {
            const RadImpNode &node = imp_nodes[matched_indices[m]];
            if (node.depth < max_bounces)
                importance[node.depth] += node.val;
        }

        Spectrum value2 = Spectrum::Zero();
        int impStart = 1;
        for (int m = 0; m <= max_bounces; m++)
        {
            for (int n = impStart; n < max_bounces - m; n++)
                value2 += photon_radiances[m] * importance[n];
        }

        return abs(baseValue * eRec.pdf) * abs(value2.maxCoeff() / (matched_r2_rad * matched_r2_imp));
    }

    void d_sampleDirectBoundary(const Scene &scene, Scene &d_scene,
                                RndSampler &sampler, const RenderOptions &options,
                                const DiscreteDistribution &edge_dist,
                                const std::vector<Vector2i> &edge_indices,
                                std::vector<Spectrum> &d_image, const Vector3 &d_rnd, Float d_pdf)
    {
        MALA::MALAVector pss(3 * options.max_bounces);
        BoundarySamplingRecord eRec;
        // EdgeRaySamplingRecord eRec;
        sampleEdgeRay(scene, d_rnd, edge_dist, edge_indices, eRec);
        if (eRec.shape_id < 0)
        {
            return;
        }
        const Shape *shape = scene.shape_list[eRec.shape_id];
        const Edge &edge = shape->edges[eRec.edge_id];
        Ray edgeRay(eRec.ref, eRec.dir);
        Intersection itsS, itsD;
        if (!scene.rayIntersect(edgeRay, true, itsS) ||
            !scene.rayIntersect(edgeRay.flipped(), true, itsD))
            return;
        // populate the data in BoundarySamplingRecord eRec
        eRec.shape_id_S = itsS.indices[0];
        eRec.tri_id_S = itsS.indices[1];
        eRec.shape_id_D = itsD.indices[0];
        eRec.tri_id_D = itsD.indices[1];

        // make sure the ray is tangent to the surface
        if (edge.f0 >= 0 && edge.f1 >= 0)
        {
            Vector n0 = shape->getGeoNormal(edge.f0),
                   n1 = shape->getGeoNormal(edge.f1);
            Float dotn0 = edgeRay.dir.dot(n0),
                  dotn1 = edgeRay.dir.dot(n1);
            if (math::signum(dotn0) * math::signum(dotn1) > -0.5)
            {
                PSDR_ASSERT_MSG(false, "Bad edge ray sample: [ {}, {} ]", dotn0, dotn1);
                return;
            }
        }

        /* prevent self intersection */
        const Vector2i ind0(eRec.shape_id, edge.f0), ind1(eRec.shape_id, edge.f1);
        if (itsS.indices == ind0 || itsS.indices == ind1 ||
            itsD.indices == ind0 || itsD.indices == ind1)
            return;
        // FIXME: if the isTransmissive() of a dielectric returns false, the rendering time will be short, and the variance will be small.
        const Float gn1d1 = itsS.geoFrame.n.dot(-edgeRay.dir), sn1d1 = itsS.shFrame.n.dot(-edgeRay.dir),
                    gn2d1 = itsD.geoFrame.n.dot(edgeRay.dir), sn2d1 = itsD.shFrame.n.dot(edgeRay.dir);
        bool valid1 = (itsS.ptr_bsdf->isTransmissive() && math::signum(gn1d1) * math::signum(sn1d1) > 0.5f) || (!itsS.ptr_bsdf->isTransmissive() && gn1d1 > Epsilon && sn1d1 > Epsilon),
             valid2 = (itsD.ptr_bsdf->isTransmissive() && math::signum(gn2d1) * math::signum(sn2d1) > 0.5f) || (!itsD.ptr_bsdf->isTransmissive() && gn2d1 > Epsilon && sn2d1 > Epsilon);
        if (!valid1 || !valid2)
            return;

        /* Jacobian determinant that accounts for the change of variable */
        Vector v0 = shape->getVertex(edge.v0);
        Vector v1 = shape->getVertex(edge.v1);
        Vector v2 = shape->getVertex(edge.v2);
        const Vector xB = v0 + (v1 - v0) * eRec.t,
                     &xS = itsS.p;
        const Vector &xD = itsD.p;
        Vector n = (v0 - v1).cross(-edgeRay.dir).normalized();
        n *= -math::signum(n.dot(v2 - v0)); // make sure n points to the visible side
        Float J = dlD_dlB(xS,
                          xB, (v0 - v1).normalized(),
                          xD, itsD.geoFrame.n) *
                  dA_dw(xB, xS, itsS.geoFrame.n);
        Float baseValue = J * geometric(xD, itsD.geoFrame.n, xS, itsS.geoFrame.n) / d_pdf;
        // assert(baseValue > 0.0);

        /* Sample source path */
        BoundaryRadianceQueryRecord rRec(&sampler, options.max_bounces);
        __Li(scene, Ray{xB, (xS - xB).normalized()}.shifted(), rRec);
        std::vector<Spectrum> radiances(std::move(rRec.values));
        radiances[0] = Spectrum(0.0f);
        for (int i = 1; i < radiances.size(); i++)
            radiances[i] += radiances[i - 1];

        /* Sample detector path */
        Spectrum throughput(1.0f);
        Ray ray_sensor;
        Intersection its(itsD);
        for (int i = 0; i < options.max_bounces; i++)
        {
            handleSurfaceInteraction(its, scene, d_scene,
                                     eRec, sampler,
                                     radiances[options.max_bounces - 1 - i] * throughput * baseValue / eRec.pdf,
                                     options.max_bounces, d_image);
            Vector wo_local, wo;
            Float bsdf_pdf, bsdf_eta;
            Spectrum bsdf_weight = its.sampleBSDF(sampler.next3D(), wo_local,
                                                  bsdf_pdf, bsdf_eta,
                                                  EBSDFMode::EImportanceWithCorrection);
            if (bsdf_weight.isZero())
                break;
            wo = its.toWorld(wo_local);
            Vector wi = its.toWorld(its.wi);
            Float wiDotGeoN = wi.dot(its.geoFrame.n), woDotGeoN = wo.dot(its.geoFrame.n);
            if (wiDotGeoN * its.wi.z() <= 0 || woDotGeoN * wo_local.z() <= 0)
                break;
            throughput *= bsdf_weight;
            ray_sensor = Ray(its.p, wo);
            scene.rayIntersect(ray_sensor, true, its);
            if (!its.isValid())
                break;
        }
    }
} // namespace

IndirectEdgeIntegrator::IndirectEdgeIntegrator(const Scene &scene)
{
    configure(scene);
}

IndirectEdgeIntegrator::IndirectEdgeIntegrator(const Properties &props)
{}

void IndirectEdgeIntegrator::configure(const Scene &scene)
{
    PSDR_INFO("Configuring edge_dist_inv...");
    edge_dist.clear();
    edge_indices.clear();
    // edge_indices_inv.clear();
    // edge_indices_inv.resize(scene.shape_list.size());
    /* generate the edge distribution */
    for (size_t i = 0; i < scene.shape_list.size(); i++)
    {
        // auto &indices_inv = edge_indices_inv[i];
        auto &shape = *scene.shape_list[i];
        // indices_inv.clear();
        // indices_inv.resize(shape.edges.size());
        // for (size_t j = 0; j < shape.edges.size(); j++)
        // {
        //     indices_inv[j] = -1;
        // }
        if (!shape.enable_edge)
        {
            continue;
        }
        const BSDF &bsdf = *scene.bsdf_list[shape.bsdf_id];
        if (bsdf.isNull() &&
            shape.light_id < 0)
            continue;

        Float total_dist = 0.0;
        for (size_t j = 0; j < shape.edges.size(); j++)
        {
            const Edge &edge = shape.edges[j];

            if (edge.mode == 0 ||
                bsdf.isTransmissive())
            {
                // indices_inv[j] = edge_indices.size();
                edge_indices.push_back({i, j});
                edge_dist.append(edge.length * 4 * M_PI); // area of a unit sphere
                if (!shape.enable_draw)
                {
                    draw_dist.append(edge.length * 4 * M_PI);
                    total_dist = 0.0;
                }
                total_dist += edge.length * 4 * M_PI;
            }
            if (edge.mode > 0)
            {
                // indices_inv[j] = edge_indices.size();
                const Vector &n0 = shape.getGeoNormal(edge.f0),
                             &n1 = shape.getGeoNormal(edge.f1);
                Float cos_angle = n0.dot(n1);
                edge_indices.push_back({i, j});
                if (!shape.enable_draw)
                {
                    draw_dist.append(edge.length * 4 * std::acos(cos_angle));
                    total_dist = 0.0;
                }
                edge_dist.append(edge.length * 4 * std::acos(cos_angle));
                total_dist += edge.length * 4 * std::acos(cos_angle);
            }
            if (shape.enable_draw)
            {
                if (j < shape.edges.size() - 1)
                {
                    const Edge &edge_next = shape.edges[j + 1];
                    if (edge.v1 != edge_next.v0)
                    {
                        draw_dist.append(total_dist);
                        total_dist = 0.0;
                    }
                }
            }
        }
        if (shape.enable_draw)
        {
            draw_dist.append(total_dist);
            total_dist = 0.0;
        }
    }
    // for (size_t i = 0; i < scene.shape_list.size(); i++)
    // {
    //     auto &shape = *scene.shape_list[i];
    //     for (size_t j = 0; j < shape.edges.size(); j++)
    //     {
    //         int idx = edge_indices_inv[i][j];
    //         if (idx == -1) continue;
    //         if (edge_indices[idx][0] != i || edge_indices[idx][1] != j)
    //         {
    //             PSDR_INFO("idx: {}", idx);
    //             PSDR_INFO("edge_indices[idx][0]: {}", edge_indices[idx][0]);
    //             PSDR_INFO("edge_indices[idx][1]: {}", edge_indices[idx][1]);
    //             PSDR_INFO("i: {}", i);
    //             PSDR_INFO("j: {}", j);
    //             assert(false);
    //         }
    //     }
    // }
    edge_dist.normalize();
    draw_dist.normalize();
    // if (edge_dist.m_cdf.size() != edge_indices.size() + 1) {
    //     PSDR_INFO("edge_dist.size(): {}", edge_dist.m_cdf.size());
    //     PSDR_INFO("edge_indices.size(): {}", edge_indices.size());
    //     assert(false);
    // }
    // std::cout << "edge dist size: " << edge_dist.m_cdf.size() - 1 << std::endl;
    // std::cout << "edge draw size: " << draw_dist.m_cdf.size() - 1 << std::endl;
}

void IndirectEdgeIntegrator::recompute_edge(const Scene &scene)
{
    edge_dist.clear();
    draw_dist.clear();
    edge_indices.clear();
    // edge_indices_inv.clear();
    // edge_indices_inv.resize(scene.shape_list.size());
    for (size_t i = 0; i < scene.shape_list.size(); i++)
    {
        // auto &indices_inv = edge_indices_inv[i];
        // indices_inv.clear();
        auto &shape = *scene.shape_list[i];
        // indices_inv.resize(shape.edges.size());
        // for (size_t j = 0; j < shape.edges.size(); j++)
        // {
        //     indices_inv[j] = -1;
        // }
        if (!shape.enable_edge)
        {
            continue;
        }
        const BSDF &bsdf = *scene.bsdf_list[shape.bsdf_id];
        if (bsdf.isNull() &&
            shape.light_id < 0)
            continue;

        Float total_dist = 0.0;
        for (size_t j = 0; j < shape.edges.size(); j++)
        {
            const Edge &edge = shape.edges[j];

            if (edge.mode == 0 ||
                bsdf.isTransmissive())
            {
                // indices_inv[j] = edge_indices.size();
                edge_indices.push_back({i, j});
                edge_dist.append(edge.length * 4 * M_PI); // area of a unit sphere
                if (!shape.enable_draw)
                {
                    draw_dist.append(edge.length * 4 * M_PI);
                    total_dist = 0.0;
                }
                total_dist += edge.length * 4 * M_PI;
            }
            if (edge.mode > 0)
            {
                // indices_inv[j] = edge_indices.size();
                const Vector &n0 = shape.getGeoNormal(edge.f0),
                             &n1 = shape.getGeoNormal(edge.f1);
                Float cos_angle = n0.dot(n1);
                edge_indices.push_back({i, j});
                if (!shape.enable_draw)
                {
                    draw_dist.append(edge.length * 4 * std::acos(cos_angle));
                    total_dist = 0.0;
                }
                edge_dist.append(edge.length * 4 * std::acos(cos_angle));
                total_dist += edge.length * 4 * std::acos(cos_angle);
            }
            if (shape.enable_draw)
            {
                if (j < shape.edges.size() - 1)
                {
                    const Edge &edge_next = shape.edges[j + 1];
                    if (edge.v1 != edge_next.v0)
                    {
                        draw_dist.append(total_dist);
                        total_dist = 0.0;
                    }
                }
            }
        }
        if (shape.enable_draw)
        {
            draw_dist.append(total_dist);
            total_dist = 0.0;
        }
    }
    edge_dist.normalize();
    draw_dist.normalize();
    std::cout << "edge draw size: " << draw_dist.m_cdf.size() - 1 << std::endl;
}

void IndirectEdgeIntegrator::preprocess_grid(const Scene &scene, const Grid3D_Sampling::grid3D_config &config, int max_bounces)
{
    PhotonMapOptions opts(10000, 10000, max_bounces);
    std::vector<RadImpNode> rad_nodes, imp_nodes;

    std::cout << "[INFO] Indirect Guiding: #camPath = " << opts.num_cam_path << ", #lightPath = " << opts.num_light_path << std::endl;
    buildPhotonMap(scene, opts.num_cam_path, max_bounces + 1, rad_nodes, false);
    buildPhotonMap(scene, opts.num_light_path, max_bounces, imp_nodes, true);
    std::cout << "[INFO] Indirect Guiding: #rad_nodes = " << rad_nodes.size() << ", #imp_nodes = " << imp_nodes.size() << std::endl;
    PointCloud<Float> rad_cloud, imp_cloud;
    rad_cloud.pts.resize(rad_nodes.size());
    for (size_t i = 0; i < rad_nodes.size(); i++)
    {
        rad_cloud.pts[i].x = rad_nodes[i].p[0];
        rad_cloud.pts[i].y = rad_nodes[i].p[1];
        rad_cloud.pts[i].z = rad_nodes[i].p[2];
    }
    imp_cloud.pts.resize(imp_nodes.size());
    for (size_t i = 0; i < imp_nodes.size(); i++)
    {
        imp_cloud.pts[i].x = imp_nodes[i].p[0];
        imp_cloud.pts[i].y = imp_nodes[i].p[1];
        imp_cloud.pts[i].z = imp_nodes[i].p[2];
    }
    KDtree<Float> rad_indices(3, rad_cloud, nanoflann::KDTreeSingleIndexAdaptorParams(10));
    KDtree<Float> imp_indices(3, imp_cloud, nanoflann::KDTreeSingleIndexAdaptorParams(10));
    imp_indices.buildIndex();
    rad_indices.buildIndex();

    std::cout << "preprocessing IndirectEdgeIntegrator" << std::endl;
    auto NEE_function = [&](const Vector &AQ_rnd, RndSampler &sampler)
    {
        // Float result = eval_InDirectBoundary(AQ_rnd, scene, sampler, edge_dist, edge_indices, max_bounces);
        Float result = eval_photon_InDirectBoundary(AQ_rnd, scene, sampler, edge_dist, edge_indices, max_bounces, rad_nodes, imp_nodes, rad_indices, imp_indices, -1, false);
        if (!isnan(result))
        {
            return result;
        }
        return 0.0;
    };
    grid_distrb.setup(NEE_function, config);
    std::cout << "finish grid guiding" << std::endl;
}

void IndirectEdgeIntegrator::preprocess_aq(const Scene &scene, const Adaptive_Sampling::adaptive3D_config &config, int max_bounces)
{
    PhotonMapOptions opts(10000, 10000, max_bounces);
    std::vector<RadImpNode> rad_nodes, imp_nodes;

    std::cout << "[INFO] Indirect Guiding: #camPath = " << opts.num_cam_path << ", #lightPath = " << opts.num_light_path << std::endl;
    buildPhotonMap(scene, opts.num_cam_path, max_bounces + 1, rad_nodes, false);
    buildPhotonMap(scene, opts.num_light_path, max_bounces, imp_nodes, true);
    std::cout << "[INFO] Indirect Guiding: #rad_nodes = " << rad_nodes.size() << ", #imp_nodes = " << imp_nodes.size() << std::endl;
    PointCloud<Float> rad_cloud, imp_cloud;
    rad_cloud.pts.resize(rad_nodes.size());
    for (size_t i = 0; i < rad_nodes.size(); i++)
    {
        rad_cloud.pts[i].x = rad_nodes[i].p[0];
        rad_cloud.pts[i].y = rad_nodes[i].p[1];
        rad_cloud.pts[i].z = rad_nodes[i].p[2];
    }
    imp_cloud.pts.resize(imp_nodes.size());
    for (size_t i = 0; i < imp_nodes.size(); i++)
    {
        imp_cloud.pts[i].x = imp_nodes[i].p[0];
        imp_cloud.pts[i].y = imp_nodes[i].p[1];
        imp_cloud.pts[i].z = imp_nodes[i].p[2];
    }
    KDtree<Float> rad_indices(3, rad_cloud, nanoflann::KDTreeSingleIndexAdaptorParams(10));
    KDtree<Float> imp_indices(3, imp_cloud, nanoflann::KDTreeSingleIndexAdaptorParams(10));
    imp_indices.buildIndex();
    rad_indices.buildIndex();

    std::cout << "AQ preprocessing IndirectEdgeIntegrator" << std::endl;
    auto NEE_function = [&](const Vector &AQ_rnd, RndSampler &sampler)
    {
        // Float result = eval_InDirectBoundary(AQ_rnd, scene, sampler, edge_dist, edge_indices, max_bounces);
        Float result = eval_photon_InDirectBoundary(AQ_rnd, scene, sampler, edge_dist, edge_indices, max_bounces, rad_nodes, imp_nodes, rad_indices, imp_indices, config.shape_opt_id, config.local_backward);
        if (!isnan(result))
        {
            if (result < config.eps)
            {
                return config.eps;
            }
            return result;
        }
        return config.eps;
    };

    std::vector<Float> cdfx;
    if (config.edge_draw)
    {
        std::cout << "AQ using draw edge" << std::endl;
        cdfx = draw_dist.m_cdf;
    }
    else
    {
        std::cout << "AQ using individual edge" << std::endl;
        cdfx = edge_dist.m_cdf;
    }

    std::cout << "Inital curvature size: " << cdfx.size() << std::endl;
    aq_distrb.setup(NEE_function, cdfx, config);
    // aq_distrb.print();
    std::cout << "finish AQ guiding" << std::endl;
}

ArrayXd IndirectEdgeIntegrator::renderD(SceneAD &sceneAD,
                                        RenderOptions &options, const ArrayXd &__d_image) const
{
    PSDR_INFO("IndirectEdgeIntegrator::renderD with spp = {}",
              options.num_samples_secondary_edge_indirect);
    int nsamples = options.num_samples_secondary_edge_indirect;
    const Scene &scene = sceneAD.val;
    [[maybe_unused]] Scene &d_scene = sceneAD.der;
    GradientManager<Scene> &gm = sceneAD.gm;
    gm.setZero(); // zero multi-thread gradient

    const int nworker = omp_get_num_procs();
    const auto &camera = scene.camera;
    const int nblocks = std::ceil(static_cast<Float>(camera.getNumPixels()) / (options.block_size * options.block_size));
    const int nblock_samples = options.block_size * options.block_size * nsamples;
    /* init debug info */
    debugInfo = DebugInfo(nworker, camera.getNumPixels(), nsamples);
    if (nsamples <= 0)
        return debugInfo.getArray();

    std::vector<Spectrum> _d_image_spec_list = from_tensor_to_spectrum_list(
        __d_image / nblock_samples / nblocks, camera.getNumPixels());

    Timer _("Indirect boundary");
    int blockProcessed = 0;
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < nblocks; i++)
    {
        const int tid = omp_get_thread_num();
        for (int j = 0; j < nblock_samples; j++)
        {
            RndSampler sampler(options.seed, i * nblock_samples + j);
            Vector3 d_rnd = sampler.next3D();
            Float d_pdf = 1.0;
            if (grid_distrb.distrb.getSum() > Epsilon)
            {
                d_rnd = grid_distrb.sample(d_rnd, d_pdf);
            }
            else if (aq_distrb.distrb.getSum() > Epsilon)
            {
                d_rnd = aq_distrb.sample(d_rnd, d_pdf);
            }

            // sample a point on the boundary
            d_sampleDirectBoundary(scene, gm.get(tid),
                                   sampler, options,
                                   edge_dist, edge_indices,
                                   _d_image_spec_list,
                                   d_rnd, d_pdf);
        }
        if (verbose)
#pragma omp critical
            progressIndicator(static_cast<Float>(++blockProcessed) / nblocks);
    }
    if (verbose)
        std::cout << std::endl;

    // merge d_scenes
    gm.merge();
    d_scene.configureD(scene);
    /* normal related */
#ifdef NORMAL_PREPROCESS
    Timer preprocess_timer("preprocess");
    d_precompute_normal(scene, d_scene);
#endif
    return flattened(debugInfo.getArray());
}

ArrayXd IndirectEdgeIntegrator::forwardRenderD(SceneAD &sceneAD,
                                        RenderOptions &options) const
{
    PSDR_INFO("IndirectEdgeIntegrator::forwardRenderD with spp = {}",
              options.num_samples_secondary_edge_indirect);
    int nsamples = options.num_samples_secondary_edge_indirect;
    const Scene &scene = sceneAD.val;
    [[maybe_unused]] Scene &d_scene = sceneAD.der;
    GradientManager<Scene> &gm = sceneAD.gm;
    gm.setZero(); // zero multi-thread gradient

    const int nworker = omp_get_num_procs();
    const auto &camera = scene.camera;
    const int nblocks = std::ceil(static_cast<Float>(camera.getNumPixels()) / (options.block_size * options.block_size));
    const int nblock_samples = options.block_size * options.block_size * nsamples;
    /* init debug info */
    debugInfo = DebugInfo(nworker, camera.getNumPixels(), nsamples);
    if (nsamples <= 0)
        return debugInfo.getArray();

    std::vector<Spectrum> _d_image_spec_list(camera.getNumPixels(), Spectrum(1, 0, 0) / nblock_samples / nblocks);

    Timer _("Indirect boundary");
    int blockProcessed = 0;
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < nblocks; i++)
    {
        const int tid = omp_get_thread_num();
        for (int j = 0; j < nblock_samples; j++)
        {
            RndSampler sampler(options.seed, i * nblock_samples + j);
            Vector3 d_rnd = sampler.next3D();
            Float d_pdf = 1.0;
            if (grid_distrb.distrb.getSum() > Epsilon)
            {
                d_rnd = grid_distrb.sample(d_rnd, d_pdf);
            }
            else if (aq_distrb.distrb.getSum() > Epsilon)
            {
                d_rnd = aq_distrb.sample(d_rnd, d_pdf);
            }

            // sample a point on the boundary
            d_sampleDirectBoundary(scene, gm.get(tid),
                                   sampler, options,
                                   edge_dist, edge_indices,
                                   _d_image_spec_list,
                                   d_rnd, d_pdf);
        }
        if (verbose)
#pragma omp critical
            progressIndicator(static_cast<Float>(++blockProcessed) / nblocks);
    }
    if (verbose)
        std::cout << std::endl;

    // merge d_scenes
    gm.merge();
    d_scene.configureD(scene);
    /* normal related */
#ifdef NORMAL_PREPROCESS
    Timer preprocess_timer("preprocess");
    d_precompute_normal(scene, d_scene);
#endif
    return flattened(debugInfo.getArray());
}