#include "integratorADps.h"
#include "scene.h"
#include "sampler.h"
#include "rayAD.h"
#include "math_func.h"
#include "bidir_utils.h"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <fstream>

#define NUM_NEAREST_NEIGHBOR 20
#define MAX_NEAREST_NEIGHBOR 200
#define BSDF_CLAMPING_MAX 1e4
#define BSDF_CLAMPING_MIN 1e-4
// #define VERBOSE

// Note: Assume we have pixel filters, otherwise the sensor_val from camera connections might be broken.

static bidir::PathNode paths[2*BDPT_MAX_THREADS][BDPT_MAX_PATH_LENGTH];

int IntegratorAD_PathSpace::evalEdgeDirect(const Scene &scene, int shape_id, const Edge &rEdge, const RayAD &edgeRay, RndSampler *sampler, int max_bounces,
                                           EdgeEvaluationRecord &eRec, std::pair<int, Spectrum>* record) const
{
#ifndef USE_BOUNDARY_NEE
    std::cerr << "Without next-event estimation (NEE), evalEdgeDirect() should not be used." << std::endl;
    assert(false);
#endif
    Intersection &its1 = eRec.its1, &its2 = eRec.its2;
    int &idx_pixel = eRec.idx_pixel, ret = 0;

    eRec.value0.zero();
    eRec.value1.zero();
    idx_pixel = -1;

    eRec.sensor_vals = Array4(0.0);
    eRec.pixel_indices = Array4i(-1);

    const Shape &shape = *scene.shape_list[shape_id];
    Ray _edgeRay = edgeRay.toRay();
    const Vector &d1 = _edgeRay.dir;
    if ( scene.rayIntersect(_edgeRay, true, its2) && scene.rayIntersect(_edgeRay.flipped(), true, its1) ) {
        const Vector2i ind0(shape_id, rEdge.f0), ind1(shape_id, rEdge.f1);
        if ( its1.indices != ind0 && its1.indices != ind1 && its2.indices != ind0 && its2.indices != ind1 && its2.isEmitter() ) {
            const Float gn1d1 = its1.geoFrame.n.dot(d1), sn1d1 = its1.shFrame.n.dot(d1),
                        gn2d1 = its2.geoFrame.n.dot(-d1), sn2d1 = its2.shFrame.n.dot(-d1);
            assert(std::abs(its1.wi.z() - sn1d1) < Epsilon && std::abs(its2.wi.z() - sn2d1) < Epsilon);

            bool valid1 = (its1.ptr_bsdf->isTransmissive() && math::signum(gn1d1)*math::signum(sn1d1) > 0.5f) || (!its1.ptr_bsdf->isTransmissive() && gn1d1 > Epsilon && sn1d1 > Epsilon),
                 valid2 = (its2.ptr_bsdf->isTransmissive() && math::signum(gn2d1)*math::signum(sn2d1) > 0.5f) || (!its2.ptr_bsdf->isTransmissive() && gn2d1 > Epsilon && sn2d1 > Epsilon);
            if ( valid1 && valid2 ) {
                // const Shape &shape1 = *scene.shape_list[its1.indices[0]]; const Vector3i &f1 = shape1.getIndices(its1.indices[1]);
                const Shape &shape2 = *scene.shape_list[its2.indices[0]]; const Vector3i &f2 = shape2.getIndices(its2.indices[1]);
                const VectorAD &v0 = shape2.getVertexAD(f2[0]), &v1 = shape2.getVertexAD(f2[1]), &v2 = shape2.getVertexAD(f2[2]);
                Float baseValue = 0.0f;
                Vector n;
                {
                    Float dist = (its2.p - its1.p).norm();
                    Float cos2 = std::abs(gn2d1);
                    Vector e = (shape.getVertex(rEdge.v0) - shape.getVertex(rEdge.v1)).normalized().cross(d1);
                    Float sinphi = e.norm();
                    Vector proj = e.cross(its2.geoFrame.n).normalized();
                    Float sinphi2 = d1.cross(proj).norm();
                    n = its2.geoFrame.n.cross(proj).normalized();

                    Float deltaV;
                    Vector e1 = shape.getVertex(rEdge.v2) - shape.getVertex(rEdge.v0);
                    deltaV = math::signum(e.dot(e1))*math::signum(e.dot(n));

                    if ( sinphi > Epsilon && sinphi2 > Epsilon )
                        baseValue = deltaV*(its1.t/dist)*(sinphi/sinphi2)*cos2;
                }

                if ( std::abs(baseValue) > Epsilon ) {
                    VectorAD u2;

                    // Direct
                    {
                        Vector2 pix_uv;
                        Vector d0;
                        Float sensor_val = scene.sampleAttenuatedSensorDirect(its1, sampler, 0, pix_uv, d0);
                        if ( sensor_val > Epsilon ) {
                            RayAD cameraRay = scene.camera.samplePrimaryRayAD(pix_uv[0], pix_uv[1]);
                            IntersectionAD its;
                            if ( scene.rayIntersectAD(cameraRay, false, its) && (its.p.val - its1.p).norm() < ShadowEpsilon ) {
                                bool valid = rayIntersectTriangleAD(v0, v1, v2, RayAD(its.p, edgeRay.org - its.p), u2);
                                if (valid) {
                                    Vector d0_local = its1.toLocal(d0);
                                    Spectrum value0 = its1.evalBSDF(d0_local, EBSDFMode::EImportanceWithCorrection) * baseValue * its2.Le(-d1);
                                    for ( int j = 0; j < nder; ++j )
                                        eRec.value0.grad(j) = value0*n.dot(u2.grad(j));
                                    
                                    Matrix2x4 pix_uvs;
                                    eRec.sensor_vals = scene.sampleAttenuatedSensorDirect(its1, sampler, 0, pix_uvs, d0);
                                    for (int i = 0; i < 4; i++) {
                                        eRec.pixel_indices[i] = scene.camera.getPixelIndex(pix_uvs.col(i));
                                    }
                                    idx_pixel = scene.camera.getPixelIndex(pix_uv);
                                }
                            }
                        }
                    }

                    // indirect
                    if ( max_bounces > 1 ) {
                        VectorAD x1, n1;
                        FloatAD J1;
                        scene.getPoint(its1, x1, n1, J1);
                        bool valid = rayIntersectTriangleAD(v0, v1, v2, RayAD(x1, edgeRay.org - x1), u2);
                        if (valid) {
                            for ( int j = 0; j < nder; ++j )
                                eRec.value1.grad(j) = baseValue*n.dot(u2.grad(j));
                            ret = weightedImportance(scene, sampler, its1, max_bounces - 1, nullptr, record);
                        }
                    }
                }
            }
        }

        for ( int j = 0; j < nder; ++j ) {
            if ( !std::isfinite(eRec.value0.grad(j)[0]) || !std::isfinite(eRec.value0.grad(j)[1]) || !std::isfinite(eRec.value0.grad(j)[2]) ) {
                omp_set_lock(&messageLock);
                std::cerr << std::fixed << std::setprecision(2)
                          << "\n[WARN] Invalid gradient: [" << eRec.value0.grad(j).transpose() << "]" << std::endl;
                omp_unset_lock(&messageLock);
                eRec.value0.grad(j).setZero();
            }

            if ( !std::isfinite(eRec.value1.grad(j)) ) {
                omp_set_lock(&messageLock);
                std::cerr << std::fixed << std::setprecision(2)
                          << "\n[WARN] Invalid gradient: [" << eRec.value1.grad(j) << "]" << std::endl;
                omp_unset_lock(&messageLock);
                eRec.value1.grad(j) = 0.0f;
            }
        }
    }
    return ret;
}

void IntegratorAD_PathSpace::evalEdge(const Scene &scene, int shape_id, const Edge &rEdge, const RayAD &edgeRay, RndSampler *sampler, EdgeEvaluationRecord &eRec) const {
    Intersection &its1 = eRec.its1, &its2 = eRec.its2;
    int &idx_pixel = eRec.idx_pixel;

    eRec.value0.zero();
    eRec.value1.zero();
    idx_pixel = -1;

    eRec.sensor_vals = Array4(0.0);
    eRec.pixel_indices = Array4i(-1);

    const Shape &shape = *scene.shape_list[shape_id];

    Ray _edgeRay = edgeRay.toRay();
    const Vector &d1 = _edgeRay.dir;
    if ( scene.rayIntersect(_edgeRay, true, its2) && scene.rayIntersect(_edgeRay.flipped(), true, its1) ) {
        const Vector2i ind0(shape_id, rEdge.f0), ind1(shape_id, rEdge.f1);
        if ( its1.indices != ind0 && its1.indices != ind1 && its2.indices != ind0 && its2.indices != ind1 ) {
            const Float gn1d1 = its1.geoFrame.n.dot(d1), sn1d1 = its1.shFrame.n.dot(d1),
                        gn2d1 = its2.geoFrame.n.dot(-d1), sn2d1 = its2.shFrame.n.dot(-d1);
            assert(std::abs(its1.wi.z() - sn1d1) < Epsilon && std::abs(its2.wi.z() - sn2d1) < Epsilon);

            bool valid1 = (its1.ptr_bsdf->isTransmissive() && math::signum(gn1d1)*math::signum(sn1d1) > 0.5f) || (!its1.ptr_bsdf->isTransmissive() && gn1d1 > Epsilon && sn1d1 > Epsilon),
                 valid2 = (its2.ptr_bsdf->isTransmissive() && math::signum(gn2d1)*math::signum(sn2d1) > 0.5f) || (!its2.ptr_bsdf->isTransmissive() && gn2d1 > Epsilon && sn2d1 > Epsilon);
            if ( valid1 && valid2 ) {
                // const Shape &shape1 = *scene.shape_list[its1.indices[0]]; const Vector3i &f1 = shape1.getIndices(its1.indices[1]);
                const Shape &shape2 = *scene.shape_list[its2.indices[0]]; const Vector3i &f2 = shape2.getIndices(its2.indices[1]);
                const VectorAD &v0 = shape2.getVertexAD(f2[0]), &v1 = shape2.getVertexAD(f2[1]), &v2 = shape2.getVertexAD(f2[2]);

                Float baseValue = 0.0f;
                Vector n;
                {
                    Float dist = (its2.p - its1.p).norm();
                    Float cos2 = std::abs(gn2d1);
                    Vector e = (shape.getVertex(rEdge.v0) - shape.getVertex(rEdge.v1)).normalized().cross(d1);
                    Float sinphi = e.norm();
                    Vector proj = e.cross(its2.geoFrame.n).normalized();
                    Float sinphi2 = d1.cross(proj).norm();
                    n = its2.geoFrame.n.cross(proj).normalized();

                    Float deltaV;
                    Vector e1 = shape.getVertex(rEdge.v2) - shape.getVertex(rEdge.v0);
                    deltaV = math::signum(e.dot(e1))*math::signum(e.dot(n));

                    if ( sinphi > Epsilon && sinphi2 > Epsilon )
                        baseValue = deltaV*(its1.t/dist)*(sinphi/sinphi2)*cos2;
                }

                if ( std::abs(baseValue) > Epsilon ) {
                    VectorAD u2;

                    // Direct
                    {
                        Vector2 pix_uv;
                        Vector d0;
                        Float sensor_val = scene.sampleAttenuatedSensorDirect(its1, sampler, 0, pix_uv, d0);
                        if ( sensor_val > Epsilon ) {
                            RayAD cameraRay = scene.camera.samplePrimaryRayAD(pix_uv[0], pix_uv[1]);
                            IntersectionAD its;
                            if ( scene.rayIntersectAD(cameraRay, false, its) && (its.p.val - its1.p).norm() < ShadowEpsilon ) {
                                bool valid = rayIntersectTriangleAD(v0, v1, v2, RayAD(its.p, edgeRay.org - its.p), u2);
                                if (valid) {
                                    Vector d0_local = its1.toLocal(d0);
                                    // Float correction = std::abs((its1.wi.z()*d0.dot(its1.geoFrame.n))/(d0_local.z()*d1.dot(its1.geoFrame.n)));
                                    Spectrum value0 = its1.evalBSDF(d0_local, EBSDFMode::EImportanceWithCorrection)*baseValue;
                                    for ( int j = 0; j < nder; ++j )
                                        eRec.value0.grad(j) = value0*n.dot(u2.grad(j));

                                    Matrix2x4 pix_uvs;
                                    eRec.sensor_vals = scene.sampleAttenuatedSensorDirect(its1, sampler, 0, pix_uvs, d0);
                                    for (int i = 0; i < 4; i++) {
                                        eRec.pixel_indices[i] = scene.camera.getPixelIndex(pix_uvs.col(i));
                                    }
                                    idx_pixel = scene.camera.getPixelIndex(pix_uv);
                                }
                            }
                        }
                    }

                    // Indirect
                    {
                        VectorAD x1, n1;
                        FloatAD J1;
                        scene.getPoint(its1, x1, n1, J1);
                        bool valid = rayIntersectTriangleAD(v0, v1, v2, RayAD(x1, edgeRay.org - x1), u2);
                        if (valid) {
                            for ( int j = 0; j < nder; ++j ) {
                                eRec.value1.grad(j) = baseValue*n.dot(u2.grad(j));
                            }
                        }
                    }
                }
            }
        }

        for ( int j = 0; j < nder; ++j ) {
            if ( !std::isfinite(eRec.value0.grad(j)[0]) || !std::isfinite(eRec.value0.grad(j)[1]) || !std::isfinite(eRec.value0.grad(j)[2]) ) {
                omp_set_lock(&messageLock);
                std::cerr << std::fixed << std::setprecision(2)
                          << "\n[WARN] Invalid gradient: [" << eRec.value0.grad(j).transpose() << "]" << std::endl;
                omp_unset_lock(&messageLock);
                eRec.value0.grad(j).setZero();
            }

            if ( !std::isfinite(eRec.value1.grad(j)) ) {
                omp_set_lock(&messageLock);
                std::cerr << std::fixed << std::setprecision(2)
                          << "\n[WARN] Invalid gradient: [" << eRec.value1.grad(j) << "]" << std::endl;
                omp_unset_lock(&messageLock);
                eRec.value1.grad(j) = 0.0f;
            }
        }
    }
}

void IntegratorAD_PathSpace::preprocessDirect(const Scene &scene, const std::vector<int> &params, int max_bounces, ptr<float> data, bool quiet) const {
#ifndef USE_BOUNDARY_NEE
    std::cerr << "Without next-event estimation (NEE), preprocessDirect() should not be used." << std::endl;
    assert(false);
#endif
    const int nworker = omp_get_num_procs();
    std::vector<RndSampler> samplers;
    for ( int i = 0; i < nworker; ++i ) samplers.push_back(RndSampler(13, i));
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for ( int omp_i = 0; omp_i < params[0]*params[1]; ++omp_i ) {
        std::pair<int, Spectrum> importance[BDPT_MAX_PATH_LENGTH + 1];
        EdgeEvaluationRecord eRec;
        const int tid = omp_get_thread_num();
        RndSampler &sampler = samplers[tid];

        const int i = omp_i/params[1], j = omp_i % params[1];
        for ( int k = 0; k < params[2]; ++k ) {
            Float res = 0.0f;
            for ( int t = 0; t < params[3]; ++t ) {
                Vector rnd = sampler.next3D();
                rnd[0] = (rnd[0] + i)/static_cast<Float>(params[0]);
                rnd[1] = (rnd[1] + j)/static_cast<Float>(params[1]);
                rnd[2] = (rnd[2] + k)/static_cast<Float>(params[2]);

                int shape_id;
                RayAD edgeRay;
                Float edgePdf, value = 0.0f;
                const Edge &rEdge = scene.sampleEdgeRayDirect(rnd, shape_id, edgeRay, edgePdf);
                if ( shape_id >= 0 ) {
                    int num_indirect_path = evalEdgeDirect(scene, shape_id, rEdge, edgeRay, &sampler, max_bounces, eRec, importance);
                    if (!eRec.sensor_vals.isZero()) {
                        Float avg_sensor_val = 0.0;
                        Float sensor_val_cnt = 0.0;
                        for (int pixel_i = 0; pixel_i < 4; pixel_i++) {
                            if (eRec.sensor_vals[pixel_i] > Epsilon) {
                                avg_sensor_val += eRec.sensor_vals[pixel_i];
                                sensor_val_cnt += 1.0;
                            }
                        }
                        assert(sensor_val_cnt > 0.0);
                        avg_sensor_val /= sensor_val_cnt;

                        Float val = eRec.value0.der.abs().maxCoeff() * avg_sensor_val / edgePdf;
                        if ( std::isfinite(val) ) value += val;
                    }

                    if ( num_indirect_path > 0) {
                        for (int m = 0; m < num_indirect_path; m++) {
                            assert(eRec.its2.isEmitter());
                            Float val = (eRec.value1.der.abs().maxCoeff() * eRec.its2.Le(-edgeRay.dir.val) * importance[m].second).maxCoeff()/edgePdf;
                            if ( std::isfinite(val) ) value += val;
                        }
                    }
                }
                res += value;
            }
            Float avg = res/static_cast<Float>(params[3]);
            data[static_cast<long long>(omp_i)*params[2] + k] = static_cast<float>(avg);
        }

        if ( !quiet ) {
            omp_set_lock(&messageLock);
            progressIndicator(static_cast<Float>(omp_i)/(params[0]*params[1]));
            omp_unset_lock(&messageLock);
        }
    }
}

void IntegratorAD_PathSpace::buildPhotonMap(const Scene &scene, const GuidingOptions& opts, int max_bounces,
                                            std::vector<MapNode> &rad_nodes, std::vector<MapNode> &imp_nodes) const
{
    const int nworker = omp_get_num_procs();
    std::vector<RndSampler> samplers;
    for ( int i = 0; i < nworker; ++i ) samplers.push_back(RndSampler(17, i));

    std::vector< std::vector<MapNode> > rad_nodes_per_thread(nworker);
    std::vector< std::vector<MapNode> > imp_nodes_per_thread(nworker);
    for (int i = 0; i < nworker; i++) {
        rad_nodes_per_thread[i].reserve(opts.num_cam_path/nworker * max_bounces);
        imp_nodes_per_thread[i].reserve(opts.num_light_path/nworker * max_bounces);
    }

    const Camera &camera = scene.camera;
    const CropRectangle &rect = camera.rect;
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (size_t omp_i = 0; omp_i < opts.num_cam_path; omp_i++) {
        const int tid = omp_get_thread_num();
        bidir::PathNode *cameraPath = paths[2*tid];
        RndSampler &sampler = samplers[tid];

        Float x = rect.isValid() ? rect.offset_x + sampler.next1D() * rect.crop_width
                                 : sampler.next1D() * camera.width;
        Float y = rect.isValid() ? rect.offset_y + sampler.next1D() * rect.crop_height
                                 : sampler.next1D() * camera.height;
        Ray cameraRay = camera.samplePrimaryRay(x, y);
        int cameraPathLen;
        if ( scene.rayIntersect(cameraRay, false, cameraPath[0].its) ) {
            cameraPath[0].throughput = Spectrum(1.0f);
            cameraPathLen = 1;
            if ( max_bounces > 0 )
                cameraPathLen = bidir::buildPath(scene, &sampler, max_bounces, false, cameraPath);
        } else
            cameraPathLen = 0;

        for (int i = 0; i < cameraPathLen; i++) {
            assert(i <= max_bounces);
            rad_nodes_per_thread[tid].push_back( MapNode{cameraPath[i].its, cameraPath[i].throughput, i} );
        }
    }
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (size_t omp_i = 0; omp_i < opts.num_light_path; omp_i++) {
        const int tid = omp_get_thread_num();
        bidir::PathNode *lightPath = paths[2*tid+1];
        RndSampler &sampler = samplers[tid];
        lightPath[0].throughput = scene.sampleEmitterPosition(sampler.next2D(), lightPath[0].its, &lightPath[0].pdf0);
        int lightPathLen = 1;
        if ( max_bounces > 0 ) {
            Vector wo;
            Float tmp = lightPath[0].its.ptr_emitter->sampleDirection(sampler.next2D(), wo, &lightPath[1].pdf0);
            wo = lightPath[0].its.geoFrame.toWorld(wo);
            if ( scene.rayIntersect(Ray(lightPath[0].its.p, wo), true, lightPath[1].its) ) {
                lightPath[0].wo = wo;
                lightPath[1].throughput = lightPath[0].throughput*tmp;
                if ( max_bounces > 1 )
                    lightPathLen = bidir::buildPath(scene, &sampler, max_bounces, true, &lightPath[1]) + 1;
                else
                    lightPathLen = 2;
            }
        }
#ifdef USE_BOUNDARY_NEE
        int lightPathStart = 1;
#else
        int lightPathStart = 0;
#endif
        for (int i = lightPathStart; i < lightPathLen; i++) {
            assert(i <= max_bounces);
            imp_nodes_per_thread[tid].push_back( MapNode{lightPath[i].its, lightPath[i].throughput, i} );
        }
    }

    size_t sz_rad = 0, sz_imp = 0;
    for (int i = 0; i < nworker; i++) {
        sz_rad += rad_nodes_per_thread[i].size();
        sz_imp += imp_nodes_per_thread[i].size();
    }
    rad_nodes.reserve(sz_rad);
    imp_nodes.reserve(sz_imp);
    for (int i = 0; i < nworker; i++) {
        rad_nodes.insert(rad_nodes.end(), rad_nodes_per_thread[i].begin(), rad_nodes_per_thread[i].end());
        imp_nodes.insert(imp_nodes.end(), imp_nodes_per_thread[i].begin(), imp_nodes_per_thread[i].end());
    }
}

int IntegratorAD_PathSpace::queryPhotonMap(const KDtree<Float> &indices, const GuidingOptions& opts, const Float* query_point,
                                           size_t* matched_indices, Float& matched_dist_sqr, bool type) const {
    assert( opts.type == 1 || opts.type == 2 );
    int num_matched = 0;
    Float dist_sqr[NUM_NEAREST_NEIGHBOR];
    if (opts.type == 1) {
        num_matched = indices.knnSearch(query_point, NUM_NEAREST_NEIGHBOR, matched_indices, dist_sqr);
        assert(num_matched > 0);
        matched_dist_sqr = dist_sqr[num_matched - 1];
        if (matched_dist_sqr <  opts.search_radius) {
#ifdef VERBOSE
            omp_set_lock(&messageLock);
            if (type)
                std::cout << "[INFO] RadianceMap: " << "r = " << matched_dist_sqr << " < " << opts.search_radius << std::endl;
            else
                std::cout << "[INFO] ImportanceMap: " << "r = " << matched_dist_sqr << " < " << opts.search_radius << std::endl;
            omp_unset_lock(&messageLock);
#endif
            std::vector<std::pair<size_t, Float>> rsearch_result;
            nanoflann::SearchParams search_params;
            num_matched = indices.radiusSearch(query_point, NUM_NEAREST_NEIGHBOR, rsearch_result, search_params);
            matched_dist_sqr = opts.search_radius;
            if (num_matched > MAX_NEAREST_NEIGHBOR) {
                omp_set_lock(&messageLock);
                std::cout << "[INFO] #matched = " << num_matched << " > " << MAX_NEAREST_NEIGHBOR << (type ? "(RadianceMap)" : "(ImportanceMap)") << std::endl;
                omp_unset_lock(&messageLock);
                num_matched = MAX_NEAREST_NEIGHBOR;
            }

            for (int i = 0; i < num_matched; i++)
                matched_indices[i] = rsearch_result[i].first;
        }
    } else {
        nanoflann::SearchParams search_params;
        std::vector<std::pair<size_t, Float>> rsearch_result;
        num_matched = indices.radiusSearch(query_point, opts.search_radius, rsearch_result, search_params);
        if (num_matched == 0) {
#ifdef VERBOSE
            omp_set_lock(&messageLock);
            if (type)
                std::cout << "[INFO] RadianceMap: No photon is found within dist(d2) " << opts.search_radius << std::endl;
            else
                std::cout << "[INFO] ImportanceMap: No photon is found within dist(d2) " << opts.search_radius << std::endl;
            omp_unset_lock(&messageLock);
#endif
            num_matched = indices.knnSearch(query_point, NUM_NEAREST_NEIGHBOR, matched_indices, dist_sqr);
            assert(num_matched > 0);
            matched_dist_sqr = dist_sqr[num_matched - 1];
        } else {
            if (num_matched > MAX_NEAREST_NEIGHBOR) {
                omp_set_lock(&messageLock);
                std::cout << "[INFO] #matched = " << num_matched << " > " << MAX_NEAREST_NEIGHBOR << std::endl;
                omp_unset_lock(&messageLock);
                num_matched = MAX_NEAREST_NEIGHBOR;
            }

            matched_dist_sqr = opts.search_radius;
            for (int i = 0; i < num_matched; i++)
                matched_indices[i] = rsearch_result[i].first;
        }
    }
    return num_matched;
}

void IntegratorAD_PathSpace::preprocessIndirect(const Scene &scene, const GuidingOptions& opts, int max_bounces,
                                                const std::vector<MapNode> &rad_nodes, const KDtree<Float> &rad_indices,
                                                const std::vector<MapNode> &imp_nodes, const KDtree<Float> &imp_indices,
                                                ptr<float> data, bool quiet) const
{
    const int nworker = omp_get_num_procs();
    std::vector<RndSampler> samplers;
    for ( int i = 0; i < nworker; ++i ) samplers.push_back(RndSampler(13, i));
    const std::vector<int> &params = opts.params;

#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for ( int omp_i = 0; omp_i < params[0]*params[1]; ++omp_i ) {
        size_t matched_indices[MAX_NEAREST_NEIGHBOR];
        const int tid = omp_get_thread_num();
        RndSampler &sampler = samplers[tid];

        const int i = omp_i/params[1], j = omp_i % params[1];
        for ( int k = 0; k < params[2]; ++k ) {
            Float res = 0.0f;
            for ( int t = 0; t < params[3]; ++t ) {
                Vector rnd = sampler.next3D();
                rnd[0] = (rnd[0] + i)/static_cast<Float>(params[0]);
                rnd[1] = (rnd[1] + j)/static_cast<Float>(params[1]);
                rnd[2] = (rnd[2] + k)/static_cast<Float>(params[2]);

                int shape_id;
                RayAD edgeRay;
                Float edgePdf, value1 = 0.0f;
                const Edge &rEdge = scene.sampleEdgeRay(rnd, shape_id, edgeRay, edgePdf);
                if ( shape_id >= 0 ) {
                    EdgeEvaluationRecord eRec;
                    evalEdge(scene, shape_id, rEdge, edgeRay, &sampler, eRec);
                    value1 = eRec.value1.der.abs().maxCoeff()/edgePdf;
                    if (value1 > 0.0f) {
                        Float pt_rad[3] = {eRec.its1.p[0], eRec.its1.p[1], eRec.its1.p[2]};
                        Float pt_imp[3] = {eRec.its2.p[0], eRec.its2.p[1], eRec.its2.p[2]};
                        Float matched_r2_rad, matched_r2_imp;
                        int num_matched = queryPhotonMap(rad_indices, opts, pt_rad, matched_indices, matched_r2_rad, true);
                        std::vector<Spectrum> radiance(max_bounces, Spectrum::Zero());
                        for (int m = 0; m < num_matched; m++) {
                            const MapNode& node = rad_nodes[matched_indices[m]];
                            assert(node.depth < max_bounces);
                            Float bsdf_val = node.its.evalBSDF(node.its.toLocal(edgeRay.dir.val)).maxCoeff();
                            if (bsdf_val < BSDF_CLAMPING_MIN) bsdf_val = BSDF_CLAMPING_MIN;
                            if (bsdf_val > BSDF_CLAMPING_MAX) bsdf_val = BSDF_CLAMPING_MAX;
                            radiance[node.depth] += node.val * bsdf_val;
                        }
                        num_matched = queryPhotonMap(imp_indices, opts, pt_imp, matched_indices, matched_r2_imp, false);
                        std::vector<Spectrum> importance(max_bounces, Spectrum::Zero());
                        for (int m = 0; m < num_matched; m++) {
                            const MapNode& node = imp_nodes[matched_indices[m]];
                            assert(node.depth < max_bounces);
#ifdef USE_BOUNDARY_NEE
                            assert(node.depth > 0);
#endif
                            // Float bsdf_val = node.its.evalBSDF(node.its.toLocal(-edgeRay.dir.val), EBSDFMode::EImportanceWithCorrection).maxCoeff();
                            // if (bsdf_val < BSDF_CLAMPING_MIN) bsdf_val = BSDF_CLAMPING_MIN;
                            // if (bsdf_val > BSDF_CLAMPING_MAX) bsdf_val = BSDF_CLAMPING_MAX;
                            if (node.depth == 0) {
                                importance[node.depth] += node.val.maxCoeff() * node.its.ptr_emitter->evalDirection(node.its.geoFrame.n, -edgeRay.dir.val);
                            }
                            else {
                                Float bsdf_val = node.its.evalBSDF(node.its.toLocal(-edgeRay.dir.val), EBSDFMode::EImportanceWithCorrection).maxCoeff();
                                if (bsdf_val < BSDF_CLAMPING_MIN) bsdf_val = BSDF_CLAMPING_MIN;
                                if (bsdf_val > BSDF_CLAMPING_MAX) bsdf_val = BSDF_CLAMPING_MAX;
                                importance[node.depth] += node.val * bsdf_val;
                            }
                        }

                        Spectrum value2 = Spectrum::Zero();
#ifdef USE_BOUNDARY_NEE
                        int impStart = 1;
#else
                        int impStart = 0;
#endif
                        for (int m = 0; m < max_bounces; m++) {
                            for (int n = impStart; n < max_bounces-m; n++)
                                value2 += radiance[m] * importance[n];
                        }

                        value1 *= value2.maxCoeff() / (matched_r2_rad * matched_r2_imp);
                        assert(std::isfinite(value1));
                    }
                }
                res += value1;
            }

            Float avg = res/static_cast<Float>(params[3]);
            data[static_cast<long long>(omp_i)*params[2] + k] = static_cast<float>(avg);
        }

        if ( !quiet ) {
            omp_set_lock(&messageLock);
            progressIndicator(static_cast<Float>(omp_i)/(params[0]*params[1]));
            omp_unset_lock(&messageLock);
        }
    }
}

void IntegratorAD_PathSpace::preprocess(const Scene &scene, int max_bounces, const GuidingOptions& opts, ptr<float> data) const {
    using namespace std::chrono;
    const std::vector<int> &params = opts.params;

    assert(opts.type < 4);
    std::string guiding_type[4] = {"Direct Guiding", "Indirect Guiding (KNN)",
                                   "Indirect Guiding (Radius Search)", "Indirect Guiding (Old)"};

    assert(params.size() == 4);
    if ( !opts.quiet )
        std::cout << "[INFO] Preprocessing for " << guiding_type[opts.type]
                  << " at (" << params[0] << " x " << params[1] << " x " << params[2] << ") ... " << std::endl;

    auto _start = high_resolution_clock::now();
    if ( opts.type == 0) {
        preprocessDirect(scene, params, max_bounces, data, opts.quiet);
    } else {
#ifdef USE_BOUNDARY_NEE
        if (max_bounces < 2) {
            if ( !opts.quiet )
                std::cout << "[INFO] max_bounces < 2, no indirect component. Guiding cancelled." << std::endl;
            return;
        }
#endif
        if (opts.type == 3) {
            // Old Indirect guiding (unbiased)
            const int nworker = omp_get_num_procs();
            std::vector<RndSampler> samplers;
            for ( int i = 0; i < nworker; ++i ) samplers.push_back(RndSampler(13, i));
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
            for ( int omp_i = 0; omp_i < params[0]*params[1]; ++omp_i ) {
                const int tid = omp_get_thread_num();
                RndSampler &sampler = samplers[tid];
                const int i = omp_i/params[1], j = omp_i % params[1];
                for ( int k = 0; k < params[2]; ++k ) {
                    Float res = 0.0f;
                    for ( int t = 0; t < params[3]; ++t ) {
                        Vector rnd = sampler.next3D();
                        rnd[0] = (rnd[0] + i)/static_cast<Float>(params[0]);
                        rnd[1] = (rnd[1] + j)/static_cast<Float>(params[1]);
                        rnd[2] = (rnd[2] + k)/static_cast<Float>(params[2]);

                        int shape_id;
                        RayAD edgeRay;
                        Float edgePdf, value = 0.0f;
                        const Edge &rEdge = scene.sampleEdgeRay(rnd, shape_id, edgeRay, edgePdf);
                        if ( shape_id >= 0 ) {
                            EdgeEvaluationRecord eRec;
                            evalEdge(scene, shape_id, rEdge, edgeRay, &sampler, eRec);
                            value = eRec.value1.der.abs().maxCoeff()/edgePdf;
                        }
                        res += value;
                    }

                    Float avg = res/static_cast<Float>(params[3]);
                    data[static_cast<long long>(omp_i)*params[2] + k] = static_cast<float>(avg);
                }

                if ( !opts.quiet ) {
                    omp_set_lock(&messageLock);
                    progressIndicator(static_cast<Float>(omp_i)/(params[0]*params[1]));
                    omp_unset_lock(&messageLock);
                }
            }
        } else {
            if ( !opts.quiet )
                std::cout << "[INFO] #camPath = " << opts.num_cam_path << ", #lightPath = " << opts.num_light_path << std::endl;
            std::vector<MapNode> rad_nodes, imp_nodes;
            buildPhotonMap(scene, opts, max_bounces-1, rad_nodes, imp_nodes);
            if ( !opts.quiet )
                std::cout << "[INFO] #rad_nodes = " << rad_nodes.size() << ", #imp_nodes = " << imp_nodes.size() << std::endl;

            // Build up the point KDtree for query
            PointCloud<Float> rad_cloud;
            rad_cloud.pts.resize(rad_nodes.size());
            for (size_t i = 0; i < rad_nodes.size(); i++) {
                rad_cloud.pts[i].x = rad_nodes[i].its.p[0];
                rad_cloud.pts[i].y = rad_nodes[i].its.p[1];
                rad_cloud.pts[i].z = rad_nodes[i].its.p[2];
            }
            KDtree<Float> rad_indices(3, rad_cloud, nanoflann::KDTreeSingleIndexAdaptorParams(10));
            rad_indices.buildIndex();

            PointCloud<Float> imp_cloud;
            imp_cloud.pts.resize(imp_nodes.size());
            for (size_t i = 0; i < imp_nodes.size(); i++) {
                imp_cloud.pts[i].x = imp_nodes[i].its.p[0];
                imp_cloud.pts[i].y = imp_nodes[i].its.p[1];
                imp_cloud.pts[i].z = imp_nodes[i].its.p[2];
            }
            KDtree<Float> imp_indices(3, imp_cloud, nanoflann::KDTreeSingleIndexAdaptorParams(10));
            imp_indices.buildIndex();
            // Indirect Guiding
            preprocessIndirect(scene, opts, max_bounces, rad_nodes, rad_indices, imp_nodes, imp_indices, data, opts.quiet);
        }
    }
    if ( !opts.quiet )
        std::cout << "\nDone in " << duration_cast<seconds>(high_resolution_clock::now() - _start).count() << " seconds." << std::endl;
}

void IntegratorAD_PathSpace::renderEdgesDirect(const Scene &scene, const RenderOptions &options, ptr<float> rendered_image) const {
#ifndef USE_BOUNDARY_NEE
    std::cerr << "Without next-event estimation (NEE), renderEdgesDirect() should not be used." << std::endl;
    assert(false);
#endif
    const Camera &camera = scene.camera;
    int num_pixels = camera.getNumPixels();
    const int nworker = omp_get_num_procs();
    std::vector<std::vector<Spectrum> > image_per_thread(nworker);
    for (int i = 0; i < nworker; i++) image_per_thread[i].resize(nder*num_pixels, Spectrum(0.0f));

    constexpr int num_samples_per_block = 128;
    long long num_samples = static_cast<long long>(options.num_samples_secondary_edge_direct)*num_pixels;
    const long long num_block = static_cast<long long>(std::ceil(static_cast<Float>(num_samples)/num_samples_per_block));
    num_samples = num_block*num_samples_per_block;
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for ( long long index_block = 0; index_block < num_block; ++index_block ) {
        for ( int omp_i = 0; omp_i < num_samples_per_block; ++omp_i ) {
            const int tid = omp_get_thread_num();
            RndSampler sampler(options.seed, m_taskId[tid] = index_block*num_samples_per_block + omp_i);
            int shape_id;
            RayAD edgeRay;
            Float edgePdf;
            const Edge &rEdge = scene.sampleEdgeRayDirect(sampler.next3D(), shape_id, edgeRay, edgePdf);
            if ( shape_id >= 0 ) {
                std::pair<int, Spectrum> importance[BDPT_MAX_PATH_LENGTH + 1];
                EdgeEvaluationRecord eRec;
                int num_indirect_path = evalEdgeDirect(scene, shape_id, rEdge, edgeRay, &sampler, options.max_bounces, eRec, importance);

                // one-bounce
                if (!eRec.sensor_vals.isZero()) {
                    for (int i = 0; i < 4; ++i) {
                        if (eRec.sensor_vals[i] > Epsilon) {
                            for (int j = 0; j < nder; ++j) {
                                image_per_thread[tid][j*num_pixels + eRec.pixel_indices[i]] += eRec.value0.grad(j) * eRec.sensor_vals[i] / edgePdf;
                            }
                        }
                    }
                }

                // multi-bounce
                if ( num_indirect_path > 0 ) {
                    Spectrum light_val = eRec.its2.Le(-edgeRay.dir.val);
                    for ( int j = 0; j < nder; ++j ) {
                        for (int k = 0; k < num_indirect_path; k++)
                            image_per_thread[tid][j*num_pixels + importance[k].first] +=
                                eRec.value1.grad(j)*importance[k].second*light_val/edgePdf;
                    }
                }
            }
        }

        if ( !options.quiet ) {
            omp_set_lock(&messageLock);
            progressIndicator(Float(index_block + 1)/num_block);
            omp_unset_lock(&messageLock);
        }
    }
    if ( !options.quiet ) std::cout << std::endl;

    for ( int i = 0; i < nworker; ++i )
        for ( int j = 0; j < nder; ++j )
            for ( int idx_pixel = 0; idx_pixel < num_pixels; ++idx_pixel ) {
                int offset1 = ((j + 1)*num_pixels + idx_pixel)*3,
                    offset2 = j*num_pixels + idx_pixel;
                rendered_image[offset1    ] += image_per_thread[i][offset2][0]/static_cast<Float>(num_samples);
                rendered_image[offset1 + 1] += image_per_thread[i][offset2][1]/static_cast<Float>(num_samples);
                rendered_image[offset1 + 2] += image_per_thread[i][offset2][2]/static_cast<Float>(num_samples);
            }
}

void IntegratorAD_PathSpace::radiance(const Scene& scene, RndSampler* sampler, const Intersection &its, int max_bounces, Spectrum *ret) const {
    Intersection _its(its);
    if ( _its.isEmitter() ) ret[0] += _its.Le(_its.toWorld(_its.wi));

    Spectrum throughput(1.0f);
    for ( int d_emitter = 1; d_emitter <= max_bounces; d_emitter++ ) {
        // Direct illumination
        Float pdf_nee;
        Vector wo;
        Spectrum value = scene.sampleEmitterDirect(_its, sampler->next2D(), sampler, wo, pdf_nee);
        if ( !value.isZero(Epsilon) ) {
            Spectrum bsdf_val = _its.evalBSDF(wo);
            Float bsdf_pdf = _its.pdfBSDF(wo);
            Float mis_weight = pdf_nee/(pdf_nee + bsdf_pdf);
            ret[d_emitter] += throughput*value*bsdf_val*mis_weight;
        }

        // Indirect illumination
        Float bsdf_pdf, bsdf_eta;
        Spectrum bsdf_weight = _its.sampleBSDF(sampler->next3D(), wo, bsdf_pdf, bsdf_eta);
        if ( bsdf_weight.isZero(Epsilon) ) break;
        wo = _its.toWorld(wo);

        Ray ray_emitter(_its.p, wo);
        if ( !scene.rayIntersect(ray_emitter, true, _its) ) break;
        throughput *= bsdf_weight;
        if ( _its.isEmitter() ) {
            Spectrum light_contrib = _its.Le(-ray_emitter.dir);
            if ( !light_contrib.isZero(Epsilon) ) {
                Float dist_sq = (_its.p - ray_emitter.org).squaredNorm();
                Float G = _its.geoFrame.n.dot(-ray_emitter.dir)/dist_sq;
                pdf_nee = scene.pdfEmitterSample(_its)/G;
                Float mis_weight = bsdf_pdf/(pdf_nee + bsdf_pdf);
                ret[d_emitter] += throughput*light_contrib*mis_weight;
            }
        }
    }
}

int IntegratorAD_PathSpace::weightedImportance(const Scene& scene, RndSampler* sampler, const Intersection& its0, int max_depth, const Spectrum *weight,
                                               std::pair<int, Spectrum>* ret) const {
    Intersection its = its0;
    Spectrum throughput(1.0f);
    Vector d0;
    Matrix2x4 pix_uvs;
    Ray ray_sensor;
    int num_valid_path = 0;
    for (int d_sensor = 1; d_sensor <= max_depth; d_sensor++) {
        // sample a new direction
        Vector wo_local, wo;
        Float bsdf_pdf, bsdf_eta;
        Spectrum bsdf_weight = its.sampleBSDF(sampler->next3D(), wo_local, bsdf_pdf, bsdf_eta, EBSDFMode::EImportanceWithCorrection);
        if (bsdf_weight.isZero())
            break;
        wo = its.toWorld(wo_local);
        Vector wi = its.toWorld(its.wi);
        Float wiDotGeoN = wi.dot(its.geoFrame.n), woDotGeoN = wo.dot(its.geoFrame.n);
        if (wiDotGeoN * its.wi.z() <= 0 || woDotGeoN * wo_local.z() <= 0)
            break;
        throughput *= bsdf_weight;
        ray_sensor = Ray(its.p, wo);
        scene.rayIntersect(ray_sensor, true, its);
        if ( !its.isValid() )
            break;

        Array4 sensor_vals = scene.sampleAttenuatedSensorDirect(its, sampler, 0, pix_uvs, d0);
        if (!sensor_vals.isZero()) {
            Vector wi = -ray_sensor.dir;
            Vector wo = d0, wo_local = its.toLocal(wo);
            Float wiDotGeoN = wi.dot(its.geoFrame.n), woDotGeoN = wo.dot(its.geoFrame.n);
            if (wiDotGeoN * its.wi.z() > 0 && woDotGeoN * wo_local.z() > 0) {
                assert(num_valid_path <= BDPT_MAX_PATH_LENGTH * 4);
                
                for (int i = 0; i < 4; i++) {
                    ret[num_valid_path].second = (weight != nullptr) ? weight[max_depth - d_sensor] : Spectrum(1.0f);
                    ret[num_valid_path].second *= throughput * its.evalBSDF(wo_local, EBSDFMode::EImportanceWithCorrection) * sensor_vals[i];
                    ret[num_valid_path].first = scene.camera.getPixelIndex(pix_uvs.col(i));
                    num_valid_path++;
                }
            }
        }
    }
    return num_valid_path;
}

void IntegratorAD_PathSpace::traceRayFromEdgeSegement(const Scene &scene, const EdgeEvaluationRecord& eRec, Float edgePdf, int max_bounces, RndSampler *sampler, std::vector<Spectrum> &image) const {
    assert(max_bounces > 0);
    const int num_pixels = image.size()/nder;

    /*** Trace ray towards emitter from its2 ***/
    std::vector<Spectrum> L(max_bounces, Spectrum(0.0f));
    {
        radiance(scene, sampler, eRec.its2, max_bounces - 1, &L[0]);
#ifdef USE_BOUNDARY_NEE
        L[0] = Spectrum(0.0f);
#endif
        for (int d_emitter = 1; d_emitter < max_bounces; d_emitter++)
            L[d_emitter] += L[d_emitter - 1];
    }

    /*** Trace ray towards sensor from its1 ***/
    if (!eRec.sensor_vals.isZero()) {
        for (int i = 0; i < 4; i++) {
            if (eRec.sensor_vals[i] > Epsilon) {
                for (int j = 0; j < nder; ++j) {
                    // Direct connect to sensor
                    image[j*num_pixels + eRec.pixel_indices[i]] += L[max_bounces - 1] * Spectrum(eRec.value0.grad(j)) * eRec.sensor_vals[i] / edgePdf;
                }
            }
        }
    }

    if ( max_bounces > 1 ) {
        std::pair<int, Spectrum> pathThroughput[BDPT_MAX_PATH_LENGTH + 1];
        int num_path = weightedImportance(scene, sampler, eRec.its1, max_bounces - 1, &L[0], pathThroughput);
        for (int i = 0; i < num_path; i++) {
            for (int j = 0; j < nder; j++) {
                image[j*num_pixels + pathThroughput[i].first] += pathThroughput[i].second * eRec.value1.grad(j) / edgePdf;
            }
        }
    }
}

void IntegratorAD_PathSpace::renderEdges(const Scene &scene, const RenderOptions &options, ptr<float> rendered_image) const {
#ifdef USE_BOUNDARY_NEE
    if ( options.max_bounces > 1 ) {
#else
    if ( options.max_bounces >= 1 ) {
#endif
        const Camera &camera = scene.camera;
        int num_pixels = camera.getNumPixels();
        const int nworker = omp_get_num_procs();
        std::vector<std::vector<Spectrum> > image_per_thread(nworker);
        for (int i = 0; i < nworker; i++) image_per_thread[i].resize(nder*num_pixels, Spectrum(0.0f));

        constexpr int num_samples_per_block = 128;
        long long num_samples = static_cast<long long>(options.num_samples_secondary_edge_indirect)*num_pixels;
        const long long num_block = static_cast<long long>(std::ceil(static_cast<Float>(num_samples)/num_samples_per_block));
        num_samples = num_block*num_samples_per_block;
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
        for (long long index_block = 0; index_block < num_block; ++index_block) {
            for (int omp_i = 0; omp_i < num_samples_per_block; omp_i++) {
                const int tid = omp_get_thread_num();
                RndSampler sampler(options.seed, m_taskId[tid] = index_block*num_samples_per_block + omp_i);
                // indirect contribution of edge term
                int shape_id;
                RayAD edgeRay;
                Float edgePdf;
                const Edge &rEdge = scene.sampleEdgeRay(sampler.next3D(), shape_id, edgeRay, edgePdf);
                if (shape_id >= 0) {
                    EdgeEvaluationRecord eRec;
                    evalEdge(scene, shape_id, rEdge, edgeRay, &sampler, eRec);
                    if (eRec.its1.isValid() && eRec.its2.isValid()) {
                        bool zeroVelocity = eRec.value0.der.isZero(Epsilon) && eRec.value1.der.isZero(Epsilon);
                        if (!zeroVelocity) {
                            traceRayFromEdgeSegement(scene, eRec, edgePdf, options.max_bounces, &sampler, image_per_thread[tid]);
                        }
                    }
                }
            }

            if ( !options.quiet ) {
                omp_set_lock(&messageLock);
                progressIndicator(Float(index_block + 1)/num_block);
                omp_unset_lock(&messageLock);
            }
        }
        if ( !options.quiet )
                std::cout << std::endl;

        for ( int i = 0; i < nworker; ++i )
            for ( int j = 0; j < nder; ++j )
                for ( int idx_pixel = 0; idx_pixel < num_pixels; ++idx_pixel ) {
                    int offset1 = ((j + 1)*num_pixels + idx_pixel)*3,
                        offset2 = j*num_pixels + idx_pixel;
                    rendered_image[offset1    ] += image_per_thread[i][offset2][0]/static_cast<Float>(num_samples);
                    rendered_image[offset1 + 1] += image_per_thread[i][offset2][1]/static_cast<Float>(num_samples);
                    rendered_image[offset1 + 2] += image_per_thread[i][offset2][2]/static_cast<Float>(num_samples);
                }
    }
}


void IntegratorAD_PathSpace::render(const Scene &scene, const RenderOptions &options, ptr<float> rendered_image) const {
    if ( !options.quiet )
        std::cout << std::scientific << std::setprecision(1) << "[INFO] grad_threshold = " << options.grad_threshold << std::endl;
    IntegratorAD::render(scene, options, rendered_image);
}
