#include "integratorADps.h"
#include "scene.h"
#include "sampler.h"
#include "rayAD.h"
#include "math_func.h"
#include "bidir_utils.h"
#include "nanoflann.hpp"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <fstream>

#define NUM_NEAREST_NEIGHBORS 5

template <typename T>
struct PointCloud {
    struct Point
    {
        T  x,y,z;
    };

    std::vector<Point>  pts;

    inline size_t kdtree_get_point_count() const { return pts.size(); }

    inline T kdtree_get_pt(const size_t idx, const size_t dim) const
    {
        if (dim == 0) return pts[idx].x;
        else if (dim == 1) return pts[idx].y;
        else return pts[idx].z;
    }

    template <class BBOX>
    bool kdtree_get_bbox(BBOX& /* bb */) const { return false; }
};

template <typename T>
using KDtree = nanoflann::KDTreeSingleIndexAdaptor<nanoflann::L2_Simple_Adaptor<T, PointCloud<T> >, PointCloud<T>, 3>;

SpectrumAD EventRecordAD::fAD(const Scene &scene, const VectorAD &wi, const VectorAD &wo) const {
    if (onSurface)
    {
        return its.evalBSDF(its.toLocal(wo));
    }
    else
    {
        const PhaseFunction *ptr_phase = scene.phase_list[ptr_med->phase_id];
        return ptr_phase->evalAD(wi, wo);
    }
}

bool EventRecordAD::sample(const Scene &scene, RndSampler *sampler, const Vector &wi, Vector &wo, Float &pdf) const {
    if (onSurface)
    {
        VectorAD wo_local;
        Float bsdf_eta;
        auto bsdf_weight = its.sampleBSDF(sampler->next3D(), wo_local, pdf, bsdf_eta);
        wo = its.toIntersection().toWorld(wo_local.val);
        if (bsdf_weight.isZero(Epsilon))
            return false;
    }
    else
    {
        const PhaseFunction *ptr_phase = scene.phase_list[ptr_med->phase_id];
        auto phase = ptr_phase->sample(wi, sampler->next2D(), wo);
        pdf = ptr_phase->pdf(wi, wo);
        if (phase < Epsilon)
            return false;
    }
    return true;
}

SpectrumAD EventRecordAD::f(const Scene &scene, const VectorAD &wo) const {
    SpectrumAD ret;
    if (onSurface)
    {
        assert(its.isValid());
        assert(!its.getBSDF()->isNull());
        VectorAD wo_local = its.toLocal(wo);
        Float pdf = its.getBSDF()->pdf(its, wo_local.val);
        ret = its.evalBSDF(wo_local) / pdf;
    }
    else
    {
        const PhaseFunction *ptr_phase = scene.phase_list[ptr_med->phase_id];
        Float pdf = ptr_phase->pdf(wi.val, wo.val);
        ret = SpectrumAD(ptr_phase->evalAD(wi, wo)) / pdf;
    }
    return ret;
}

IntegratorAD_PathSpace::IntegratorAD_PathSpace() {
    omp_init_lock(&messageLock);
}

IntegratorAD_PathSpace::~IntegratorAD_PathSpace() {
    omp_destroy_lock(&messageLock);
}

/******   Helper Functions    ******/
bool getBoundaryEndpointOnEmitter(const Scene& scene, RndSampler* sampler, const Ray& bRay, const Medium* ptr_med, int max_depth,
                                  int& depth_boundary, Intersection& its, Spectrum& attenuated_Le)
{
    Ray _ray(bRay);
    while ( depth_boundary < max_depth - 1) {
        if ( !scene.rayIntersect(_ray, true, its) ) break;
        depth_boundary++;
        if ( its.ptr_bsdf->isNull() ) {
            _ray.org = its.p;
        }
        else {
            if( !its.isEmitter() ) break;
            if ( its.geoFrame.n.dot(-_ray.dir) < Epsilon ) break;
            Float dist = (its.p - bRay.org).norm();
            Float transmittance = scene.evalTransmittance(bRay, true, ptr_med, dist, sampler, max_depth-1);
            if ( transmittance < Epsilon ) {
                // printf("[WARN] Get zero transmittance: ray.org = (%f, %f, %f), p = (%f, %f, %f)\n",
                //      bRay.org.x(), bRay.org.y(), bRay.org.z(),
                //      its.p.x(), its.p.y(), its.p.z());
                return false;
            }
            attenuated_Le = its.Le(-_ray.dir) * transmittance;
            return true;
        }
    }
    return false;
}

bool traceRayForBoundaryEndpoint(const Scene& scene, RndSampler* sampler, Ray bRay, int max_depth,
                                 int& depth_boundary, EventRecord& bEvt, Spectrum& throughput)
{
    bool success = false;
    while ( depth_boundary < max_depth ) {
        if ( !scene.rayIntersect(bRay, true, bEvt.its) ) break;
        depth_boundary++;
        bool inside_med = bEvt.ptr_med != nullptr &&
                          bEvt.ptr_med->sampleDistance(bRay, bEvt.its.t, sampler->next2D(), sampler, bRay.org, throughput);
        if ( !inside_med ) {
            if ( bEvt.its.ptr_bsdf->isNull() ) {
                bRay.org = bEvt.its.p;
                if ( bEvt.its.isMediumTransition() )
                    bEvt.ptr_med = bEvt.its.getTargetMedium(bRay.dir);
            }
            else {
                bEvt.onSurface = true;
                Float gnDotD = bEvt.its.geoFrame.n.dot(-bRay.dir);
                Float snDotD = bEvt.its.shFrame.n.dot(-bRay.dir);
                success = ( bEvt.its.ptr_bsdf->isTransmissive() && math::signum(gnDotD)*math::signum(snDotD) > 0.5f) ||
                          (!bEvt.its.ptr_bsdf->isTransmissive() && gnDotD > Epsilon && snDotD > Epsilon);
                break;
            }
        } else {
            bEvt.onSurface = false;
            bEvt.scatter = bRay.org;
            bEvt.wi = -bRay.dir;
            success = true;
            break;
        }
    }
    return success;
}

bool getBoundaryEndpointFromBoundaryRay(const Scene& scene, const EventRecord& bEvt, const VectorAD& x1, const VectorAD& pEdge, VectorAD& u2) {
    if ( bEvt.onSurface ) {
        RayAD bRay(x1, pEdge-x1);
        const Shape &shape = *scene.shape_list[bEvt.its.indices[0]];
        const Vector3i &f1 = shape.getIndices(bEvt.its.indices[1]);
        const VectorAD &v0 = shape.getVertexAD(f1[0]), &v1 = shape.getVertexAD(f1[1]), &v2 = shape.getVertexAD(f1[2]);
        return rayIntersectTriangleAD(v0, v1, v2, bRay, u2);
    } else {
        RayAD bRay(pEdge, (pEdge-x1).normalized());
        const Shape &shape = *scene.shape_list[bEvt.its.indices[0]];
        const Vector3i &f1 = shape.getIndices(bEvt.its.indices[1]);
        const VectorAD &v0 = shape.getVertexAD(f1[0]), &v1 = shape.getVertexAD(f1[1]), &v2 = shape.getVertexAD(f1[2]);
        const ArrayAD uvt = rayIntersectTriangleAD(v0, v1, v2, bRay);
        FloatAD tFar = uvt(2);
        VectorAD pFar = bRay.org + bRay.dir * tFar;

        Float t = (bEvt.scatter - pEdge.val).norm();
        VectorAD x2 = bRay.org + t * bRay.dir;

        if ( (x2.val-bEvt.scatter).norm() > ShadowEpsilon) {
            printf("[WARN] x1 = (%.4f, %.4f, %.4f), scatter = (%.4f, %.4f, %.4f), x2 = (%.4f, %.4f, %.4f), \n",
                    x1.val.x(), x1.val.y(), x1.val.z(),
                    bEvt.scatter.x(), bEvt.scatter.y(), bEvt.scatter.z(),
                    x2.val.x(), x2.val.y(), x2.val.z());
            return false;
        }

        return bEvt.ptr_med->tet_ptr->query_boundary(x2, u2);
    }
}

int weightedImportance(const Scene& scene, RndSampler* sampler, const EventRecord& evt, int max_interactions, const Spectrum *weight, std::pair<int, Spectrum>* ret ) {
    int num_valid_path = 0;
    int depth = 0;

    Vector d0;
    Matrix2x4 pixel_uvs;
    Ray ray_sensor;
    Spectrum throughput(1.0f);

    const Medium* ptr_med = evt.onSurface ? nullptr : evt.ptr_med;
    Intersection its = evt.its;
    bool inside_med = ptr_med != nullptr;
    Ray ray( inside_med ? evt.scatter : evt.its.p,
             inside_med ? -evt.wi : -evt.its.toWorld(its.wi) );

    while (depth <= max_interactions) {
        Array4 sensor_vals = inside_med ?  scene.sampleAttenuatedSensorDirect(ray.org, ptr_med, sampler, max_interactions-depth, pixel_uvs, d0)
                                        : (its.ptr_bsdf->isNull() ? Array4(0.0) : scene.sampleAttenuatedSensorDirect(its, sampler, max_interactions-depth, pixel_uvs, d0));
        if ( !sensor_vals.isZero() ) {
            Spectrum value0(0.0);
            if ( !inside_med ) {
                Vector wi = -ray.dir;
                Vector wo = d0, wo_local = its.toLocal(wo);
                Float wiDotGeoN = wi.dot(its.geoFrame.n), woDotGeoN = wo.dot(its.geoFrame.n);
                if (wiDotGeoN * its.wi.z() > Epsilon && woDotGeoN * wo_local.z() > Epsilon) {
                    value0 = (weight == nullptr) ? Spectrum(1.0) : weight[max_interactions - depth];
                    value0 *= throughput * its.evalBSDF(wo_local, EBSDFMode::EImportanceWithCorrection);
                }
            } else {
                const PhaseFunction* ptr_phase = scene.phase_list[ptr_med->phase_id];
                value0 = (weight == nullptr) ? Spectrum(1.0) : weight[max_interactions - depth];
                value0 *= throughput * ptr_phase->eval(-ray.dir, d0);
            }

            for (int i = 0; i < 4; i++) {
                if ( sensor_vals(i) > Epsilon ) {
                    assert(num_valid_path < BDPT_MAX_PATH_LENGTH);
                    ret[num_valid_path].second = value0 * sensor_vals(i);
                    ret[num_valid_path].first = scene.camera.getPixelIndex(pixel_uvs.col(i));
                    num_valid_path++;
                }
            }
        }

        if ( inside_med ) {
            Vector wo;
            const PhaseFunction* ptr_phase = scene.phase_list[ptr_med->phase_id];
            Float phase_val = ptr_phase->sample(-ray.dir, sampler->next2D(), wo);
            if ( phase_val < Epsilon ) break;
            throughput *= phase_val;
            ray.dir = wo;
            if ( !scene.rayIntersect(ray, false, its) ) break;
        } else {
            Vector wo_local, wo;
            Float bsdf_pdf, bsdf_eta;
            Spectrum bsdf_weight = its.sampleBSDF(sampler->next3D(), wo_local, bsdf_pdf, bsdf_eta, EBSDFMode::EImportanceWithCorrection);
            if (bsdf_weight.isZero()) break;
            wo = its.toWorld(wo_local);
            Vector wi = its.toWorld(its.wi);
            Float wiDotGeoN = wi.dot(its.geoFrame.n), woDotGeoN = wo.dot(its.geoFrame.n);
            if (wiDotGeoN * its.wi.z() <= 0 || woDotGeoN * wo_local.z() <= 0) break;
            throughput *= bsdf_weight;
            ray.org = its.p;
            ray.dir = wo;
            if ( its.isMediumTransition() )
                ptr_med = its.getTargetMedium(wo);
            if (!scene.rayIntersect(ray, true, its) ) break;
        }

        depth++;
        inside_med = ptr_med != nullptr && ptr_med->sampleDistance(ray, its.t, sampler->next2D(), sampler, ray.org, throughput);
        if (throughput.isZero()) break;
    }
    return num_valid_path;
}

Spectrum sampleAttenuatedEmitterDirect(const Scene& scene, const EventRecord& evt, const Vector2 &_rnd_light, RndSampler* sampler, const Medium* ptr_medium, int max_interactions,
                                       Vector& wo, Float& pdf, int& interactions)
{
    Vector2 rnd_light(_rnd_light);
    const int light_id = scene.light_distrb.sampleReuse(rnd_light[0], pdf);

    const Emitter* emitter =scene.emitter_list[light_id];
    DirectSamplingRecord dRec(evt.x());
    Spectrum value = emitter->sampleDirect(rnd_light, dRec);
    value /= pdf;
    pdf *= dRec.pdf;

    const Vector &light_pos = dRec.p, &light_norm = dRec.n;
    Vector ray_org = evt.onSurface ? evt.its.p : evt.scatter;
    Vector dir = light_pos - ray_org;
    bool valid = pdf < 0 || (dir.dot(light_norm) < -Epsilon && pdf > Epsilon);
    if ( evt.onSurface ) {
        Float snDotD = dir.dot(evt.its.shFrame.n);
        Float gnDotD = dir.dot(evt.its.geoFrame.n);
        if ( evt.its.ptr_bsdf->isTransmissive() || evt.its.ptr_bsdf->isTwosided() ) {
            valid = valid && math::signum(gnDotD)*math::signum(snDotD) > 0.5f;
        } else {
            valid = valid && snDotD > Epsilon && gnDotD > Epsilon;
        }
    }

    if ( valid ) {
        Float dist = dir.norm();
        dir = dir/dist;
        Ray shadow_ray(ray_org, dir);
        Float offset = evt.onSurface ? ShadowEpsilon : 0.0;
        Float transmittance = scene.evalTransmittance(shadow_ray, offset, ptr_medium, dist, sampler, max_interactions, interactions);
        if (transmittance != 0) {
            wo = evt.onSurface ? evt.its.toLocal(dir) : dir;
            return value * transmittance;
        }
    }
    return Spectrum::Zero();
}

void radiance(const Scene& scene, RndSampler* sampler, EventRecord evt, int max_depth, Spectrum *ret) {
    if ( evt.onSurface && evt.its.isEmitter() )
        ret[0] += evt.its.Le(evt.its.toWorld(evt.its.wi));
    Spectrum throughput(1.0f);
    const Medium* ptr_med = evt.onSurface ? nullptr : evt.ptr_med;
    bool inside_med = ptr_med != nullptr;
    Ray ray( inside_med ?  evt.scatter : evt.its.p,
             inside_med ? -evt.wi : Vector(1.0, 0.0, 0.0) );

    int depth = 1;
    while( depth <= max_depth ) {
        if ( inside_med ) {
            Vector wo;
            Float pdf_nee;
            const PhaseFunction* ptr_phase = scene.phase_list[ptr_med->phase_id];
            int num_interactions = 0;
            evt.scatter = ray.org;
            evt.onSurface = false;
            Spectrum value = sampleAttenuatedEmitterDirect(scene, evt, sampler->next2D(), sampler, ptr_med, max_depth-depth, wo, pdf_nee, num_interactions);
            if ( !value.isZero() ) {
                Float phase_val = ptr_phase->eval(-ray.dir, wo);
                if ( phase_val != 0.0) {
                    Float phase_pdf = ptr_phase->pdf(-ray.dir, wo);
                    Float mis_weight = pdf_nee < 0 ? 1.0 : square(pdf_nee) / (square(pdf_nee) + square(phase_pdf));
                    assert(depth+num_interactions <= max_depth);
                    ret[depth+num_interactions] += throughput * value * mis_weight * phase_val;
                }
            }
            Float phase_val = ptr_phase->sample(-ray.dir, sampler->next2D(), wo);
            if ( phase_val == 0.0 ) break;
            throughput *= phase_val;
            Vector raydir = ray.dir;
            ray.dir = wo;
            num_interactions = 0;
            Spectrum attenuated_radiance = scene.rayIntersectAndLookForEmitter(ray, false, sampler, ptr_med, max_depth-depth, evt.its, pdf_nee, num_interactions);
            if (!attenuated_radiance.isZero()) {
                Float phase_pdf = ptr_phase->pdf(-raydir, wo);
                Float mis_weight = square(phase_pdf) / (square(pdf_nee) + square(phase_pdf));
                ret[depth+num_interactions] += throughput * attenuated_radiance * mis_weight;
            }
        } else {
            Float pdf_nee;
            Vector wo;
            int num_interactions = 0;
            evt.onSurface = true;
            Spectrum value = evt.its.ptr_bsdf->isNull() ? Vector::Zero()
                                                        : sampleAttenuatedEmitterDirect(scene, evt, sampler->next2D(), sampler, ptr_med, max_depth-depth, wo, pdf_nee, num_interactions);
            if ( !value.isZero(Epsilon) ) {
                Spectrum bsdf_val = evt.its.evalBSDF(wo);
                Float bsdf_pdf = evt.its.pdfBSDF(wo);
                Float mis_weight = pdf_nee < 0 ? 1.0 : square(pdf_nee) / (square(pdf_nee) + square(bsdf_pdf));
                ret[depth+num_interactions] += throughput * value * bsdf_val * mis_weight;
            }

            Float bsdf_pdf, bsdf_eta;
            Spectrum bsdf_weight = evt.its.sampleBSDF(sampler->next3D(), wo, bsdf_pdf, bsdf_eta);
            if ( bsdf_weight.isZero(Epsilon) ) break;
            wo = evt.its.toWorld(wo);
            ray = Ray(evt.its.p, wo);

            if (evt.its.isMediumTransition())
                ptr_med = evt.its.getTargetMedium(wo);

            if (evt.its.ptr_bsdf->isNull())
                scene.rayIntersect(ray, true, evt.its);
            else {
                int num_interactions = 0;
                Spectrum attenuated_radiance = scene.rayIntersectAndLookForEmitter(ray, true, sampler, ptr_med, max_depth-depth, evt.its, pdf_nee, num_interactions);
                throughput *= bsdf_weight;
                if ( !attenuated_radiance.isZero(Epsilon) ) {
                    Float mis_weight = square(bsdf_pdf) / (square(pdf_nee) + square(bsdf_pdf));
                    ret[depth+num_interactions] += throughput * attenuated_radiance * mis_weight;
                }
            }
        }
        if ( !evt.its.isValid() ) break;
        depth++;
        inside_med = ptr_med != nullptr && ptr_med->sampleDistance(ray, evt.its.t, sampler->next2D(), sampler, ray.org, throughput);
        if (throughput.isZero()) break;
    }
}

void buildPhotonMap(const Scene& scene, int num_paths, int max_bounces, std::vector<RadImpNode> &nodes, bool importance) {
    const int nworker = omp_get_num_procs();
    std::vector<RndSampler> samplers;
    for ( int i = 0; i < nworker; ++i ) samplers.push_back(RndSampler(17, i));

    std::vector< std::vector<RadImpNode> > nodes_per_thread(nworker);
    for (int i = 0; i < nworker; i++) {
        nodes_per_thread[i].reserve(num_paths/nworker * max_bounces);
    }

    const Camera &camera = scene.camera;
    const CropRectangle &rect = camera.rect;
    #pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (size_t omp_i = 0; omp_i < (size_t)num_paths; omp_i++) {
        const int tid = omp_get_thread_num();
        RndSampler& sampler = samplers[tid];
        Ray ray;
        Spectrum throughput;
        Intersection its;
        const Medium* ptr_med = nullptr;
        int depth = 0;
        bool onSurface = false;
        if ( !importance ) {
            // Trace ray from camera
            Float x = rect.isValid() ? rect.offset_x + sampler.next1D() * rect.crop_width : sampler.next1D() * camera.width;
            Float y = rect.isValid() ? rect.offset_y + sampler.next1D() * rect.crop_height : sampler.next1D() * camera.height;
            ray = camera.samplePrimaryRay(x, y);
            throughput = Spectrum(1.0);
            ptr_med = camera.getMedID() == -1 ? nullptr : scene.medium_list[camera.getMedID()];
        } else {
            // Trace ray from emitter
            throughput = scene.sampleEmitterPosition(sampler.next2D(), its);
            ptr_med = its.ptr_med_ext;
            nodes_per_thread[tid].push_back( RadImpNode{its.p, throughput, depth} );
            ray.org = its.p;
            throughput *= its.ptr_emitter->sampleDirection(sampler.next2D(), ray.dir);
            ray.dir = its.geoFrame.toWorld(ray.dir);
            depth++;
            onSurface = true;
        }

        if ( scene.rayIntersect(ray, onSurface, its) ) {
            while (depth <= max_bounces) {
                bool inside_med = ptr_med != nullptr &&
                                  ptr_med->sampleDistance(Ray(ray), its.t, sampler.next2D(), &sampler, ray.org, throughput);
                if ( inside_med ) {
                    if ( throughput.isZero() ) break;
                    const PhaseFunction* ptr_phase = scene.phase_list[ptr_med->phase_id];
                    nodes_per_thread[tid].push_back( RadImpNode{ray.org, throughput, depth} );
                    Float phase_val = ptr_phase->sample(-ray.dir, sampler.next2D(), ray.dir);
                    if ( phase_val == 0.0 ) break;
                    throughput *= phase_val;
                    scene.rayIntersect(ray, false, its);
                } else {
                    nodes_per_thread[tid].push_back( RadImpNode{its.p, throughput, depth} );
                    Float bsdf_pdf, bsdf_eta;
                    Vector wo_local, wo;
                    Spectrum bsdf_weight = its.sampleBSDF(sampler.next3D(), wo_local, bsdf_pdf, bsdf_eta,
                                                          importance ? EBSDFMode::EImportanceWithCorrection : EBSDFMode::ERadiance);
                    if ( bsdf_weight.isZero() ) break;
                    throughput = throughput * bsdf_weight;

                    wo = its.toWorld(wo_local);
                    Vector wi = -ray.dir;
                    Float wiDotGeoN = wi.dot(its.geoFrame.n), woDotGeoN = wo.dot(its.geoFrame.n);
                    if (wiDotGeoN * its.wi.z() <= 0 || woDotGeoN * wo_local.z() <= 0) break;

                    if ( its.isMediumTransition() )
                        ptr_med = its.getTargetMedium(woDotGeoN);

                    ray = Ray(its.p, wo);
                    if ( !scene.rayIntersect(ray, true, its) ) break;
                }
                depth++;
            }
        }
    }

    size_t sz_node = 0;
    for (int i = 0; i < nworker; i++)
        sz_node += nodes_per_thread[i].size();
    nodes.reserve(sz_node);
    for (int i = 0; i < nworker; i++)
        nodes.insert(nodes.end(), nodes_per_thread[i].begin(), nodes_per_thread[i].end());
}

int queryPhotonMap(const KDtree<Float> &indices, const Float* query_point, size_t* matched_indices, Float& matched_dist_sqr) {
    int num_matched = 0;
    Float dist_sqr[NUM_NEAREST_NEIGHBORS];
    num_matched = indices.knnSearch(query_point, NUM_NEAREST_NEIGHBORS, matched_indices, dist_sqr);
    assert(num_matched == NUM_NEAREST_NEIGHBORS);
    matched_dist_sqr = dist_sqr[num_matched - 1];
    return num_matched;
}

/***************************************************************************************************************************************************************/
void IntegratorAD_PathSpace::render(const Scene &scene, const RenderOptions &options, ptr<float> rendered_image) const {
    if ( !options.quiet )
        std::cout << "[INFO] Rendering using [ " << getName() << " ]" << std::endl;

    using namespace std::chrono;
    high_resolution_clock::time_point _init = high_resolution_clock::now();

    if ( options.num_samples > 0 ) {
        high_resolution_clock::time_point _start = high_resolution_clock::now();
        if ( !options.quiet )
            std::cout << "[INFO] Computing INTERIOR contribution, #samples = " << options.num_samples << std::endl;
        renderInterior(scene, options, rendered_image);
        if ( !options.quiet )
            std::cout << "[INFO] INTERIOR done in " << duration_cast<seconds>(high_resolution_clock::now() - _start).count() << " seconds." << std::endl;
    }

    if ( options.num_samples_primary_edge > 0) {
        high_resolution_clock::time_point _start = high_resolution_clock::now();
        if ( !options.quiet )
            std::cout << "[INFO] Computing PRIMARY BOUNDARY contribution, #samples = " << options.num_samples_primary_edge << std::endl;
        renderEdgesPrimary(scene, options, rendered_image);
        if ( !options.quiet )
            std::cout << "[INFO] PRIMARY BOUNDARY done in " << duration_cast<seconds>(high_resolution_clock::now() - _start).count() << " seconds." << std::endl;
    }

    if ( options.num_samples_secondary_edge_direct > 0 ) {
        high_resolution_clock::time_point _start = high_resolution_clock::now();
        if ( !options.quiet )
            std::cout << "[INFO] Computing DIRECT BOUNDARY contribution, #samples = " << options.num_samples_secondary_edge_direct << std::endl;
        renderEdgesDirect(scene, options, rendered_image);
        if ( !options.quiet )
            std::cout << "[INFO] DIRECT BOUNDARY done in " << duration_cast<seconds>(high_resolution_clock::now() - _start).count() << " seconds." << std::endl;
        
        _start = high_resolution_clock::now();
        if ( !options.quiet )
            std::cout << "[INFO] render DIRECT BOUNDARY POINT LIGHT" << std::endl;
        renderEdgesPointLight(scene, options, rendered_image);
        if ( !options.quiet )
            std::cout << "[INFO] DIRECT BOUNDARY POINT LIGHT done in " << duration_cast<seconds>(high_resolution_clock::now() - _start).count() << " seconds." << std::endl;
    }

    if ( options.num_samples_secondary_edge_indirect > 0 && options.max_bounces > 1) {
        high_resolution_clock::time_point _start = high_resolution_clock::now();
        if ( !options.quiet )
            std::cout << "[INFO] Computing INDIRECT BOUNDARY contribution, #samples = " << options.num_samples_secondary_edge_indirect << std::endl;
        renderEdges(scene, options, rendered_image);
        if ( !options.quiet )
            std::cout << "[INFO] INDIRECT BOUNDARY done in " << duration_cast<seconds>(high_resolution_clock::now() - _start).count() << " seconds." << std::endl;
    }
    if ( !options.quiet )
        std::cout << "[INFO] All done in " << duration_cast<seconds>(high_resolution_clock::now() - _init).count() << " seconds." << std::endl;
}

void IntegratorAD_PathSpace::preprocess(const Scene &scene, int max_bounces, const GuidingOptions& opts, ptr<float> data) const {
    using namespace std::chrono;
    auto _start = high_resolution_clock::now();
    if (opts.type == 2)
    {
        if (!opts.quiet)
            printf("[INFO] PRIMARY guiding ( res=[%d] ) starts \n", opts.params[0]);
        preprocessPrimary(scene, max_bounces, opts, data);
    }
    else if (opts.type == 0)
    {
        if (!opts.quiet)
            printf("[INFO] DIRECT guiding ( res=[%d, %d, %d] ) starts \n", opts.params[0], opts.params[1], opts.params[2]);
        preprocessDirect(scene, max_bounces, opts, data);
    }
    else if (opts.type == 1)
    {
        if (max_bounces < 2)
        {
            std::cout << "[ERROR] max_bounces < 2, no indirect component. Guiding cancelled." << std::endl;
            assert(false);
        }
        if (!opts.quiet)
            printf("[INFO] INDIRECT guiding ( res=[%d, %d, %d] ) starts \n", opts.params[0], opts.params[1], opts.params[2]);
        preprocessIndirect(scene, max_bounces, opts, data);
    }
    else if (opts.type == 3)
    {
        if (!opts.quiet)
            printf("[INFO] DIRECT POINT LIGHT guiding ( res=[%d, %d] ) starts \n", opts.params[0], opts.params[1]);
        preprocessDirectPointLight(scene, max_bounces, opts, data);
    }
    else
    {
        std::cout << "[ERROR] Invalid guiding type!" << std::endl;
        assert(false);
    }

    if (!opts.quiet)
        std::cout << "\nDone in " << duration_cast<seconds>(high_resolution_clock::now() - _start).count() << " seconds." << std::endl;
}

/******   Primary Boundary Term    ******/
void IntegratorAD_PathSpace::renderEdgesPrimary(const Scene &scene, const RenderOptions &options, ptr<float> rendered_image) const {

    const Camera &camera = scene.camera;
    int num_pixels = camera.getNumPixels();
    const int nworker = omp_get_num_procs();
    std::vector<std::vector<Spectrum> > image_per_thread(nworker);
    for (int i = 0; i < nworker; i++) image_per_thread[i].resize(nder*num_pixels, Spectrum(0.0f));

    constexpr int num_samples_per_block = 128;
    long long num_samples = static_cast<long long>(options.num_samples_primary_edge)*num_pixels;
    const long long num_block = static_cast<long long>(std::ceil(static_cast<Float>(num_samples)/num_samples_per_block));
    num_samples = num_block*num_samples_per_block;
    int finished_block = 0;
    #pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (long long index_block = 0; index_block < num_block; ++index_block) {
        for (int omp_i = 0; omp_i < num_samples_per_block; omp_i++) {
            const int tid = omp_get_thread_num();
            RndSampler sampler(options.seed, index_block*num_samples_per_block + omp_i);
            // sample a point on the edge for generating the boundary ray
            int shape_id;
            VectorAD pEdge;
            Float edgePdf;
            const Edge &edge = scene.ptr_psEdgeManager->samplePrimaryEdgePoint(sampler.next1D(), shape_id, pEdge, edgePdf);
            const Shape &shape = *scene.shape_list[shape_id];
            const VectorAD &v0 = shape.getVertexAD(edge.v0), &v1 = shape.getVertexAD(edge.v1);
            const Medium* ptr_medEdge = shape.med_ext_id < 0 ? nullptr : scene.medium_list[shape.med_ext_id];
            int depth_boundary = -1;
            Vector raydir = (pEdge.val - camera.cpos.val).normalized();
            Ray _edgeRay(pEdge.val, raydir);
            Spectrum boundary_throughput(1.0);
            EventRecord bEvt;
            bEvt.ptr_med = ptr_medEdge;
            if ( !traceRayForBoundaryEndpoint(scene, &sampler, _edgeRay, options.max_bounces, depth_boundary, bEvt, boundary_throughput) ) continue;

            Matrix2x4 pixel_uvs;
            Array4 attenuations(0.0);
            {
                Vector d0;
                Vector p = bEvt.onSurface ? bEvt.its.p : bEvt.scatter;
                camera.sampleDirect(p, pixel_uvs, attenuations, d0);
                if ( attenuations.isZero() ) continue;
                assert( (d0 + raydir).norm() < Epsilon );
                Float dist = (_edgeRay.org - camera.cpos.val).norm();
                Float transmittance = scene.evalTransmittance(_edgeRay.flipped(), true, ptr_medEdge, dist, &sampler, options.max_bounces, depth_boundary);
                if ( transmittance < Epsilon) continue;
                attenuations *= transmittance;
            }

            // evaluate the boundary segment
            Float baseValue = 0.0;
            Vector n;
            {
                Float dist = ((bEvt.onSurface ? bEvt.its.p : bEvt.scatter) - camera.cpos.val).norm();
                Float dist1 = (_edgeRay.org - camera.cpos.val).norm();
                Vector e = (v0.val - v1.val).normalized().cross(_edgeRay.dir);
                Vector e1 = shape.getVertex(edge.v2) -v0.val;
                Float sinphi = e.norm();
                if ( bEvt.onSurface ) {
                    Vector proj = e.cross(bEvt.its.geoFrame.n).normalized();
                    Float sinphi2 = _edgeRay.dir.cross(proj).norm();
                    n = bEvt.its.geoFrame.n.cross(proj).normalized();
                    Float deltaV = math::signum(e.dot(e1))*math::signum(e.dot(n));
                    if ( sinphi > Epsilon && sinphi2 > Epsilon ) {
                        baseValue = deltaV * (dist/dist1)*(sinphi/sinphi2) * std::abs(bEvt.its.geoFrame.n.dot(-_edgeRay.dir));
                    }
                } else {
                    n = e.normalized();
                    Float deltaV = math::signum(e.dot(e1));
                    if ( sinphi > Epsilon )
                        baseValue = deltaV * (dist/dist1) * sinphi;
                }
            }
            if ( std::abs(baseValue) < Epsilon) continue;
            // Hack: assuming that the camera is fixed
            VectorAD x1 = VectorAD(camera.cpos.val);
            VectorAD u2;
            int max_interactions = options.max_bounces - depth_boundary;
            if ( getBoundaryEndpointFromBoundaryRay(scene, bEvt, x1, pEdge, u2) ) {
                if ( u2.der.isZero(Epsilon) ) continue;
                std::vector<Spectrum> L(max_interactions+1, Spectrum::Zero());
                radiance(scene, &sampler, bEvt, max_interactions, &L[0]);
                for (int i = 1; i <= max_interactions; i++)
                    L[i] += L[i-1];

                if ( !L[max_interactions].isZero(Epsilon) ) {
                    SpectrumAD contribBoundarySeg;
                    for ( int j = 0; j < nder; ++j )
                        contribBoundarySeg.grad(j) = baseValue * n.dot(u2.grad(j)) * boundary_throughput * L[max_interactions];

                    for ( int i = 0; i < 4; i++ ) {
                        if ( attenuations(i) < Epsilon ) continue;
                        for (int j = 0; j < nder; j++) {
                            int pix_index = scene.camera.getPixelIndex(pixel_uvs.col(i));
                            image_per_thread[tid][j*num_pixels + pix_index] += attenuations(i) * Spectrum(contribBoundarySeg.grad(j)) / edgePdf;
                        }
                    }
                }
            }
        }

        if ( !options.quiet ) {
            omp_set_lock(&messageLock);
            progressIndicator(Float(++finished_block)/num_block);
            omp_unset_lock(&messageLock);
        }
    }
    if ( !options.quiet ) std::cout << std::endl;

    for ( int i = 0; i < nworker; ++i )
        for ( int j = 0; j < nder; ++j )
            for ( int idx_pixel = 0; idx_pixel < num_pixels; ++idx_pixel ) {
                int offset1 = ((j + 1)*num_pixels + idx_pixel)*3,
                    offset2 = j*num_pixels + idx_pixel;
                rendered_image[offset1    ] += image_per_thread[i][offset2][0]/static_cast<Float>(num_samples);
                rendered_image[offset1 + 1] += image_per_thread[i][offset2][1]/static_cast<Float>(num_samples);
                rendered_image[offset1 + 2] += image_per_thread[i][offset2][2]/static_cast<Float>(num_samples);
            }
}

void IntegratorAD_PathSpace::preprocessPrimary(const Scene &scene, int max_bounces, const GuidingOptions& opts, ptr<float> data) const {
    assert(false);      // We disable primary boundary guiding

    if ( !opts.quiet )
        std::cout << "[INFO] Primary Guiding: #lightPath = " << opts.num_light_path << std::endl;
    std::vector<RadImpNode> imp_nodes;
    buildPhotonMap(scene, opts.num_light_path, max_bounces, imp_nodes, true);
    PointCloud<Float> imp_cloud;
    imp_cloud.pts.resize(imp_nodes.size());
    for (size_t i = 0; i < imp_nodes.size(); i++) {
        imp_cloud.pts[i].x = imp_nodes[i].p[0];
        imp_cloud.pts[i].y = imp_nodes[i].p[1];
        imp_cloud.pts[i].z = imp_nodes[i].p[2];
    }
    if ( !opts.quiet )
        std::cout << "[INFO] Primary Guiding: #imp_nodes = " << imp_nodes.size() << std::endl;
    KDtree<Float> imp_indices(3, imp_cloud, nanoflann::KDTreeSingleIndexAdaptorParams(10));
    imp_indices.buildIndex();

    const std::vector<int> &params = opts.params;
    const int nworker = omp_get_num_procs();
    std::vector<RndSampler> samplers;
    for ( int i = 0; i < nworker; ++i ) samplers.push_back(RndSampler(13, i));
    const auto &camera = scene.camera;
    constexpr int block_size = 100;
    if ( params[0] % block_size != 0) {
        printf("[INFO] Primary Guiding: Ensure that grid resoltion is a multiple of 100 (Hack..)\n");
        assert(false);
    }
    int num_block = params[0] / block_size;
    int finished_block = 0;
    #pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for ( int omp_i = 0; omp_i < num_block; ++omp_i ) {
        const int tid = omp_get_thread_num();
        RndSampler &sampler = samplers[tid];
        for (int i = 0; i < block_size; i++) {
            long long index = static_cast<long long>(omp_i) * block_size + i;
            assert(index < params[0]);
            Float res = 0.0;
            for (int t = 0; t < params[3]; ++t) {
                Float rnd = sampler.next1D();
                rnd = (index + rnd)/static_cast<Float>(params[0]);
                int shape_id;
                VectorAD pEdge;
                Float edgePdf;
                const Edge &edge = scene.ptr_psEdgeManager->samplePrimaryEdgePoint(rnd, shape_id, pEdge, edgePdf);
                const Shape &shape = *scene.shape_list[shape_id];
                const VectorAD &v0 = shape.getVertexAD(edge.v0), &v1 = shape.getVertexAD(edge.v1);

                const Medium* ptr_medEdge = shape.med_ext_id < 0 ? nullptr : scene.medium_list[shape.med_ext_id];
                int depth_boundary = -1;
                Vector raydir = (pEdge.val - camera.cpos.val).normalized();
                Ray _edgeRay(pEdge.val, raydir);
                Spectrum boundary_throughput(1.0);
                EventRecord bEvt;
                bEvt.ptr_med = ptr_medEdge;
                if ( !traceRayForBoundaryEndpoint(scene, &sampler, _edgeRay, max_bounces, depth_boundary, bEvt, boundary_throughput) ) continue;
                Matrix2x4 pixel_uvs;
                Array4 attenuations(0.0);
                {
                    Vector d0;
                    Vector p = bEvt.onSurface ? bEvt.its.p : bEvt.scatter;
                    camera.sampleDirect(p, pixel_uvs, attenuations, d0);
                    if ( attenuations.isZero() ) continue;
                    assert((d0 + raydir).norm() < Epsilon);
                    Float dist = (_edgeRay.org - camera.cpos.val).norm();
                    Float transmittance = scene.evalTransmittance(_edgeRay.flipped(), true, ptr_medEdge, dist, &sampler, max_bounces, depth_boundary);
                    if ( transmittance < Epsilon) continue;
                    attenuations *= transmittance;
                }

                // evaluate the boundary segment
                Float baseValue = 0.0;
                Vector n;
                {
                    Float dist = ((bEvt.onSurface ? bEvt.its.p : bEvt.scatter) - camera.cpos.val).norm();
                    Float dist1 = (_edgeRay.org - camera.cpos.val).norm();
                    Vector e = (v0.val - v1.val).normalized().cross(_edgeRay.dir);
                    Vector e1 = shape.getVertex(edge.v2) -v0.val;
                    Float sinphi = e.norm();
                    if ( bEvt.onSurface ) {
                        Vector proj = e.cross(bEvt.its.geoFrame.n).normalized();
                        Float sinphi2 = _edgeRay.dir.cross(proj).norm();
                        n = bEvt.its.geoFrame.n.cross(proj).normalized();
                        Float deltaV = math::signum(e.dot(e1))*math::signum(e.dot(n));
                        if ( sinphi > Epsilon && sinphi2 > Epsilon ) {
                            baseValue = deltaV * (dist/dist1)*(sinphi/sinphi2) * std::abs(bEvt.its.geoFrame.n.dot(-_edgeRay.dir));
                        }
                    } else {
                        n = e.normalized();
                        Float deltaV = math::signum(e.dot(e1));
                        if ( sinphi > Epsilon )
                            baseValue = deltaV * (dist/dist1) * sinphi;
                    }
                }
                if ( std::abs(baseValue) < Epsilon) continue;
                // Hack: assuming that the camera is fixed
                VectorAD x1 = VectorAD(camera.cpos.val);
                VectorAD u2;
                int max_interactions = max_bounces - depth_boundary;
                if ( getBoundaryEndpointFromBoundaryRay(scene, bEvt, x1, pEdge, u2) ) {
                    if ( u2.der.isZero(Epsilon) ) continue;
                    const Vector& p = bEvt.onSurface ? bEvt.its.p : bEvt.scatter;
                    Float pt_imp[3] = {p[0], p[1], p[2]};
                    Float matched_r2_imp;
                    size_t matched_indices[NUM_NEAREST_NEIGHBORS];
                    int num_nearby_imp = queryPhotonMap(imp_indices, pt_imp, matched_indices, matched_r2_imp);
                    Spectrum importance(0.0);
                    for (int m = 0; m < num_nearby_imp; m++) {
                        const RadImpNode& node = imp_nodes[matched_indices[m]];
                        if (node.depth <= max_interactions)
                            importance += node.val;
                    }
                    assert( !importance.isZero(Epsilon) );
                    SpectrumAD contribBoundarySeg;
                    for ( int j = 0; j < nder; ++j )
                        contribBoundarySeg.grad(j) = baseValue * n.dot(u2.grad(j)) * boundary_throughput * importance;
                    Float val = contribBoundarySeg.der.abs().maxCoeff() * attenuations.maxCoeff() / (edgePdf*matched_r2_imp);
                    if( std::isfinite(val) ) {
                        res += val;
                    }
                }
            }
            Float avg = res/static_cast<Float>(params[3]);
            data[index] = static_cast<float>(avg);
        }

        if ( !opts.quiet ) {
            omp_set_lock(&messageLock);
            progressIndicator(static_cast<Float>(++finished_block)/num_block);
            omp_unset_lock(&messageLock);
        }
    }
    if ( !opts.quiet ) std::cout << std::endl;
}

/******   Direct Boundary Term    ******/
void IntegratorAD_PathSpace::renderEdgesDirect(const Scene &scene, const RenderOptions &options, ptr<float> rendered_image) const {
    if(scene.area_emitter_list.empty())
        return;
    const Camera &camera = scene.camera;
    int num_pixels = camera.getNumPixels();
    const int nworker = omp_get_num_procs();
    std::vector<std::vector<Spectrum> > image_per_thread(nworker);
    for (int i = 0; i < nworker; i++) image_per_thread[i].resize(nder*num_pixels, Spectrum(0.0f));
    constexpr int num_samples_per_block = 128;
    long long num_samples = static_cast<long long>(options.num_samples_secondary_edge_direct)*num_pixels;
    const long long num_block = static_cast<long long>(std::ceil(static_cast<Float>(num_samples)/num_samples_per_block));
    num_samples = num_block*num_samples_per_block;

    int finished_block = 0;
    #pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for ( long long index_block = 0; index_block < num_block; ++index_block ) {
        std::pair<int, Spectrum> importance[BDPT_MAX_PATH_LENGTH];
        for ( int omp_i = 0; omp_i < num_samples_per_block; ++omp_i ) {
            const int tid = omp_get_thread_num();
            RndSampler sampler(options.seed, index_block*num_samples_per_block + omp_i);
            int shape_id;
            RayAD edgeRay;
            Float edgePdf;
            const Edge &rEdge = scene.sampleEdgeRayDirect(sampler.next3D(), shape_id, edgeRay, edgePdf);
            if ( shape_id == -1 ) continue;

            const Shape& shape = *scene.shape_list[shape_id];
            const Medium* ptr_medEdge = nullptr;
            if ( shape.med_ext_id >= 0)
                ptr_medEdge = scene.medium_list[shape.med_ext_id];
            // Trace ray to obtain boundary endpoint on emitter
            Spectrum attenuated_Le(0.0);
            Intersection itsEmitter;
            int depth_boundary = -1;
            Ray _edgeRay = edgeRay.toRay();
            if ( !getBoundaryEndpointOnEmitter(scene, &sampler, _edgeRay, ptr_medEdge, options.max_bounces,
                                               depth_boundary, itsEmitter, attenuated_Le) ) continue;
            // Trace the opposite ray to obtain the other endpoint
            EventRecord bEvt;
            bEvt.ptr_med = ptr_medEdge;
            if ( !traceRayForBoundaryEndpoint(scene, &sampler, Ray(_edgeRay.org, -_edgeRay.dir), options.max_bounces,
                                              depth_boundary, bEvt, attenuated_Le) ) continue;

            Float baseValue = 0.0f;
            Vector n;
            {
                Float dist = (itsEmitter.p - (bEvt.onSurface ? bEvt.its.p : bEvt.scatter)).norm();
                Float dist1 = (edgeRay.org.val - (bEvt.onSurface ? bEvt.its.p : bEvt.scatter)).norm();
                Vector e = (shape.getVertex(rEdge.v0) - shape.getVertex(rEdge.v1)).normalized().cross(_edgeRay.dir);
                Float sinphi = e.norm();
                Vector proj = e.cross(itsEmitter.geoFrame.n).normalized();
                Float sinphi2 = _edgeRay.dir.cross(proj).norm();
                n = itsEmitter.geoFrame.n.cross(proj).normalized();
                Vector e1 = shape.getVertex(rEdge.v2) - shape.getVertex(rEdge.v0);
                Float deltaV = math::signum(e.dot(e1))*math::signum(e.dot(n));
                if ( sinphi > Epsilon && sinphi2 > Epsilon )
                    baseValue = deltaV*(dist1/dist)*(sinphi/sinphi2) * std::abs(itsEmitter.geoFrame.n.dot(-_edgeRay.dir));
            }
            if ( std::abs(baseValue) < Epsilon ) continue;

            int max_interactions = options.max_bounces - depth_boundary;
            const Shape &emitter = *itsEmitter.ptr_shape;
            const Vector3i &f1 = emitter.getIndices(itsEmitter.indices[1]);
            const VectorAD &v0 = emitter.getVertexAD(f1[0]), &v1 = emitter.getVertexAD(f1[1]), &v2 = emitter.getVertexAD(f1[2]);
            VectorAD x1;
            FloatAD J1;
            if ( bEvt.onSurface ) {
                VectorAD n1;
                scene.getPoint(bEvt.its, x1, n1, J1);
            }
            else
                if ( !bEvt.ptr_med->getPoint(bEvt.scatter, x1, J1) ) continue;

            VectorAD u2;
            if ( rayIntersectTriangleAD(v0, v1, v2, RayAD(x1, edgeRay.org - x1), u2) ) {
                if ( u2.der.isZero(Epsilon) ) continue;
                int num_indirect_path = weightedImportance(scene, &sampler, bEvt, max_interactions, nullptr, importance);
                if ( num_indirect_path > 0 ) {
                    FloatAD contribBoundarySeg;
                    for ( int j = 0; j < nder; ++j )
                        contribBoundarySeg.grad(j) = baseValue * n.dot(u2.grad(j));
                    for ( int j = 0; j < nder; ++j ) {
                        for (int k = 0; k < num_indirect_path; k++)
                            image_per_thread[tid][j*num_pixels + importance[k].first] += contribBoundarySeg.grad(j) * importance[k].second * attenuated_Le/edgePdf;
                    }
                }
            }
        }

        if ( !options.quiet ) {
            omp_set_lock(&messageLock);
            progressIndicator(Float(++finished_block)/num_block);
            omp_unset_lock(&messageLock);
        }
    }
    if ( !options.quiet ) std::cout << std::endl;

    for ( int i = 0; i < nworker; ++i )
        for ( int j = 0; j < nder; ++j )
            for ( int idx_pixel = 0; idx_pixel < num_pixels; ++idx_pixel ) {
                int offset1 = ((j + 1)*num_pixels + idx_pixel)*3,
                    offset2 = j*num_pixels + idx_pixel;
                rendered_image[offset1    ] += image_per_thread[i][offset2][0]/static_cast<Float>(num_samples);
                rendered_image[offset1 + 1] += image_per_thread[i][offset2][1]/static_cast<Float>(num_samples);
                rendered_image[offset1 + 2] += image_per_thread[i][offset2][2]/static_cast<Float>(num_samples);
            }
}

void IntegratorAD_PathSpace::preprocessDirect(const Scene &scene, int max_bounces, const GuidingOptions& opts, ptr<float> data) const {
   if ( !opts.quiet )
        std::cout << "[INFO] Direct Guiding: #camPath = " << opts.num_cam_path << std::endl;
    std::vector<RadImpNode> rad_nodes;
    buildPhotonMap(scene, opts.num_cam_path, max_bounces, rad_nodes, false);
    if ( !opts.quiet )
        std::cout << "[INFO] Direct Guiding: #rad_nodes = " << rad_nodes.size() << std::endl;
    PointCloud<Float> rad_cloud;
    rad_cloud.pts.resize(rad_nodes.size());
    for (size_t i = 0; i < rad_nodes.size(); i++) {
        rad_cloud.pts[i].x = rad_nodes[i].p[0];
        rad_cloud.pts[i].y = rad_nodes[i].p[1];
        rad_cloud.pts[i].z = rad_nodes[i].p[2];
    }
    KDtree<Float> rad_indices(3, rad_cloud, nanoflann::KDTreeSingleIndexAdaptorParams(10));
    rad_indices.buildIndex();

    const std::vector<int> &params = opts.params;
    const int nworker = omp_get_num_procs();
    std::vector<RndSampler> samplers;
    for ( int i = 0; i < nworker; ++i ) samplers.push_back(RndSampler(13, i));

    int finished_block = 0;
    #pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for ( int omp_i = 0; omp_i < params[0]*params[1]; ++omp_i ) {
        const int tid = omp_get_thread_num();
        RndSampler &sampler = samplers[tid];

        const int i = omp_i/params[1], j = omp_i % params[1];
        for ( int k = 0; k < params[2]; ++k ) {
            Float res = 0.0f;
            for ( int t = 0; t < params[3]; ++t ) {
                Vector rnd = sampler.next3D();
                rnd[0] = (rnd[0] + i)/static_cast<Float>(params[0]);
                rnd[1] = (rnd[1] + j)/static_cast<Float>(params[1]);
                rnd[2] = (rnd[2] + k)/static_cast<Float>(params[2]);

                int shape_id;
                RayAD edgeRay;
                Float edgePdf, value = 0.0f;
                const Edge &rEdge = scene.sampleEdgeRayDirect(rnd, shape_id, edgeRay, edgePdf);
                if ( shape_id < 0 ) continue;

                const Shape& shape = *scene.shape_list[shape_id];
                const Medium* ptr_medEdge = nullptr;
                if ( shape.med_ext_id >= 0) ptr_medEdge = scene.medium_list[shape.med_ext_id];
                // Trace ray to obtain boundary endpoint on emitter
                Spectrum attenuated_Le(0.0);
                Intersection itsEmitter;
                int depth_boundary = -1;
                Ray _edgeRay = edgeRay.toRay();
                if ( !getBoundaryEndpointOnEmitter(scene, &sampler, _edgeRay, ptr_medEdge, max_bounces,
                                                   depth_boundary, itsEmitter, attenuated_Le) ) continue;
                // Trace the opposite ray to obtain the other endpoint
                EventRecord bEvt;
                bEvt.ptr_med = ptr_medEdge;
                if ( !traceRayForBoundaryEndpoint(scene, &sampler, Ray(_edgeRay.org, -_edgeRay.dir), max_bounces,
                                                  depth_boundary, bEvt, attenuated_Le) ) continue;

                Float baseValue = 0.0f;
                Vector n;
                {
                    Float dist = (itsEmitter.p - (bEvt.onSurface ? bEvt.its.p : bEvt.scatter)).norm();
                    Float dist1 = (edgeRay.org.val - (bEvt.onSurface ? bEvt.its.p : bEvt.scatter)).norm();
                    Vector e = (shape.getVertex(rEdge.v0) - shape.getVertex(rEdge.v1)).normalized().cross(_edgeRay.dir);
                    Float sinphi = e.norm();
                    Vector proj = e.cross(itsEmitter.geoFrame.n).normalized();
                    Float sinphi2 = _edgeRay.dir.cross(proj).norm();
                    n = itsEmitter.geoFrame.n.cross(proj).normalized();
                    Vector e1 = shape.getVertex(rEdge.v2) - shape.getVertex(rEdge.v0);
                    Float deltaV = math::signum(e.dot(e1))*math::signum(e.dot(n));
                    if ( sinphi > Epsilon && sinphi2 > Epsilon )
                        baseValue = deltaV*(dist1/dist)*(sinphi/sinphi2) * std::abs(itsEmitter.geoFrame.n.dot(-_edgeRay.dir));
                }
                if ( std::abs(baseValue) < Epsilon ) continue;

                int max_interactions = max_bounces - depth_boundary;
                const Shape &emitter = *itsEmitter.ptr_shape;
                const Vector3i &f1 = emitter.getIndices(itsEmitter.indices[1]);
                const VectorAD &v0 = emitter.getVertexAD(f1[0]), &v1 = emitter.getVertexAD(f1[1]), &v2 = emitter.getVertexAD(f1[2]);
                VectorAD x1;
                FloatAD J1;
                if ( bEvt.onSurface ) {
                    VectorAD n1;
                    scene.getPoint(bEvt.its, x1, n1, J1);
                }
                else
                    if ( !bEvt.ptr_med->getPoint(bEvt.scatter, x1, J1) ) continue;

                VectorAD u2;
                if ( rayIntersectTriangleAD(v0, v1, v2, RayAD(x1, edgeRay.org - x1), u2) ) {
                    if ( u2.der.isZero(Epsilon) ) continue;
                    const Vector& p = bEvt.onSurface ? bEvt.its.p : bEvt.scatter;
                    Float pt_rad[3] = {p[0], p[1], p[2]};
                    Float matched_r2_rad;
                    size_t matched_indices[NUM_NEAREST_NEIGHBORS];
                    int num_nearby_rad = queryPhotonMap(rad_indices, pt_rad, matched_indices, matched_r2_rad);
                    Spectrum radiance(0.0);
                    for (int m = 0; m < num_nearby_rad; m++) {
                        const RadImpNode& node = rad_nodes[matched_indices[m]];
                        if (node.depth <= max_interactions)
                            radiance += node.val;
                    }
                    assert(!radiance.isZero(Epsilon));
                    FloatAD contribBoundarySeg;
                    for ( int j = 0; j < nder; ++j )
                        contribBoundarySeg.grad(j) = baseValue * n.dot(u2.grad(j));
                    value = contribBoundarySeg.der.abs().maxCoeff() * (radiance * attenuated_Le).maxCoeff()/edgePdf;
                    value /= matched_r2_rad;
                    assert(std::isfinite(value));
                }
                res += value;
            }
            Float avg = res/static_cast<Float>(params[3]);
            data[static_cast<long long>(omp_i)*params[2] + k] = static_cast<float>(avg);
        }

        if ( !opts.quiet ) {
            omp_set_lock(&messageLock);
            progressIndicator(static_cast<Float>(++finished_block)/(params[0]*params[1]));
            omp_unset_lock(&messageLock);
        }
    }
    if ( !opts.quiet ) std::cout << std::endl;
}

/******   Indirect Boundary Term    ******/
void IntegratorAD_PathSpace::renderEdges(const Scene &scene, const RenderOptions &options, ptr<float> rendered_image) const {
    const Camera &camera = scene.camera;
    int num_pixels = camera.getNumPixels();
    const int nworker = omp_get_num_procs();
    std::vector<std::vector<Spectrum> > image_per_thread(nworker);
    for (int i = 0; i < nworker; i++) image_per_thread[i].resize(nder*num_pixels, Spectrum(0.0f));

    constexpr int num_samples_per_block = 128;
    long long num_samples = static_cast<long long>(options.num_samples_secondary_edge_indirect)*num_pixels;
    const long long num_block = static_cast<long long>(std::ceil(static_cast<Float>(num_samples)/num_samples_per_block));
    num_samples = num_block*num_samples_per_block;
    int finished_block = 0;
    #pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (long long index_block = 0; index_block < num_block; ++index_block) {
        std::pair<int, Spectrum> pathThroughput[BDPT_MAX_PATH_LENGTH];
        for (int omp_i = 0; omp_i < num_samples_per_block; omp_i++) {
            const int tid = omp_get_thread_num();
            RndSampler sampler(options.seed, index_block*num_samples_per_block + omp_i);
            // indirect contribution of edge term
            int shape_id;
            RayAD edgeRay;
            Float edgePdf;
            const Edge &rEdge = scene.sampleEdgeRay(sampler.next3D(), shape_id, edgeRay, edgePdf);
            if ( shape_id == -1 ) continue;

            const Shape& shape = *scene.shape_list[shape_id];
            const Medium* ptr_med = shape.med_ext_id < 0 ? nullptr : scene.medium_list[shape.med_ext_id];
            // construct the boundary segment
            EventRecord bEvt0, bEvt1;           // bEvt0 lies on the discontinuity curve/surface
            bEvt0.ptr_med = bEvt1.ptr_med = ptr_med;
            Spectrum boundary_throughput(Spectrum::Ones());
            int depth_boundary = -1;
            Ray _edgeRay = edgeRay.toRay();
            bool valid_segment = traceRayForBoundaryEndpoint(scene, &sampler, _edgeRay, options.max_bounces-1, depth_boundary, bEvt0, boundary_throughput) &&
                                 traceRayForBoundaryEndpoint(scene, &sampler, _edgeRay.flipped(), options.max_bounces, depth_boundary, bEvt1, boundary_throughput);

            if ( !valid_segment || options.max_bounces == depth_boundary) continue;

            // evaluate the boundary segment
            Float baseValue = 0.0;
            Vector n;
            {
                Float dist = ((bEvt0.onSurface ? bEvt0.its.p : bEvt0.scatter) - (bEvt1.onSurface ? bEvt1.its.p : bEvt1.scatter)).norm();
                Float dist1 = (edgeRay.org.val - (bEvt1.onSurface ? bEvt1.its.p : bEvt1.scatter)).norm();
                Vector e = (shape.getVertex(rEdge.v0) - shape.getVertex(rEdge.v1)).normalized().cross(_edgeRay.dir);
                Vector e1 = shape.getVertex(rEdge.v2) - shape.getVertex(rEdge.v0);
                Float sinphi = e.norm();
                if ( bEvt0.onSurface ) {
                    Vector proj = e.cross(bEvt0.its.geoFrame.n).normalized();
                    Float sinphi2 = _edgeRay.dir.cross(proj).norm();
                    n = bEvt0.its.geoFrame.n.cross(proj).normalized();
                    Float deltaV = math::signum(e.dot(e1))*math::signum(e.dot(n));
                    if ( sinphi > Epsilon && sinphi2 > Epsilon )
                        baseValue = deltaV*(dist1/dist)*(sinphi/sinphi2) * std::abs(bEvt0.its.geoFrame.n.dot(-_edgeRay.dir));
                } else {
                    n = e.normalized();
                    Float deltaV = math::signum(e.dot(e1));
                    if ( sinphi > Epsilon )
                        baseValue = deltaV * (dist1/dist) * sinphi;
                }
            }
            if ( std::abs(baseValue) < Epsilon) continue;

            int max_interactions = options.max_bounces - depth_boundary;
            assert(max_interactions > 0);

            VectorAD x1;
            FloatAD J1;
            if ( bEvt1.onSurface ) {
                VectorAD n1;
                scene.getPoint(bEvt1.its, x1, n1, J1);
            } else {
                if ( !bEvt1.ptr_med->getPoint(bEvt1.scatter, x1, J1) ) continue;
            }

            VectorAD u2;
            if ( getBoundaryEndpointFromBoundaryRay(scene, bEvt0, x1, edgeRay.org, u2) ) {
                if ( u2.der.isZero(Epsilon) ) continue;
                std::vector<Spectrum> L(max_interactions+1, Spectrum::Zero());
                radiance(scene, &sampler, bEvt0, max_interactions, &L[0]);
                L[0] = Spectrum(0.0f);
                for (int i = 1; i <= max_interactions; i++)
                    L[i] += L[i-1];
                int num_path = weightedImportance(scene, &sampler, bEvt1, max_interactions, &L[0], pathThroughput);

                if ( num_path > 0 ) {
                    SpectrumAD contribBoundarySeg;
                    for ( int j = 0; j < nder; ++j )
                        contribBoundarySeg.grad(j) = baseValue * n.dot(u2.grad(j)) * boundary_throughput;

                    for ( int j = 0; j < nder; ++j ) {
                        Spectrum contrib_j = contribBoundarySeg.grad(j);
                        for (int k = 0; k < num_path; k++)
                            image_per_thread[tid][j*num_pixels + pathThroughput[k].first] += pathThroughput[k].second * contrib_j / edgePdf;
                    }
                }
            }
        }

        if ( !options.quiet ) {
            omp_set_lock(&messageLock);
            progressIndicator(Float(++finished_block)/num_block);
            omp_unset_lock(&messageLock);
        }
    }
    if ( !options.quiet )
            std::cout << std::endl;

    for ( int i = 0; i < nworker; ++i )
        for ( int j = 0; j < nder; ++j )
            for ( int idx_pixel = 0; idx_pixel < num_pixels; ++idx_pixel ) {
                int offset1 = ((j + 1)*num_pixels + idx_pixel)*3,
                    offset2 = j*num_pixels + idx_pixel;
                rendered_image[offset1    ] += image_per_thread[i][offset2][0]/static_cast<Float>(num_samples);
                rendered_image[offset1 + 1] += image_per_thread[i][offset2][1]/static_cast<Float>(num_samples);
                rendered_image[offset1 + 2] += image_per_thread[i][offset2][2]/static_cast<Float>(num_samples);
            }
}

void IntegratorAD_PathSpace::preprocessIndirect(const Scene &scene, int max_bounces, const GuidingOptions& opts, ptr<float> data) const {
    if ( !opts.quiet )
        std::cout << "[INFO] Indirect Guiding: #camPath = " << opts.num_cam_path << ", #lightPath = " << opts.num_light_path << std::endl;
    std::vector<RadImpNode> rad_nodes, imp_nodes;
    buildPhotonMap(scene, opts.num_cam_path, max_bounces, rad_nodes, false);
    buildPhotonMap(scene, opts.num_light_path, max_bounces-1, imp_nodes, true);
    if ( !opts.quiet )
        std::cout << "[INFO] Indirect Guiding: #rad_nodes = " << rad_nodes.size() << ", #imp_nodes = " << imp_nodes.size() << std::endl;

    // Build up the point KDtree for query
    PointCloud<Float> rad_cloud, imp_cloud;
    rad_cloud.pts.resize(rad_nodes.size());
    for (size_t i = 0; i < rad_nodes.size(); i++) {
        rad_cloud.pts[i].x = rad_nodes[i].p[0];
        rad_cloud.pts[i].y = rad_nodes[i].p[1];
        rad_cloud.pts[i].z = rad_nodes[i].p[2];
    }
    imp_cloud.pts.resize(imp_nodes.size());
    for (size_t i = 0; i < imp_nodes.size(); i++) {
        imp_cloud.pts[i].x = imp_nodes[i].p[0];
        imp_cloud.pts[i].y = imp_nodes[i].p[1];
        imp_cloud.pts[i].z = imp_nodes[i].p[2];
    }
    KDtree<Float> rad_indices(3, rad_cloud, nanoflann::KDTreeSingleIndexAdaptorParams(10));
    KDtree<Float> imp_indices(3, imp_cloud, nanoflann::KDTreeSingleIndexAdaptorParams(10));
    imp_indices.buildIndex();
    rad_indices.buildIndex();

    // Compute the importance grid
    const int nworker = omp_get_num_procs();
    std::vector<RndSampler> samplers;
    for ( int i = 0; i < nworker; ++i ) samplers.push_back(RndSampler(13, i));
    const std::vector<int> &params = opts.params;
    int finished_block = 0;
    #pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for ( int omp_i = 0; omp_i < params[0]*params[1]; ++omp_i ) {
        size_t matched_indices[NUM_NEAREST_NEIGHBORS];
        const int tid = omp_get_thread_num();
        RndSampler &sampler = samplers[tid];

        const int i = omp_i/params[1], j = omp_i % params[1];
        for ( int k = 0; k < params[2]; ++k ) {
            Float res = 0.0f;
            for ( int t = 0; t < params[3]; ++t ) {
                Vector rnd = sampler.next3D();
                rnd[0] = (rnd[0] + i)/static_cast<Float>(params[0]);
                rnd[1] = (rnd[1] + j)/static_cast<Float>(params[1]);
                rnd[2] = (rnd[2] + k)/static_cast<Float>(params[2]);
                int shape_id;
                RayAD edgeRay;
                Float edgePdf;
                const Edge &rEdge = scene.sampleEdgeRay(rnd, shape_id, edgeRay, edgePdf);
                if ( shape_id < 0 ) continue;

                const Shape& shape = *scene.shape_list[shape_id];
                const Medium* ptr_med = shape.med_ext_id < 0 ? nullptr : scene.medium_list[shape.med_ext_id];
                // construct the boundary segment
                EventRecord bEvt0, bEvt1;           // bEvt0 lies on the discontinuity curve/surface
                bEvt0.ptr_med = bEvt1.ptr_med = ptr_med;
                Spectrum boundary_throughput(Spectrum::Ones());
                int depth_boundary = -1;
                Ray _edgeRay = edgeRay.toRay();
                bool valid_segment = traceRayForBoundaryEndpoint(scene, &sampler, _edgeRay, max_bounces-1, depth_boundary, bEvt0, boundary_throughput) &&
                                     traceRayForBoundaryEndpoint(scene, &sampler, _edgeRay.flipped(), max_bounces, depth_boundary, bEvt1, boundary_throughput);
                if ( !valid_segment || max_bounces == depth_boundary) continue;

                // evaluate the boundary segment
                Float baseValue = 0.0;
                Vector n;
                {
                    Float dist = ((bEvt0.onSurface ? bEvt0.its.p : bEvt0.scatter) - (bEvt1.onSurface ? bEvt1.its.p : bEvt1.scatter)).norm();
                    Float dist1 = (edgeRay.org.val - (bEvt1.onSurface ? bEvt1.its.p : bEvt1.scatter)).norm();
                    Vector e = (shape.getVertex(rEdge.v0) - shape.getVertex(rEdge.v1)).normalized().cross(_edgeRay.dir);
                    Vector e1 = shape.getVertex(rEdge.v2) - shape.getVertex(rEdge.v0);
                    Float sinphi = e.norm();
                    if ( bEvt0.onSurface ) {
                        Vector proj = e.cross(bEvt0.its.geoFrame.n).normalized();
                        Float sinphi2 = _edgeRay.dir.cross(proj).norm();
                        n = bEvt0.its.geoFrame.n.cross(proj).normalized();
                        Float deltaV = math::signum(e.dot(e1))*math::signum(e.dot(n));
                        if ( sinphi > Epsilon && sinphi2 > Epsilon )
                            baseValue = deltaV*(dist1/dist)*(sinphi/sinphi2) * std::abs(bEvt0.its.geoFrame.n.dot(-_edgeRay.dir));
                    } else {
                        n = e.normalized();
                        Float deltaV = math::signum(e.dot(e1));
                        if ( sinphi > Epsilon )
                            baseValue = deltaV * (dist1/dist) * sinphi;
                    }
                }
                if ( std::abs(baseValue) < Epsilon) continue;

                int max_interactions = max_bounces - depth_boundary;
                assert(max_interactions > 0);

                VectorAD x1;
                FloatAD J1;
                if ( bEvt1.onSurface ) {
                    VectorAD n1;
                    scene.getPoint(bEvt1.its, x1, n1, J1);
                } else {
                    if ( !bEvt1.ptr_med->getPoint(bEvt1.scatter, x1, J1) ) continue;
                }

                VectorAD u2;
                Float value1 = 0.0;
                if ( getBoundaryEndpointFromBoundaryRay(scene, bEvt0, x1, edgeRay.org, u2) ) {
                    if ( u2.der.isZero(Epsilon) ) continue;
                    SpectrumAD valAD;
                    for (int j = 0; j < nder; j++)
                        valAD.grad(j) = baseValue * n.dot(u2.grad(j)) * boundary_throughput / edgePdf;
                    value1 = valAD.der.abs().maxCoeff();

                    const Vector& p0 = bEvt0.onSurface ? bEvt0.its.p : bEvt0.scatter;
                    const Vector& p1 = bEvt1.onSurface ? bEvt1.its.p : bEvt1.scatter;
                    Float pt_rad[3] = {p1[0], p1[1], p1[2]};
                    Float pt_imp[3] = {p0[0], p0[1], p0[2]};
                    Float matched_r2_rad, matched_r2_imp;
                    int num_nearby_rad = queryPhotonMap(rad_indices, pt_rad, matched_indices, matched_r2_rad);
                    std::vector<Spectrum> radiance(max_interactions+1, Spectrum::Zero());
                    for (int m = 0; m < num_nearby_rad; m++) {
                        const RadImpNode& node = rad_nodes[matched_indices[m]];
                        if (node.depth <= max_interactions)
                            radiance[node.depth] += node.val;
                    }

                    int num_nearby_imp = queryPhotonMap(imp_indices, pt_imp, matched_indices, matched_r2_imp);
                    std::vector<Spectrum> importance(max_interactions, Spectrum::Zero());
                    for (int m = 0; m < num_nearby_imp; m++) {
                        const RadImpNode& node = imp_nodes[matched_indices[m]];
                        if (node.depth < max_interactions)
                            importance[node.depth] += node.val;
                    }


                    Spectrum value2 = Spectrum::Zero();
                    int impStart = 1;
                    for (int m = 0; m <= max_interactions; m++) {
                        for (int n = impStart; n < max_interactions-m; n++)
                            value2 += radiance[m] * importance[n];
                    }

                    value1 *= value2.maxCoeff() / (matched_r2_rad * matched_r2_imp);
                    assert(std::isfinite(value1));
                }
                res += value1;
            }

            Float avg = res/static_cast<Float>(params[3]);
            data[static_cast<long long>(omp_i)*params[2] + k] = static_cast<float>(avg);
        }

        if ( !opts.quiet ) {
            omp_set_lock(&messageLock);
            progressIndicator(static_cast<Float>(++finished_block)/(params[0]*params[1]));
            omp_unset_lock(&messageLock);
        }
    }
    if ( !opts.quiet ) std::cout << std::endl;
}


/***** Point Light Direct Boundary Term ******/
void IntegratorAD_PathSpace::renderEdgesPointLight(const Scene &scene, const RenderOptions &options, ptr<float> rendered_image) const {
    if (scene.point_emitter_list.empty())
		return;
	const Camera &camera = scene.camera;
	int num_pixels = camera.getNumPixels();
	const int nworker = omp_get_num_procs();
	// const int nworker = 1;
	std::vector<std::vector<Spectrum> > image_per_thread(nworker);
	for (int i = 0; i < nworker; i++) image_per_thread[i].resize(nder*num_pixels, Spectrum(0.0f)); // init image per thread

	constexpr int num_samples_per_block = 128; // parallel block
	int num_samples_point = options.num_samples_secondary_edge_direct;
    assert(num_samples_point > 0);
	long long num_samples = static_cast<long long>(num_samples_point) * num_pixels; // total sample num
	const long long num_block = static_cast<long long>(std::ceil(static_cast<Float>(num_samples)/num_samples_per_block)); 
	num_samples = num_block * num_samples_per_block;

	#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
	for(long long index_block = 0; index_block < num_block; ++index_block) {
		std::pair<int, Spectrum> importance[BDPT_MAX_PATH_LENGTH];
		for(int omp_i = 0; omp_i < num_samples_per_block; ++omp_i)
		{
			const int tid = omp_get_thread_num();
			RndSampler sampler(options.seed, index_block * num_samples_per_block + omp_i);
			
			/* ==================== main function ==================== */
			EdgeRaySamplingRecordAD eRec;
			// sample an edge ray
			bool success = scene.ptr_psEdgeManager->sampleEdgeRayPointLight(sampler.next2D(), eRec);
            if(!success)
                continue;
			const Medium *ptr_medEdge = nullptr;
			if( eRec.shape->med_ext_id >= 0)
				ptr_medEdge = scene.medium_list[eRec.shape->med_ext_id];
			Spectrum throughput = Spectrum::Ones();
			int depth_boundary = 0;
			Float T = scene.evalTransmittance(eRec.ray.toRay(), true, ptr_medEdge, eRec.dist.val, 
								&sampler, options.max_bounces, depth_boundary);
			throughput *= T;
			// Trace opposite ray
			EventRecord bEvt;
			bEvt.ptr_med = ptr_medEdge;
			Ray ray(eRec.ray.org.val, -eRec.ray.dir.val); // opposite ray
			if (!traceRayForBoundaryEndpoint(scene, &sampler, ray, options.max_bounces, depth_boundary, bEvt, throughput))
				continue;
			
			Spectrum intensity = eRec.emitter->getIntensity();
						
			// change of variables
			Float xS_xB = eRec.dist.val;
			Vector xS = eRec.p.val;
			Vector xD = bEvt.x();
			Float xS_xD = (xS - xD).norm();
			const Vector &v_0 = eRec.shape->getVertex(eRec.edge->v0);
			const Vector &v_1 = eRec.shape->getVertex(eRec.edge->v1);
			const Vector &v_2 = eRec.shape->getVertex(eRec.edge->v2);
			Vector v2 = (v_2 - v_0).normalized();
			Vector v = (v_0 - v_1).normalized(); // edge vector
			Vector n = v.cross(eRec.ray.dir.val); // unnormalized face normal
			Float sinPhiB = n.norm();
			n.normalized();
			n *= -math::signum(n.dot(v2)); // make sure n pointing to the visible side
			Float J_B =  xS_xD / xS_xB * sinPhiB;

			Spectrum attenuated_Le = intensity / (xS_xD * xS_xD) * throughput;

			// calculate change rate
			FloatAD u; // change rate
			Float baseValue;
			if (bEvt.onSurface)
			{
				const Shape *shape_D = bEvt.its.ptr_shape;
				const Vector3i &f = shape_D->getIndices(bEvt.its.indices[1]);
				const VectorAD &v0 = shape_D->getVertexAD(f[0]);
				const VectorAD &v1 = shape_D->getVertexAD(f[1]);
				const VectorAD &v2 = shape_D->getVertexAD(f[2]);
				VectorAD x_D;
				rayIntersectTriangleAD(v0, v1, v2, RayAD(eRec.p, -eRec.ray.dir), x_D);
				
				Vector proj = n.cross(bEvt.its.geoFrame.n).normalized();
				Float sinPhiD = eRec.ray.dir.val.cross(proj).norm();
				J_B = J_B / sinPhiD;
				Vector n_D = bEvt.its.geoFrame.n.cross(proj).normalized();
				n_D *= math::signum(n.dot(n_D)); // make sure n_D pointing the visible side
				VectorAD nAD(n_D);
				u = nAD.dot(x_D); // normal velocity
				u.val = 0;
				Float cosPhiD = bEvt.its.geoFrame.n.dot(eRec.ray.dir.val);
				baseValue = J_B * cosPhiD;
			}
			else
			{
				VectorAD x_D;
				FloatAD J_D;
				if (!bEvt.ptr_med->getPoint(bEvt.scatter, x_D, J_D))
					continue;
				RayAD r(eRec.p, -eRec.ray.dir);
				Float t = xS_xD;
				VectorAD x_D1 = r(t);
				assert((x_D.val - x_D1.val).isZero(Epsilon));
				x_D.der += x_D1.der;
				VectorAD nAD(n);
				u = nAD.dot(x_D);
				u.val = 0;
				baseValue = J_B;
			}

			if (u.der.isZero(Epsilon))
				continue;
			
			int max_interactions = options.max_bounces - depth_boundary;
			int num_indirect_path = weightedImportance(scene, &sampler, bEvt, max_interactions, nullptr, importance);
			if( num_indirect_path > 0) {
				FloatAD contribBoundarySeg;
				for( int j = 0; j < nder; ++j)
					contribBoundarySeg.grad(j) = -baseValue * u.grad(j);
				for( int j = 0; j < nder; ++j) {
					for(int k = 0; k < num_indirect_path; k++)
						image_per_thread[tid][j * num_pixels + importance[k].first] += contribBoundarySeg.grad(j) * importance[k].second * attenuated_Le / eRec.pdf;
				}
			}

			/* ==================== main function ==================== */
		}
		if (!options.quiet)
		{
			omp_set_lock(&messageLock);
			progressIndicator(Float(index_block + 1) / num_block);
			omp_unset_lock(&messageLock);
		}
	}
	if (!options.quiet)
		std::cout << std::endl;

	for (int i = 0; i < nworker; ++i)
		for (int j = 0; j < nder; ++j)
			for (int idx_pixel = 0; idx_pixel < num_pixels; ++idx_pixel)
			{
				int offset1 = ((j + 1) * num_pixels + idx_pixel) * 3,
					offset2 = j * num_pixels + idx_pixel;
				rendered_image[offset1] += image_per_thread[i][offset2][0] / static_cast<Float>(num_samples);
				rendered_image[offset1 + 1] += image_per_thread[i][offset2][1] / static_cast<Float>(num_samples);
				rendered_image[offset1 + 2] += image_per_thread[i][offset2][2] / static_cast<Float>(num_samples);
			}
}

void IntegratorAD_PathSpace::preprocessDirectPointLight(const Scene &scene, int max_bounces, const GuidingOptions &opts, ptr<float> data) const
{
    // build photon map and KDtree for accelerating query
    std::vector<RadImpNode> rad_nodes;
    buildPhotonMap(scene, opts.num_cam_path, max_bounces, rad_nodes, false);
    PointCloud<Float> rad_cloud;
    rad_cloud.pts.resize(rad_nodes.size());
    for (size_t i = 0; i < rad_nodes.size(); i++)
    {
        rad_cloud.pts[i].x = rad_nodes[i].p[0];
        rad_cloud.pts[i].y = rad_nodes[i].p[1];
        rad_cloud.pts[i].z = rad_nodes[i].p[2];
    }
    KDtree<Float> rad_indices(3, rad_cloud, nanoflann::KDTreeSingleIndexAdaptorParams(10));
    rad_indices.buildIndex();

    /* fill in the grid */
    auto params = opts.params;
    params[2] = params[3]; //!
    // const int nworker = omp_get_num_procs();
    const int nworker = 1;
    std::vector<RndSampler> samplers;
    for (int i = 0; i < nworker; ++i)
        samplers.push_back(RndSampler(13, i));

    int finished_block = 0;
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int omp_i = 0; omp_i < params[0]; ++omp_i)
    {
        int i = omp_i;
        const int tid = omp_get_thread_num();
        RndSampler &sampler = samplers[tid];
        for (int j = 0; j < params[1]; ++j)
        {
            Float res = 0.0f;
            /* sample in a grid block */
            for (int t = 0; t < params[2]; ++t)
            {
                /* sample a point in a grid block */
                Vector2 rnd = sampler.next2D();
                rnd[0] = (rnd[0] + i) / static_cast<Float>(params[0]); // egde point
                rnd[1] = (rnd[1] + j) / static_cast<Float>(params[1]); // emitter point

                /* convert the rnd to an edge ray */
                int shape_id;
                RayAD EdgeRay;
                Float edge_Pdf;
                EdgeRaySamplingRecordAD eRec;
                bool success = scene.ptr_psEdgeManager->sampleEdgeRayPointLight(rnd, eRec);
                if(!success)
                    continue;
                if (eRec.shape == nullptr)
                    continue;

                const Shape *shape = eRec.shape;
                const Medium *ptr_medEdge = nullptr;
                if (shape->med_ext_id >= 0)
                    ptr_medEdge = scene.medium_list[shape->med_ext_id];
                Spectrum throughput = Spectrum::Ones();
                int depth_boundary = 0;

                Float T = scene.evalTransmittance(eRec.ray.toRay(), true, ptr_medEdge, eRec.dist.val, 
								&sampler, max_bounces, depth_boundary);
                throughput *= T;

                EventRecord bEvt;
                bEvt.ptr_med = ptr_medEdge;
                if (!traceRayForBoundaryEndpoint(scene, &sampler, Ray(eRec.ray.org.val, -eRec.ray.dir.val), max_bounces, depth_boundary, bEvt, throughput))
                    continue;

                // change of variables
                Float xS_xB = eRec.dist.val;
                Vector xS = eRec.p.val;
                Vector xD = bEvt.x();
                Float xS_xD = (xS - xD).norm();
                const Vector &v_0 = eRec.shape->getVertex(eRec.edge->v0);
                const Vector &v_1 = eRec.shape->getVertex(eRec.edge->v1);
                const Vector &v_2 = eRec.shape->getVertex(eRec.edge->v2);
                Vector v2 = (v_2 - v_0).normalized();
                Vector v = (v_0 - v_1).normalized();  // edge vector
                Vector n = v.cross(eRec.ray.dir.val); // unnormalized face normal
                Float sinPhiB = n.norm();
                n.normalized();
                n *= -math::signum(n.dot(v2)); // make sure n pointing to the visible side
                Float J_B = xS_xD / xS_xB * sinPhiB;

                // calculate change rate
                FloatAD u; // change rate
                Float baseValue;
                if (bEvt.onSurface)
                {
                    const Shape *shape_D = bEvt.its.ptr_shape;
                    const Vector3i &f = shape_D->getIndices(bEvt.its.indices[1]);
                    const VectorAD &v0 = shape_D->getVertexAD(f[0]);
                    const VectorAD &v1 = shape_D->getVertexAD(f[1]);
                    const VectorAD &v2 = shape_D->getVertexAD(f[2]);
                    VectorAD x_D;
                    rayIntersectTriangleAD(v0, v1, v2, RayAD(eRec.p, -eRec.ray.dir), x_D);

                    Vector proj = n.cross(bEvt.its.geoFrame.n).normalized();
                    Float sinPhiD = eRec.ray.dir.val.cross(proj).norm();
                    J_B = J_B / sinPhiD;
                    Vector n_D = bEvt.its.geoFrame.n.cross(proj).normalized();
                    n_D *= math::signum(n.dot(n_D)); // make sure n_D pointing the visible side
                    VectorAD nAD(n_D);
                    u = nAD.dot(x_D); // normal velocity
                    u.val = 0;
                    Float cosPhiD = bEvt.its.geoFrame.n.dot(eRec.ray.dir.val);
                    baseValue = J_B * cosPhiD;
                }
                else
                {
                    VectorAD x_D;
                    FloatAD J_D;
                    if (!bEvt.ptr_med->getPoint(bEvt.scatter, x_D, J_D))
                        continue;
                    RayAD r(eRec.p, -eRec.ray.dir);
                    Float t = xS_xD;
                    VectorAD x_D1 = r(t);
                    assert((x_D.val - x_D1.val).isZero(Epsilon));
                    x_D.der += x_D1.der;
                    VectorAD nAD(n);
                    u = nAD.dot(x_D);
                    u.val = 0;
                    baseValue = J_B;
                }

                if (u.der.isZero(Epsilon))
                    continue;
                Spectrum intensity = eRec.emitter->getIntensity();
                Spectrum attenuated_Le = intensity / (xS_xD * xS_xD) * throughput;

                int max_interactions = max_bounces - depth_boundary;

                const Vector& p = bEvt.onSurface ? bEvt.its.p : bEvt.scatter;
                Float pt_rad[3] = {p[0], p[1], p[2]};
                Float matched_r2_rad;
                size_t matched_indices[NUM_NEAREST_NEIGHBORS];
                int num_nearby_rad = queryPhotonMap(rad_indices, pt_rad, matched_indices, matched_r2_rad);
                Spectrum radiance(0.0);
                for(int m = 0; m < num_nearby_rad; m++)
                {
                    const RadImpNode& node = rad_nodes[matched_indices[m]];
                    if(node.depth <= max_interactions)
                        radiance += node.val;
                }
                assert(!radiance.isZero(Epsilon));
                FloatAD contribBoundarySeg;
                for(int j = 0; j < nder; ++j)
                    contribBoundarySeg.grad(j) = -baseValue * u.grad(j);
                Float value = contribBoundarySeg.der.abs().maxCoeff() * (radiance * attenuated_Le).maxCoeff()/eRec.pdf;
                value /= matched_r2_rad;
                assert(std::isfinite(value));
                res += value;
            }
            
            Float avg = res / static_cast<Float>(params[2]);
            data[static_cast<long long>(omp_i) * params[1] + j] = static_cast<float>(avg);
        }
        if ( !opts.quiet ) {
            omp_set_lock(&messageLock);
            progressIndicator(Float(++finished_block)/params[0]);
            omp_unset_lock(&messageLock);
        }
    }
    if ( !opts.quiet ) std::cout << std::endl;

}