#include "binnedPathADps.h"
#include "scene.h"
#include "sampler.h"
#include "rayAD.h"
#include "intersectionAD.h"
#include "math_func.h"
#include <iomanip>
#include <chrono>

static const int nworker = omp_get_num_procs();

#define ALTERNATIVE_EDGE_SAMPLING

// Direct edges

void Binned_PathTracerAD_PathSpace::evalEdgeDirect(const Scene &scene, int shape_id, const Edge &rEdge, const RayAD &edgeRay, RndSampler *sampler, int max_bounces,
    EdgeEvaluationRecord &eRec, Float &path_length, std::vector<std::tuple<int, Spectrum, Float>> &record, bool quiet) const
{
#ifndef USE_BOUNDARY_NEE
    std::cerr << "Without next-event estimation (NEE), evalEdgeDirect() should not be used." << std::endl;
    assert(false);
#endif
    Intersection &its1 = eRec.its1, &its2 = eRec.its2;
    int &idx_pixel = eRec.idx_pixel;

    eRec.value0.zero();
    eRec.value1.zero();
    idx_pixel = -1;

    eRec.sensor_vals = Array4(0.0);
    eRec.pixel_indices = Array4i(-1);

    const Shape &shape = *scene.shape_list[shape_id];
    Ray _edgeRay = edgeRay.toRay();
    const Vector &d1 = _edgeRay.dir;
    if (scene.rayIntersect(_edgeRay, true, its2) && scene.rayIntersect(_edgeRay.flipped(), true, its1)) {
        const Vector2i ind0(shape_id, rEdge.f0), ind1(shape_id, rEdge.f1);
        if (its1.indices != ind0 && its1.indices != ind1 && its2.indices != ind0 && its2.indices != ind1 && its2.isEmitter()) {
            const Float gn1d1 = its1.geoFrame.n.dot(d1), sn1d1 = its1.shFrame.n.dot(d1),
                gn2d1 = its2.geoFrame.n.dot(-d1), sn2d1 = its2.shFrame.n.dot(-d1);
            assert(std::abs(its1.wi.z() - sn1d1) < Epsilon && std::abs(its2.wi.z() - sn2d1) < Epsilon);

            bool valid1 = (its1.ptr_bsdf->isTransmissive() && math::signum(gn1d1)*math::signum(sn1d1) > 0.5f) || (!its1.ptr_bsdf->isTransmissive() && gn1d1 > Epsilon && sn1d1 > Epsilon),
                valid2 = (its2.ptr_bsdf->isTransmissive() && math::signum(gn2d1)*math::signum(sn2d1) > 0.5f) || (!its2.ptr_bsdf->isTransmissive() && gn2d1 > Epsilon && sn2d1 > Epsilon);
            if (valid1 && valid2) {
                // const Shape &shape1 = *scene.shape_list[its1.indices[0]]; const Vector3i &f1 = shape1.getIndices(its1.indices[1]);
                const Shape &shape2 = *scene.shape_list[its2.indices[0]]; const Vector3i &f2 = shape2.getIndices(its2.indices[1]);
                const VectorAD &v0 = shape2.getVertexAD(f2[0]), &v1 = shape2.getVertexAD(f2[1]), &v2 = shape2.getVertexAD(f2[2]);
                Float baseValue = 0.0f;
                Float dist = (its2.p - its1.p).norm();
                Vector n;
                {
                    Float cos2 = std::abs(gn2d1);
                    Vector e = (shape.getVertex(rEdge.v0) - shape.getVertex(rEdge.v1)).normalized().cross(d1);
                    Float sinphi = e.norm();
                    Vector proj = e.cross(its2.geoFrame.n).normalized();
                    Float sinphi2 = d1.cross(proj).norm();
                    n = its2.geoFrame.n.cross(proj).normalized();

                    Float deltaV;
                    Vector e1 = shape.getVertex(rEdge.v2) - shape.getVertex(rEdge.v0);
                    deltaV = math::signum(e.dot(e1))*math::signum(e.dot(n));

                    if (sinphi > Epsilon && sinphi2 > Epsilon)
                        baseValue = deltaV * (its1.t / dist)*(sinphi / sinphi2)*cos2;
                }

                if (std::abs(baseValue) > Epsilon) {
                    VectorAD u2;

                    // Direct
                    {
                        Vector2 pix_uv;
                        Vector d0;
                        Float sensor_val = scene.sampleAttenuatedSensorDirect(its1, sampler, 0, pix_uv, d0);
                        if (sensor_val > Epsilon) {
                            RayAD cameraRay = scene.camera.samplePrimaryRayAD(pix_uv[0], pix_uv[1]);
                            IntersectionAD its;
                            if (scene.rayIntersectAD(cameraRay, false, its) && (its.p.val - its1.p).norm() < ShadowEpsilon) {
                                bool valid = rayIntersectTriangleAD(v0, v1, v2, RayAD(its.p, edgeRay.org - its.p), u2);
                                if (valid) {
                                    path_length = dist + (its1.p - scene.camera.cpos.val).norm();

                                    Vector d0_local = its1.toLocal(d0);
                                    Spectrum value0 = its1.evalBSDF(d0_local, EBSDFMode::EImportanceWithCorrection) * baseValue * its2.Le(-d1);
                                    for (int j = 0; j < nder; ++j)
                                        eRec.value0.grad(j) = value0 * n.dot(u2.grad(j));

                                    Matrix2x4 pix_uvs;
                                    eRec.sensor_vals = scene.sampleAttenuatedSensorDirect(its1, sampler, 0, pix_uvs, d0);
                                    for (int i = 0; i < 4; i++) {
                                        eRec.pixel_indices[i] = scene.camera.getPixelIndex(pix_uvs.col(i));
                                    }
                                    idx_pixel = scene.camera.getPixelIndex(pix_uv);
                                }
                            }
                        }
                    }

                    // indirect
                    if (max_bounces > 1) {
                        VectorAD x1, n1;
                        FloatAD J1;
                        scene.getPoint(its1, x1, n1, J1);
                        bool valid = rayIntersectTriangleAD(v0, v1, v2, RayAD(x1, edgeRay.org - x1), u2);
                        if (valid) {
                            for (int j = 0; j < nder; ++j)
                                eRec.value1.grad(j) = baseValue * n.dot(u2.grad(j));

                            BinnedWeight weight(1);
                            weight[0].resize(1);
                            Spectrum tmp_spec(0.f);
                            weight[0][0] = std::tie(tmp_spec, dist);
                            weightedImportance(scene, sampler, its1, max_bounces - 1, true, weight, record);
                        }
                    }
                }
            }
        }

        for (int j = 0; j < nder; ++j) {
            for (int i = 0; i < 4; i++) {
                if (!std::isfinite(eRec.value0.grad(j)[0]) || !std::isfinite(eRec.value0.grad(j)[1]) || !std::isfinite(eRec.value0.grad(j)[2])) {
                    if (!quiet) {
                        omp_set_lock(&messageLock);
                        std::cerr << std::fixed << std::setprecision(2)
                            << "\n[WARN] Invalid gradient: [" << eRec.value0.grad(j).transpose() << "]" << std::endl;
                        omp_unset_lock(&messageLock);
                    }
                    eRec.value0.grad(j).setZero();
                }
            }

            if (!std::isfinite(eRec.value1.grad(j))) {
                if (!quiet) {
                    omp_set_lock(&messageLock);
                    std::cerr << std::fixed << std::setprecision(2)
                        << "\n[WARN] Invalid gradient: [" << eRec.value1.grad(j) << "]" << std::endl;
                    omp_unset_lock(&messageLock);
                }
                eRec.value1.grad(j) = 0.0f;
            }
        }
    }
}

void Binned_PathTracerAD_PathSpace::renderEdgesDirect(const Scene &scene, const RenderOptions &options, ptr<float> rendered_image) const {
#ifndef USE_BOUNDARY_NEE
    std::cerr << "Without next-event estimation (NEE), renderEdgesDirect() should not be used." << std::endl;
    assert(false);
#endif
    const Camera &camera = scene.camera;
    const auto &pif = camera.pif;
    int num_pixels = camera.getNumPixels();
    const int num_bins = pif->num_bins;
    const int nworker = omp_get_num_procs();
    std::vector<std::vector<Spectrum>> image_per_thread(nworker);
    for (int i = 0; i < nworker; i++) image_per_thread[i].resize(num_bins*nder*num_pixels, Spectrum(0.0f));

    constexpr int num_samples_per_block = 128;
    long long num_samples = static_cast<long long>(options.num_samples_secondary_edge_direct)*num_pixels;
    const long long num_block = static_cast<long long>(std::ceil(static_cast<Float>(num_samples) / num_samples_per_block));
    num_samples = num_block * num_samples_per_block;
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (long long index_block = 0; index_block < num_block; ++index_block) {
        for (int omp_i = 0; omp_i < num_samples_per_block; ++omp_i) {
            const int tid = omp_get_thread_num();
            RndSampler sampler(options.seed, m_taskId[tid] = index_block * num_samples_per_block + omp_i);
            int shape_id;
            RayAD edgeRay;
            Float edgePdf;
            const Edge &rEdge = scene.sampleEdgeRayDirect(sampler.next3D(), shape_id, edgeRay, edgePdf);
            if (shape_id >= 0) {
                std::vector<std::tuple<int, Spectrum, Float>> importance;
                EdgeEvaluationRecord eRec;
                Float path_length;
                evalEdgeDirect(scene, shape_id, rEdge, edgeRay, &sampler, options.max_bounces, eRec, path_length, importance, options.quiet);

                // one-bounce
                if (!eRec.sensor_vals.isZero()) {
                    Vector2i bin_range = pif->getBinIndexRange(path_length);
                    for (int idx_bin = bin_range[0]; idx_bin <= bin_range[1]; idx_bin++) {
                        int bin_offset = idx_bin * nder * num_pixels;
                        Float pif_kernel = pif->eval(path_length, idx_bin);
                        for (int i = 0; i < 4; i++) {
                            if (eRec.sensor_vals[i] > Epsilon) {
                                for (int j = 0; j < nder; ++j) {
                                    image_per_thread[tid][bin_offset + j * num_pixels + eRec.pixel_indices[i]] += pif_kernel * eRec.value0.grad(j) * eRec.sensor_vals[i] / edgePdf;
                                }
                            }
                        }
                    }
                }

                // multi-bounce
                if (!importance.empty()) {
                    Spectrum light_val = eRec.its2.Le(-edgeRay.dir.val);
                    for (const auto&[idx_pixel, value, total_path_length] : importance) {
                        Vector2i bin_range = pif->getBinIndexRange(total_path_length);
                        for (int idx_bin = bin_range[0]; idx_bin <= bin_range[1]; idx_bin++) {
                            int bin_offset = idx_bin * nder * num_pixels;
                            Float pif_kernel = pif->eval(total_path_length, idx_bin);
                            for (int j = 0; j < nder; ++j) {
                                image_per_thread[tid][bin_offset + j * num_pixels + idx_pixel] += pif_kernel * eRec.value1.grad(j)*value*light_val / edgePdf;
                            }
                        }
                    }
                }
            }
        }

        if (!options.quiet) {
            omp_set_lock(&messageLock);
            progressIndicator(Float(index_block + 1) / num_block);
            omp_unset_lock(&messageLock);
        }
    }
    if (!options.quiet) std::cout << std::endl;

    for (int i = 0; i < nworker; ++i)
        for (int idx_bin = 0; idx_bin < num_bins; ++idx_bin) {
            int bin_offset1 = idx_bin * (nder + 1) * num_pixels;
            int bin_offset2 = idx_bin * nder * num_pixels;
            for (int j = 0; j < nder; ++j)
                for (int idx_pixel = 0; idx_pixel < num_pixels; ++idx_pixel) {
                    int offset1 = (bin_offset1 + (j + 1)*num_pixels + idx_pixel) * 3,
                        offset2 = bin_offset2 + j * num_pixels + idx_pixel;
                    rendered_image[offset1    ] += image_per_thread[i][offset2][0] / static_cast<Float>(num_samples);
                    rendered_image[offset1 + 1] += image_per_thread[i][offset2][1] / static_cast<Float>(num_samples);
                    rendered_image[offset1 + 2] += image_per_thread[i][offset2][2] / static_cast<Float>(num_samples);
                }
        }
}

// Indirect edges

void Binned_PathTracerAD_PathSpace::renderEdges(const Scene &scene, const RenderOptions &options, ptr<float> rendered_image) const {
#ifdef USE_BOUNDARY_NEE
    if (options.max_bounces > 1) {
#else
    if (options.max_bounces >= 1) {
#endif
        const Camera &camera = scene.camera;
        const auto &pif = camera.pif;
        int num_pixels = camera.getNumPixels();
        const int num_bins = pif->num_bins;
        const int nworker = omp_get_num_procs();
        std::vector<std::vector<Spectrum> > image_per_thread(nworker);
        for (int i = 0; i < nworker; i++) image_per_thread[i].resize(num_bins*nder*num_pixels, Spectrum(0.0f));

        constexpr int num_samples_per_block = 128;
        long long num_samples = static_cast<long long>(options.num_samples_secondary_edge_indirect)*num_pixels;
        const long long num_block = static_cast<long long>(std::ceil(static_cast<Float>(num_samples) / num_samples_per_block));
        num_samples = num_block * num_samples_per_block;
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
        for (long long index_block = 0; index_block < num_block; ++index_block) {
            for (int omp_i = 0; omp_i < num_samples_per_block; omp_i++) {
                const int tid = omp_get_thread_num();
                RndSampler sampler(options.seed, m_taskId[tid] = index_block * num_samples_per_block + omp_i);
                // indirect contribution of edge term
                int shape_id;
                RayAD edgeRay;
                Float edgePdf;
                const Edge &rEdge = scene.sampleEdgeRay(sampler.next3D(), shape_id, edgeRay, edgePdf);
                if (shape_id >= 0) {
                    EdgeEvaluationRecord eRec;
                    evalEdge(scene, shape_id, rEdge, edgeRay, &sampler, eRec);
                    if (eRec.its1.isValid() && eRec.its2.isValid()) {
                        bool zeroVelocity = eRec.value0.der.isZero(Epsilon) && eRec.value1.der.isZero(Epsilon);
                        if (!zeroVelocity) {
                            traceRayFromEdgeSegement(scene, eRec, edgePdf, options.max_bounces, &sampler, image_per_thread[tid]);
                        }
                    }
                }
            }

            if (!options.quiet) {
                omp_set_lock(&messageLock);
                progressIndicator(Float(index_block + 1) / num_block);
                omp_unset_lock(&messageLock);
            }
        }
        if (!options.quiet)
            std::cout << std::endl;

        for (int i = 0; i < nworker; ++i)
            for (int idx_bin = 0; idx_bin < num_bins; ++idx_bin) {
                int bin_offset1 = idx_bin * (nder + 1) * num_pixels;
                int bin_offset2 = idx_bin * nder * num_pixels;
                for (int j = 0; j < nder; ++j)
                    for (int idx_pixel = 0; idx_pixel < num_pixels; ++idx_pixel) {
                        int offset1 = (bin_offset1 + (j + 1)*num_pixels + idx_pixel) * 3,
                            offset2 = bin_offset2 + j * num_pixels + idx_pixel;
                        rendered_image[offset1] += image_per_thread[i][offset2][0] / static_cast<Float>(num_samples);
                        rendered_image[offset1 + 1] += image_per_thread[i][offset2][1] / static_cast<Float>(num_samples);
                        rendered_image[offset1 + 2] += image_per_thread[i][offset2][2] / static_cast<Float>(num_samples);
                    }
            }
    }
}

// Helper functions
void Binned_PathTracerAD_PathSpace::weightedImportance(const Scene& scene, RndSampler* sampler, const Intersection& its0, int max_depth, bool is_direct,
    const BinnedWeight& weight, std::vector<std::tuple<int, Spectrum, Float>>& ret) const {
    Intersection its = its0;
    Spectrum throughput(1.0f);
    Vector d0;
    Matrix2x4 pix_uvs;
    Ray ray_sensor;
    Float current_path_length = 0;
    const auto &camera = scene.camera;

    for (int d_sensor = 1; d_sensor <= max_depth; d_sensor++) {
        // sample a new direction
        Vector wo_local, wo;
        Float bsdf_pdf, bsdf_eta;
        Spectrum bsdf_weight = its.sampleBSDF(sampler->next3D(), wo_local, bsdf_pdf, bsdf_eta, EBSDFMode::EImportanceWithCorrection);
        if (bsdf_weight.isZero())
            break;
        wo = its.toWorld(wo_local);
        Vector wi = its.toWorld(its.wi);
        Float wiDotGeoN = wi.dot(its.geoFrame.n), woDotGeoN = wo.dot(its.geoFrame.n);
        if (wiDotGeoN * its.wi.z() <= 0 || woDotGeoN * wo_local.z() <= 0)
            break;
        throughput *= bsdf_weight;
        ray_sensor = Ray(its.p, wo);
        scene.rayIntersect(ray_sensor, true, its);
        if (!its.isValid())
            break;

        // Update path length
        current_path_length += its.t;

        Array4 sensor_vals = scene.sampleAttenuatedSensorDirect(its, sampler, 0, pix_uvs, d0);
        if (!sensor_vals.isZero()) {
            Vector wi = -ray_sensor.dir;
            Vector wo = d0, wo_local = its.toLocal(wo);
            Float wiDotGeoN = wi.dot(its.geoFrame.n), woDotGeoN = wo.dot(its.geoFrame.n);
            if (wiDotGeoN * its.wi.z() > 0 && woDotGeoN * wo_local.z() > 0) {
                assert(ret.size() <= BDPT_MAX_PATH_LENGTH * 4);

                if (is_direct) {
                    assert(weight.size() == 1);

                    for (int i = 0; i < 4; i++) {
                        if (sensor_vals[i] > Epsilon) {
                            int idx_pixel = camera.getPixelIndex(pix_uvs.col(i));
                            Spectrum L = throughput * its.evalBSDF(wo_local, EBSDFMode::EImportanceWithCorrection) * sensor_vals[i];
                            Float total_path_length = current_path_length + (its.p - camera.cpos.val).norm() + std::get<1>(weight[0][0]);
                            ret.push_back(std::tie(idx_pixel, L, total_path_length));
                        }
                    }
                }
                else {
                    for (int d_emitter = 0; d_emitter <= max_depth - d_sensor; d_emitter++) {
                        for (const auto & [L_, path_length] : weight[d_emitter]) {
                            for (int i = 0; i < 4; i++) {
                                if (sensor_vals[i] > Epsilon) {
                                    int idx_pixel = camera.getPixelIndex(pix_uvs.col(i));
                                    Float total_path_length = current_path_length + (its.p - camera.cpos.val).norm() + path_length;
                                    Spectrum L = L_ * throughput * its.evalBSDF(wo_local, EBSDFMode::EImportanceWithCorrection) * sensor_vals[i];
                                    ret.push_back(std::tie(idx_pixel, L, total_path_length));
                                }
                            }
                        }
                    }
                }
            }
        }
    }
}

void Binned_PathTracerAD_PathSpace::radiance(const Scene& scene, RndSampler* sampler, const Intersection &its, int max_bounces,
        BinnedWeight &ret) const {
    Intersection _its(its);
    Float current_path_length = 0;

    if (_its.isEmitter()) {
        Spectrum L = _its.Le(_its.toWorld(_its.wi));
        Float path_length = current_path_length;
        ret[0].push_back(std::tie(L, path_length));
    }

    Spectrum throughput(1.0f);
    for (int d_emitter = 1; d_emitter <= max_bounces; d_emitter++) {
        // Direct illumination
        Float pdf_nee;
        Float dist;
        Vector wo;
        Spectrum value = scene.sampleEmitterDirect(_its, sampler->next2D(), sampler, wo, pdf_nee, &dist);
        if (!value.isZero(Epsilon)) {
            Spectrum bsdf_val = _its.evalBSDF(wo);
            Float bsdf_pdf = _its.pdfBSDF(wo);
            Float mis_weight = pdf_nee / (pdf_nee + bsdf_pdf);

            Spectrum L = throughput * value*bsdf_val*mis_weight;
            Float path_length = current_path_length + dist;
            ret[d_emitter].push_back(std::tie(L, path_length));
        }

        // Indirect illumination
        Float bsdf_pdf, bsdf_eta;
        Spectrum bsdf_weight = _its.sampleBSDF(sampler->next3D(), wo, bsdf_pdf, bsdf_eta);
        if (bsdf_weight.isZero(Epsilon)) break;
        wo = _its.toWorld(wo);

        Ray ray_emitter(_its.p, wo);
        if (!scene.rayIntersect(ray_emitter, true, _its)) break;
        throughput *= bsdf_weight;
        current_path_length += (_its.p - ray_emitter.org).norm();

        if (_its.isEmitter()) {
            Spectrum light_contrib = _its.Le(-ray_emitter.dir);
            if (!light_contrib.isZero(Epsilon)) {
                Float dist_sq = (_its.p - ray_emitter.org).squaredNorm();
                Float G = _its.geoFrame.n.dot(-ray_emitter.dir) / dist_sq;
                pdf_nee = scene.pdfEmitterSample(_its) / G;
                Float mis_weight = bsdf_pdf / (pdf_nee + bsdf_pdf);

                Spectrum L = throughput * light_contrib*mis_weight;
                ret[d_emitter].push_back(std::tie(L, current_path_length));
            }
        }
    }
}

void Binned_PathTracerAD_PathSpace::traceRayFromEdgeSegement(const Scene &scene, const EdgeEvaluationRecord& eRec, Float edgePdf, int max_depth, RndSampler *sampler, std::vector<Spectrum> &image) const {
    assert(max_depth > 0);
    const auto &camera = scene.camera;
    const auto &pif = camera.pif;
    const int num_bins = pif->num_bins;
    const int num_pixels = image.size() / num_bins / nder;

    /*** Trace ray towards emitter from its2 ***/
    BinnedWeight weight(max_depth);
    {
        radiance(scene, sampler, eRec.its2, max_depth - 1, weight);
#ifdef USE_BOUNDARY_NEE
        weight[0].clear();
#endif

        Float edge_length = (eRec.its2.p - eRec.its1.p).norm();
        for (int d_emitter = 0; d_emitter < max_depth; d_emitter++) {
            for (auto& w : weight[d_emitter]) {
                std::get<1>(w) += edge_length;
            }
        }
    }

    /*** Trace ray towards sensor from its1 ***/
    if (!eRec.sensor_vals.isZero()) {
        Float camera_its1_length = (scene.camera.cpos.val - eRec.its1.p).norm();

        // Direct connect to sensor
        for (int d_emitter = 0; d_emitter < max_depth; d_emitter++) {
            for (const auto&[L, path_length] : weight[d_emitter]) {
                Float total_path_length = path_length + camera_its1_length;

                Vector2i bin_range = pif->getBinIndexRange(total_path_length);
                for (int idx_bin = bin_range[0]; idx_bin <= bin_range[1]; idx_bin++) {
                    int bin_offset = idx_bin * nder * num_pixels;
                    Float pif_kernel = pif->eval(total_path_length, idx_bin);
                    for (int i = 0; i < 4; ++i) {
                        if (eRec.sensor_vals[i] > Epsilon) {
                            for (int j = 0; j < nder; ++j) {
                                image[bin_offset + j * num_pixels + eRec.pixel_indices[i]] += pif_kernel * L * Spectrum(eRec.value0.grad(j)) * eRec.sensor_vals[i] / edgePdf;
                            }
                        }
                    }
                }
            }
        }
    }

    if (max_depth > 1) {
        std::vector<std::tuple<int, Spectrum, Float>> pathThroughput;
        weightedImportance(scene, sampler, eRec.its1, max_depth - 1, false, weight, pathThroughput);
        for (const auto &[idx_pixel, L, total_path_length] : pathThroughput) {
            Vector2i bin_range = pif->getBinIndexRange(total_path_length);
            for (int idx_bin = bin_range[0]; idx_bin <= bin_range[1]; idx_bin++) {
                int bin_offset = idx_bin * nder * num_pixels;
                Float pif_kernel = pif->eval(total_path_length, idx_bin);
                for (int j = 0; j < nder; ++j) {
                    image[bin_offset + j * num_pixels + idx_pixel] += pif_kernel * L * eRec.value1.grad(j) / edgePdf;
                }
            }
        }
    }
}

std::vector<std::tuple<SpectrumAD, FloatAD>> Binned_PathTracerAD_PathSpace::LiAD(const Scene &scene, RndSampler* sampler, const RayAD &_ray, int pixel_x, int pixel_y, int max_depth) const {
    Ray ray = _ray.toRay();
	Intersection its;
	IntersectionAD itsAD, pre_itsAD;
	std::vector<std::tuple<SpectrumAD, FloatAD>> records;
	SpectrumAD throughputAD(Spectrum::Ones());
	Float eta = 1.0f;
    FloatAD path_length = 0;
    Float nee_pdf, bsdf_pdf = 1.f;

    pre_itsAD.p = _ray.org;

	scene.rayIntersect(ray, false, its);

	for(int depth = 0; depth <= max_depth; depth++) {
		bool isFirst = (depth == 0);
		
        if (!its.isValid())
            break;

        // getPoint
        FloatAD J;
        scene.getPoint(its, pre_itsAD.p, itsAD, J);
        VectorAD x = itsAD.p;
        VectorAD dir = x - pre_itsAD.p;
        FloatAD dist = dir.norm();
        dir /= dist;

        path_length += dist;

        // calculate throughput
        VectorAD dir_local = pre_itsAD.toLocal(dir);
        SpectrumAD f = isFirst ? scene.camera.evalFilterAD(pixel_x, pixel_y, itsAD) : pre_itsAD.evalBSDF(dir_local) / bsdf_pdf;
        FloatAD G = isFirst ? FloatAD(1.) : itsAD.geoFrame.n.dot(-dir) / dist.square();
        Float pdf = G.val;
        throughputAD *= f * G * J / pdf;
        bsdf_pdf *= pdf;

        if (itsAD.isEmitter()) {
            SpectrumAD ret;
            if (isFirst) ret = throughputAD * itsAD.Le(-dir);
            else {
                nee_pdf = scene.pdfEmitterSample(itsAD);
                auto mis_weight = bsdf_pdf / (nee_pdf + bsdf_pdf);
                // mis_weight = 0.5;
                ret = throughputAD * itsAD.Le(-dir) * mis_weight;
            }
            records.push_back(std::tie(ret, path_length));
        }

        if (depth >= max_depth)
            break;

        VectorAD woAD;
        Float G1;

        // Light sampling
        auto value = scene.sampleEmitterDirectAD(itsAD, sampler->next2D(), sampler, woAD, nee_pdf, &G1);
        if (!value.isZero(Epsilon)) {
            Intersection tmpIts;
            IntersectionAD tmpItsAD;
            scene.rayIntersect(Ray(itsAD.p.val, itsAD.toWorld(woAD).val), true, tmpIts);

            FloatAD J;
            scene.getPoint(tmpIts, tmpItsAD, J);
            FloatAD total_path_length = path_length + (tmpItsAD.p - itsAD.p).norm();

            auto bsdf_val = itsAD.evalBSDF(woAD);
            bsdf_pdf = itsAD.pdfBSDF(woAD.val) * G1;
            auto mis_weight = nee_pdf / (nee_pdf + bsdf_pdf);
            // mis_weight = 0.5;
            SpectrumAD ret = throughputAD * value * bsdf_val * mis_weight;
            records.push_back(std::tie(ret, total_path_length));
        }

        // BSDF sampling
        Vector wo;
        Float bsdf_eta;
        auto bsdf_weight = its.sampleBSDF(sampler->next3D(), wo, bsdf_pdf, bsdf_eta);
        if (bsdf_weight.isZero())
            break;
        wo = its.toWorld(wo);
        ray = Ray(its.p, wo);
        eta *= bsdf_eta;

        scene.rayIntersect(ray, true, its);

        pre_itsAD = itsAD;
	}
	return records;
}

std::vector<std::tuple<SpectrumAD, int>> Binned_PathTracerAD_PathSpace::pixelColorAD(const Scene &scene, const RenderOptions &options, RndSampler *sampler, int x, int y) const {
    int pixel_x = x;
	int pixel_y = y;
	std::vector<std::tuple<SpectrumAD, int>> ret;
	Ray rays[2];
	const int max_depth = options.max_bounces;
	const auto &camera = scene.camera;
    const auto &pif = camera.pif;
	camera.samplePrimaryRayFromFilter(pixel_x, pixel_y, sampler->next2D(), rays[0], rays[1]);
	VectorAD o = camera.cpos;
	uint64_t state = sampler->state;
	for (int i = 0; i < 2; i++) {
		sampler->state = state;
		Ray &ray = rays[i];
        auto samples = LiAD(scene, sampler, RayAD(o, ray.dir), pixel_x, pixel_y, max_depth);
        for (auto& sample : samples) {
            std::get<0>(sample) *= 0.5;
        }

        std::vector<std::tuple<SpectrumAD, int>> split_res;
        for (const auto&[contrib, path_length] : samples) {
            Vector2i bin_range = pif->getBinIndexRange(path_length.val);
            for (int idx_bin = bin_range[0]; idx_bin <= bin_range[1]; idx_bin++) {
                FloatAD pif_kernel = pif->evalAD(path_length, idx_bin);
                split_res.push_back(std::make_tuple(contrib * pif_kernel, idx_bin));
            }
        }
        ret.insert(ret.end(), split_res.begin(), split_res.end());
	}
	return ret;
}

void Binned_PathTracerAD_PathSpace::renderInterior(const Scene& scene, const RenderOptions& options, ptr<float> rendered_image) const {
    if (!options.quiet)
        std::cout << std::scientific << std::setprecision(1) << "[INFO] grad_threshold = " << options.grad_threshold << std::endl;

    const auto &camera = scene.camera;
    const auto &pif = camera.pif;
    const bool cropped = camera.rect.isValid();
    const int num_pixels = cropped ? camera.rect.crop_width * camera.rect.crop_height
        : camera.width * camera.height;
    const int size_block = 4;
    const int num_bins = pif->num_bins;

    if (!options.quiet)
        std::cout << "Rendering using [ " << getName() << " ] and " << nworker << " workers ..." << std::endl;

    if (num_bins <= 0)
        std::cout << "Invalid number of bins: " << num_bins << std::endl;

    // Pixel sampling
    if (options.num_samples > 0) {
        int num_block = std::ceil((Float)num_pixels / size_block);
        int finished_block = 0;
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
        for (int index_block = 0; index_block < num_block; index_block++) {
            int block_start = index_block * size_block;
            int block_end = std::min((index_block + 1)*size_block, num_pixels);

            for (int idx_pixel = block_start; idx_pixel < block_end; idx_pixel++) {
                int ix = cropped ? camera.rect.offset_x + idx_pixel % camera.rect.crop_width
                    : idx_pixel % camera.width;
                int iy = cropped ? camera.rect.offset_y + idx_pixel / camera.rect.crop_width
                    : idx_pixel / camera.width;
                RndSampler sampler(options.seed, idx_pixel);

                std::vector<SpectrumAD> pixel_vals(num_bins, SpectrumAD(0));
                for (int idx_sample = 0; idx_sample < options.num_samples; idx_sample++) {
                    auto tmps = pixelColorAD(scene, options, &sampler, ix, iy);
                    for (const auto&[tmp, idx_bin] : tmps) {
                        bool val_valid = std::isfinite(tmp.val[0]) && std::isfinite(tmp.val[1]) && std::isfinite(tmp.val[2]) && tmp.val.minCoeff() >= 0.0f;
                        Float tmp_val = tmp.der.abs().maxCoeff();
                        bool der_valid = std::isfinite(tmp_val) && tmp_val < options.grad_threshold;

                        if (val_valid && der_valid) {
                            pixel_vals[idx_bin] += tmp;
                        } else if (!options.quiet) {
                            omp_set_lock(&messageLock);
                            if (!val_valid)
                                std::cerr << std::scientific << std::setprecision(2) << "\n[WARN] Invalid path contribution: [" << tmp.val << "]" << std::endl;
                            if (!der_valid)
                                std::cerr << std::scientific << std::setprecision(2) << "\n[WARN] Rejecting large gradient: [" << tmp.der << "]" << std::endl;
                            omp_unset_lock(&messageLock);
                        }
                    }
                }
                for (int idx_bin = 0; idx_bin < num_bins; idx_bin++) {
                    SpectrumAD pixel_val = pixel_vals[idx_bin] / options.num_samples;

                    int bin_offset = idx_bin * (nder + 1) * num_pixels * 3;

                    rendered_image[bin_offset + idx_pixel * 3] += static_cast<float>(pixel_val.val(0));
                    rendered_image[bin_offset + idx_pixel * 3 + 1] += static_cast<float>(pixel_val.val(1));
                    rendered_image[bin_offset + idx_pixel * 3 + 2] += static_cast<float>(pixel_val.val(2));
                    for (int ch = 1; ch <= nder; ++ch) {
                        int offset = (ch*num_pixels + idx_pixel) * 3;
                        rendered_image[bin_offset + offset] += static_cast<float>((pixel_val.grad(ch - 1))(0));
                        rendered_image[bin_offset + offset + 1] += static_cast<float>((pixel_val.grad(ch - 1))(1));
                        rendered_image[bin_offset + offset + 2] += static_cast<float>((pixel_val.grad(ch - 1))(2));
                    }
                }
            }

            if (!options.quiet) {
                omp_set_lock(&messageLock);
                progressIndicator(Float(++finished_block) / num_block);
                omp_unset_lock(&messageLock);
            }
        }

        std::cout << std::endl;
    }
}

bool Binned_PathTracerAD_PathSpace::getBoundaryEndpointFromBoundaryRay(const Scene& scene, const Intersection& its, 
    const VectorAD& x1, const VectorAD& pEdge, VectorAD& u2) const {
    RayAD bRay(x1, pEdge-x1);
    const Shape &shape = *scene.shape_list[its.indices[0]];
    const Vector3i &f1 = shape.getIndices(its.indices[1]);
    const VectorAD &v0 = shape.getVertexAD(f1[0]), &v1 = shape.getVertexAD(f1[1]), &v2 = shape.getVertexAD(f1[2]);
    return rayIntersectTriangleAD(v0, v1, v2, bRay, u2);
}

void Binned_PathTracerAD_PathSpace::renderPrimaryEdges(const Scene& scene, const RenderOptions& options, ptr<float> rendered_image) const {
    const Camera &camera = scene.camera;
    const auto &pif = camera.pif;
    const int num_pixels = camera.getNumPixels();
    const int num_bins = pif->num_bins;
    const int nworker = omp_get_num_procs();
    std::vector<std::vector<Spectrum> > image_per_thread(nworker);
    for (int i = 0; i < nworker; i++) image_per_thread[i].resize(num_bins*nder*num_pixels, Spectrum(0.0f));

    constexpr int num_samples_per_block = 128;
    long long num_samples = static_cast<long long>(options.num_samples_primary_edge)*num_pixels;
    const long long num_block = static_cast<long long>(std::ceil(static_cast<Float>(num_samples)/num_samples_per_block));
    num_samples = num_block*num_samples_per_block;
	#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (long long index_block = 0; index_block < num_block; ++index_block) {
        for (int omp_i = 0; omp_i < num_samples_per_block; omp_i++) {
            const int tid = omp_get_thread_num();
            RndSampler sampler(options.seed, m_taskId[tid] = index_block*num_samples_per_block + omp_i);

            // sample a point on the edge for generating the boundary ray
            int shape_id;
            VectorAD pEdge;
            Float edgePdf;

            const Edge &edge = scene.ptr_psEdgeManager->samplePrimaryEdgePoint(sampler.next1D(), shape_id, pEdge, edgePdf);
            const Shape &shape = *scene.shape_list[shape_id];
		    const VectorAD &v0 = shape.getVertexAD(edge.v0), &v1 = shape.getVertexAD(edge.v1);
		    // pEdge = v0 + (v1 - v0)*rnd;

			int depth_boundary = 0;
			Vector raydir = (pEdge.val - camera.cpos.val).normalized();
			Ray _edgeRay(pEdge.val, raydir);
            Spectrum boundary_throughput(1.0);
			Intersection its;

            if (!scene.rayIntersect(_edgeRay, true, its)) continue;
            Float gnDotD = its.geoFrame.n.dot(-_edgeRay.dir);
            Float snDotD = its.shFrame.n.dot(-_edgeRay.dir);
            if (gnDotD < Epsilon || snDotD < Epsilon) continue;
            if (!scene.isVisible(camera.cpos.val, false, pEdge.val, true)) continue;

			Matrix2x4 pixel_uvs;
			Array4 attenuations(0.0);
			{
				Vector d0;
				Vector p = its.p;
    			camera.sampleDirect(p, pixel_uvs, attenuations, d0);
    			if ( attenuations.isZero() ) continue;
    			assert( (d0 + raydir).norm() < Epsilon );
			}

        	// evaluate the boundary segment
        	Float baseValue = 0.0;
        	Vector n;
        	Float dist = (its.p- camera.cpos.val).norm();
        	{
        		Float dist1 = (_edgeRay.org - camera.cpos.val).norm();
            	Vector e = (v0.val - v1.val).normalized().cross(_edgeRay.dir);
            	Vector e1 = shape.getVertex(edge.v2) -v0.val;
            	Float sinphi = e.norm();
                Vector proj = e.cross(its.geoFrame.n).normalized();
                Float sinphi2 = _edgeRay.dir.cross(proj).norm();
                n = its.geoFrame.n.cross(proj).normalized();
                Float deltaV = math::signum(e.dot(e1))*math::signum(e.dot(n));
                if ( sinphi > Epsilon && sinphi2 > Epsilon ) {
                    baseValue = deltaV * (dist/dist1)*(sinphi/sinphi2) * std::abs(its.geoFrame.n.dot(-_edgeRay.dir));
                }
        	}
        	if ( std::abs(baseValue) < Epsilon) continue;
			// Hack: assuming that the camera is fixed
			VectorAD x1 = VectorAD(camera.cpos.val);
			VectorAD u2;
       		int max_interactions = options.max_bounces - depth_boundary;
        	if ( getBoundaryEndpointFromBoundaryRay(scene, its, x1, pEdge, u2) ) {
        		if ( u2.der.isZero(Epsilon) ) continue;
	            BinnedWeight weight(max_interactions+1);
	            radiance(scene, &sampler, its, max_interactions, weight);

                std::vector<Spectrum> L_vals(num_bins, Spectrum(0));
				for (int i = 0; i <= max_interactions; i++) {
					for (auto& [L, path_length]: weight[i]) {
                        Float total_path_length = dist + path_length;
                        Vector2i bin_range = pif->getBinIndexRange(total_path_length);
                        for (int idx_bin = bin_range[0]; idx_bin <= bin_range[1]; idx_bin++) {
                            Float pif_kernel = pif->eval(total_path_length, idx_bin);
                            L_vals[idx_bin] += pif_kernel * L;
                        }
                    }
                }

                std::vector<SpectrumAD> contribBoundarySeg_vals(num_bins, SpectrumAD(0));
                for ( int i = 0; i < num_bins; i++) {
                    for ( int j = 0; j < nder; j++ ) {
                        contribBoundarySeg_vals[i].grad(j) = baseValue * n.dot(u2.grad(j)) * boundary_throughput * L_vals[i];
                    }
                }

                for (int idx_bin = 0; idx_bin < num_bins; idx_bin++) {
                    int bin_offset = idx_bin * nder * num_pixels;
                    for ( int i = 0; i < 4; i++ ) {
                        if ( attenuations(i) < Epsilon ) continue;
                        int idx_pixel = scene.camera.getPixelIndex(pixel_uvs.col(i));
                        for (int j = 0; j < nder; j++) {
                            image_per_thread[tid][bin_offset + j * num_pixels + idx_pixel] += attenuations(i) * Spectrum(contribBoundarySeg_vals[idx_bin].grad(j)) / edgePdf;
                        }
                    }
                }
        	}
        }

        if ( !options.quiet ) {
            omp_set_lock(&messageLock);
            progressIndicator(Float(index_block + 1)/num_block);
            omp_unset_lock(&messageLock);
        }
    }
    if ( !options.quiet )
            std::cout << std::endl;

    for (int i = 0; i < nworker; ++i) {
        for (int idx_bin = 0; idx_bin < num_bins; ++idx_bin) {
            int bin_offset1 = idx_bin * (nder + 1) * num_pixels;
            int bin_offset2 = idx_bin * nder * num_pixels;
            for (int j = 0; j < nder; ++j) {
                for (int idx_pixel = 0; idx_pixel < num_pixels; ++idx_pixel) {
                    int offset1 = (bin_offset1 + (j + 1)*num_pixels + idx_pixel) * 3,
                        offset2 = bin_offset2 + j * num_pixels + idx_pixel;
                    rendered_image[offset1    ] += image_per_thread[i][offset2][0] / static_cast<Float>(num_samples);
                    rendered_image[offset1 + 1] += image_per_thread[i][offset2][1] / static_cast<Float>(num_samples);
                    rendered_image[offset1 + 2] += image_per_thread[i][offset2][2] / static_cast<Float>(num_samples);
                }
            }
        }
    }
}

void Binned_PathTracerAD_PathSpace::render(const Scene &scene, const RenderOptions &options, ptr<float> rendered_image) const {
    using namespace std::chrono;

    const auto &pif = scene.camera.pif;
    int num_bins = pif->num_bins;
    int num_pixels = scene.camera.getNumPixels();
    for (int i = 0; i < num_bins * (nder + 1) * num_pixels * 3; i++) rendered_image[i] = 0;

    use_antithetic_boundary = options.use_antithetic_boundary;
    
    auto to_second = [](auto start, auto end) {
        return duration_cast<milliseconds>(end - start).count() / 1000.0;
    };

    // Interior term
    if (!options.quiet)
        std::cout << "\nRendering the interior term." << std::endl;
    auto start_interior = high_resolution_clock::now();
    renderInterior(scene, options, rendered_image);
    auto t_interior = to_second(start_interior, high_resolution_clock::now());

    // Visbility boundary term
    if (!options.quiet)
        std::cout << "\nRendering the visibility boundary term." << std::endl;
    auto start_visibility = high_resolution_clock::now();
    // Primary
    if ( options.num_samples_primary_edge > 0 && scene.ptr_edgeManager->getNumPrimaryEdges() > 0 )
        renderPrimaryEdges(scene, options, rendered_image);
    // Direct
#ifdef USE_BOUNDARY_NEE
    if ( options.num_samples_secondary_edge_direct > 0 )
        renderEdgesDirect(scene, options, rendered_image);
#endif
    // Indirect
    if ( options.num_samples_secondary_edge_indirect > 0 )
        renderEdges(scene, options, rendered_image);
    auto t_visibility = to_second(start_visibility, high_resolution_clock::now());

    // Path-length boundary term
    if (!options.quiet)
        std::cout << "\nRendering the path-length boundary term." << std::endl;
    auto start_pathlength = high_resolution_clock::now();
    if ((pif->name == "Boxcar" || pif->name == "TruncatedGaussian") && options.num_samples_time_edge > 0)
        renderBoxTimeGateBoundary(scene, options, rendered_image);
    auto t_pathlength = to_second(start_pathlength, high_resolution_clock::now());

    if (!options.quiet) {
        std::cout << std::fixed << std::setprecision(2)
                  << "\nElapsed time: \n"
                  << "\tInterior: " << t_interior << "s\n"
                  << "\tVisibility boundary: " << t_visibility << "s\n"
                  << "\tPath-length boundary: " << t_pathlength << "s\n"
                  << std::endl;
    }
}

