#include "binnedTofPathADps.h"
#include "scene.h"
#include "sampler.h"
#include "rayAD.h"
#include "intersectionAD.h"
#include "epc_sampler.h"
#include "math_func.h"
#include <iomanip>
#include <chrono>

static const int nworker = omp_get_num_procs();

#define ALTERNATIVE_EDGE_SAMPLING

// Direct edges

void Binned_TofPathTracerAD_PathSpace::evalEdgeDirect(const Scene &scene, int shape_id, const Edge &rEdge, const RayAD &edgeRay, RndSampler *sampler, int max_bounces,
        EdgeEvaluationRecord &eRec, Float &path_length, std::vector<std::tuple<int, int, Spectrum>> &record, bool quiet) const
{
#ifndef USE_BOUNDARY_NEE
    std::cerr << "Without next-event estimation (NEE), evalEdgeDirect() should not be used." << std::endl;
    assert(false);
#endif
    Intersection &its1 = eRec.its1, &its2 = eRec.its2;
    int &idx_pixel = eRec.idx_pixel;

    eRec.value0.zero();
    eRec.value1.zero();
    idx_pixel = -1;

    eRec.sensor_vals = Array4(0.0);
    eRec.pixel_indices = Array4i(-1);

    const Shape &shape = *scene.shape_list[shape_id];
    Ray _edgeRay = edgeRay.toRay();
    const Vector &d1 = _edgeRay.dir;
    if (scene.rayIntersect(_edgeRay, true, its2) && scene.rayIntersect(_edgeRay.flipped(), true, its1)) {
        const Vector2i ind0(shape_id, rEdge.f0), ind1(shape_id, rEdge.f1);
        if (its1.indices != ind0 && its1.indices != ind1 && its2.indices != ind0 && its2.indices != ind1 && its2.isEmitter()) {
            const Float gn1d1 = its1.geoFrame.n.dot(d1), sn1d1 = its1.shFrame.n.dot(d1),
                gn2d1 = its2.geoFrame.n.dot(-d1), sn2d1 = its2.shFrame.n.dot(-d1);
            assert(std::abs(its1.wi.z() - sn1d1) < Epsilon && std::abs(its2.wi.z() - sn2d1) < Epsilon);

            bool valid1 = (its1.ptr_bsdf->isTransmissive() && math::signum(gn1d1)*math::signum(sn1d1) > 0.5f) || (!its1.ptr_bsdf->isTransmissive() && gn1d1 > Epsilon && sn1d1 > Epsilon),
                valid2 = (its2.ptr_bsdf->isTransmissive() && math::signum(gn2d1)*math::signum(sn2d1) > 0.5f) || (!its2.ptr_bsdf->isTransmissive() && gn2d1 > Epsilon && sn2d1 > Epsilon);
            if (valid1 && valid2) {
                // const Shape &shape1 = *scene.shape_list[its1.indices[0]]; const Vector3i &f1 = shape1.getIndices(its1.indices[1]);
                const Shape &shape2 = *scene.shape_list[its2.indices[0]]; const Vector3i &f2 = shape2.getIndices(its2.indices[1]);
                const VectorAD &v0 = shape2.getVertexAD(f2[0]), &v1 = shape2.getVertexAD(f2[1]), &v2 = shape2.getVertexAD(f2[2]);
                Float baseValue = 0.0f;
                Float dist = (its2.p - its1.p).norm();
                Vector n;
                {
                    Float cos2 = std::abs(gn2d1);
                    Vector e = (shape.getVertex(rEdge.v0) - shape.getVertex(rEdge.v1)).normalized().cross(d1);
                    Float sinphi = e.norm();
                    Vector proj = e.cross(its2.geoFrame.n).normalized();
                    Float sinphi2 = d1.cross(proj).norm();
                    n = its2.geoFrame.n.cross(proj).normalized();

                    Float deltaV;
                    Vector e1 = shape.getVertex(rEdge.v2) - shape.getVertex(rEdge.v0);
                    deltaV = math::signum(e.dot(e1))*math::signum(e.dot(n));

                    if (sinphi > Epsilon && sinphi2 > Epsilon)
                        baseValue = deltaV * (its1.t / dist)*(sinphi / sinphi2)*cos2;
                }

                if (std::abs(baseValue) > Epsilon) {
                    VectorAD u2;

                    // Direct
                    {
                        Vector2 pix_uv;
                        Vector d0;
                        Float sensor_val = scene.sampleAttenuatedSensorDirect(its1, sampler, 0, pix_uv, d0);
                        if (sensor_val > Epsilon) {
                            RayAD cameraRay = scene.camera.samplePrimaryRayAD(pix_uv[0], pix_uv[1]);
                            IntersectionAD its;
                            if (scene.rayIntersectAD(cameraRay, false, its) && (its.p.val - its1.p).norm() < ShadowEpsilon) {
                                bool valid = rayIntersectTriangleAD(v0, v1, v2, RayAD(its.p, edgeRay.org - its.p), u2);
                                if (valid) {
                                    path_length = dist + (its1.p - scene.camera.cpos.val).norm();

                                    Vector d0_local = its1.toLocal(d0);
                                    Spectrum value0 = its1.evalBSDF(d0_local, EBSDFMode::EImportanceWithCorrection) * baseValue * its2.Le(-d1);
                                    for (int j = 0; j < nder; ++j)
                                        eRec.value0.grad(j) = value0 * n.dot(u2.grad(j));

                                    Matrix2x4 pix_uvs;
                                    eRec.sensor_vals = scene.sampleAttenuatedSensorDirect(its1, sampler, 0, pix_uvs, d0);
                                    for (int i = 0; i < 4; i++) {
                                        eRec.pixel_indices[i] = scene.camera.getPixelIndex(pix_uvs.col(i));
                                    }
                                    idx_pixel = scene.camera.getPixelIndex(pix_uv);
                                }
                            }
                        }
                    }

                    // indirect
                    if (max_bounces > 1) {
                        VectorAD x1, n1;
                        FloatAD J1;
                        scene.getPoint(its1, x1, n1, J1);
                        bool valid = rayIntersectTriangleAD(v0, v1, v2, RayAD(x1, edgeRay.org - x1), u2);
                        if (valid) {
                            for (int j = 0; j < nder; ++j)
                                eRec.value1.grad(j) = baseValue * n.dot(u2.grad(j));

                            BinnedWeight weight(1);
                            weight[0].resize(1);
                            Spectrum tmp_spec(0.f);
                            weight[0][0] = std::tie(tmp_spec, dist);
                            weightedImportance(scene, sampler, its1, max_bounces - 1, true, weight, record);
                        }
                    }
                }
            }
        }

        for (int j = 0; j < nder; ++j) {
            if (!std::isfinite(eRec.value0.grad(j)[0]) || !std::isfinite(eRec.value0.grad(j)[1]) || !std::isfinite(eRec.value0.grad(j)[2])) {
                if (!quiet) {
                    omp_set_lock(&messageLock);
                    std::cerr << std::fixed << std::setprecision(2)
                        << "\n[WARN] Invalid gradient: [" << eRec.value0.grad(j).transpose() << "]" << std::endl;
                    omp_unset_lock(&messageLock);
                }
                eRec.value0.grad(j).setZero();
            }

            if (!std::isfinite(eRec.value1.grad(j))) {
                if (!quiet) {
                    omp_set_lock(&messageLock);
                    std::cerr << std::fixed << std::setprecision(2)
                        << "\n[WARN] Invalid gradient: [" << eRec.value1.grad(j) << "]" << std::endl;
                    omp_unset_lock(&messageLock);
                }
                eRec.value1.grad(j) = 0.0f;
            }
        }
    }
}

void Binned_TofPathTracerAD_PathSpace::renderEdgesDirect(const Scene &scene, const RenderOptions &options, ptr<float> rendered_image) const {
#ifndef USE_BOUNDARY_NEE
    std::cerr << "Without next-event estimation (NEE), renderEdgesDirect() should not be used." << std::endl;
    assert(false);
#endif
    const Camera &camera = scene.camera;
    const auto &pif = camera.pif;
    int num_pixels = camera.getNumPixels();
    const int num_bins = pif->num_bins;
    const int nworker = omp_get_num_procs();
    std::vector<std::vector<Spectrum>> image_per_thread(nworker);
    for (int i = 0; i < nworker; i++) image_per_thread[i].resize(num_bins*nder*num_pixels, Spectrum(0.0f));

    constexpr int num_samples_per_block = 128;
    long long num_samples = static_cast<long long>(options.num_samples_secondary_edge_direct)*num_pixels;
    const long long num_block = static_cast<long long>(std::ceil(static_cast<Float>(num_samples) / num_samples_per_block));
    num_samples = num_block * num_samples_per_block;
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (long long index_block = 0; index_block < num_block; ++index_block) {
        for (int omp_i = 0; omp_i < num_samples_per_block; ++omp_i) {
            const int tid = omp_get_thread_num();
            RndSampler sampler(options.seed, m_taskId[tid] = index_block * num_samples_per_block + omp_i);
            int shape_id;
            RayAD edgeRay;
            Float edgePdf;
            const Edge &rEdge = scene.sampleEdgeRayDirect(sampler.next3D(), shape_id, edgeRay, edgePdf);
            if (shape_id >= 0) {
                std::vector<std::tuple<int, int, Spectrum>> importance;
                EdgeEvaluationRecord eRec;
                Float path_length;
                evalEdgeDirect(scene, shape_id, rEdge, edgeRay, &sampler, options.max_bounces, eRec, path_length, importance, options.quiet);

                // one-bounce
                if (!eRec.sensor_vals.isZero()) {
                    if (pif->getBinIndex(path_length) > -1) {
                        Vector2i bin_range = pif->getBinIndexRange(path_length);
                        for (int idx_bin = bin_range[0]; idx_bin <= bin_range[1]; idx_bin++) {
                            int bin_offset = idx_bin * nder * num_pixels;
                            Float pif_kernel = pif->eval(path_length, idx_bin);
                            for (int i = 0; i < 4; ++i) {
                                if (eRec.sensor_vals[i] > Epsilon) {
                                    for (int j = 0; j < nder; ++j) {
                                        image_per_thread[tid][bin_offset + j * num_pixels + eRec.pixel_indices[i]] += pif_kernel * eRec.value0.grad(j) * eRec.sensor_vals[i] / edgePdf;
                                    }
                                }
                            }
                        }
                    }
                }

                // multi-bounce
                if (!importance.empty()) {
                    Spectrum light_val = eRec.its2.Le(-edgeRay.dir.val);
                    for (const auto&[idx_pixel, idx_bin, contrib] : importance) {
                        int bin_offset = idx_bin * nder * num_pixels;
                        for (int j = 0; j < nder; ++j) {
                            image_per_thread[tid][bin_offset + j * num_pixels + idx_pixel] += eRec.value1.grad(j)*contrib*light_val / edgePdf;
                        }
                    }
                }
            }
        }

        if (!options.quiet) {
            omp_set_lock(&messageLock);
            progressIndicator(Float(index_block + 1) / num_block);
            omp_unset_lock(&messageLock);
        }
    }
    if (!options.quiet) std::cout << std::endl;

    for (int i = 0; i < nworker; ++i)
        for (int idx_bin = 0; idx_bin < num_bins; ++idx_bin) {
            int bin_offset1 = idx_bin * (nder + 1) * num_pixels;
            int bin_offset2 = idx_bin * nder * num_pixels;
            for (int j = 0; j < nder; ++j)
                for (int idx_pixel = 0; idx_pixel < num_pixels; ++idx_pixel) {
                    int offset1 = (bin_offset1 + (j + 1)*num_pixels + idx_pixel) * 3,
                        offset2 = bin_offset2 + j * num_pixels + idx_pixel;
                    rendered_image[offset1    ] += image_per_thread[i][offset2][0] / static_cast<Float>(num_samples);
                    rendered_image[offset1 + 1] += image_per_thread[i][offset2][1] / static_cast<Float>(num_samples);
                    rendered_image[offset1 + 2] += image_per_thread[i][offset2][2] / static_cast<Float>(num_samples);
                }
        }
}

// Indirect edges

void Binned_TofPathTracerAD_PathSpace::renderEdges(const Scene &scene, const RenderOptions &options, ptr<float> rendered_image) const {
#ifdef USE_BOUNDARY_NEE
    if (options.max_bounces > 1) {
#else
    if (options.max_bounces >= 1) {
#endif
        const Camera &camera = scene.camera;
        const auto &pif = camera.pif;
        int num_pixels = camera.getNumPixels();
        const int num_bins = pif->num_bins;
        const int nworker = omp_get_num_procs();
        std::vector<std::vector<Spectrum> > image_per_thread(nworker);
        for (int i = 0; i < nworker; i++) image_per_thread[i].resize(num_bins*nder*num_pixels, Spectrum(0.0f));

        constexpr int num_samples_per_block = 128;
        long long num_samples = static_cast<long long>(options.num_samples_secondary_edge_indirect)*num_pixels;
        const long long num_block = static_cast<long long>(std::ceil(static_cast<Float>(num_samples) / num_samples_per_block));
        num_samples = num_block * num_samples_per_block;
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
        for (long long index_block = 0; index_block < num_block; ++index_block) {
            for (int omp_i = 0; omp_i < num_samples_per_block; omp_i++) {
                const int tid = omp_get_thread_num();
                RndSampler sampler(options.seed, m_taskId[tid] = index_block * num_samples_per_block + omp_i);
                // Indirect contribution of edge term.
                int shape_id;
                RayAD edgeRay;
                Float edgePdf;
                const Edge &rEdge = scene.sampleEdgeRay(sampler.next3D(), shape_id, edgeRay, edgePdf);
                if (shape_id >= 0) {
                    EdgeEvaluationRecord eRec;
                    evalEdge(scene, shape_id, rEdge, edgeRay, &sampler, eRec);
                    if (eRec.its1.isValid() && eRec.its2.isValid()) {
                        bool zeroVelocity = eRec.value0.der.isZero(Epsilon) && eRec.value1.der.isZero(Epsilon);
                        if (!zeroVelocity) {
                            traceRayFromEdgeSegement(scene, eRec, edgePdf, options.max_bounces, &sampler, image_per_thread[tid]);
                        }
                    }
                }
            }

            if (!options.quiet) {
                omp_set_lock(&messageLock);
                progressIndicator(Float(index_block + 1) / num_block);
                omp_unset_lock(&messageLock);
            }
        }
        if (!options.quiet)
            std::cout << std::endl;

        for (int i = 0; i < nworker; ++i)
            for (int idx_bin = 0; idx_bin < num_bins; ++idx_bin) {
                int bin_offset1 = idx_bin * (nder + 1) * num_pixels;
                int bin_offset2 = idx_bin * nder * num_pixels;
                for (int j = 0; j < nder; ++j)
                    for (int idx_pixel = 0; idx_pixel < num_pixels; ++idx_pixel) {
                        int offset1 = (bin_offset1 + (j + 1)*num_pixels + idx_pixel) * 3,
                            offset2 = bin_offset2 + j * num_pixels + idx_pixel;
                        rendered_image[offset1] += image_per_thread[i][offset2][0] / static_cast<Float>(num_samples);
                        rendered_image[offset1 + 1] += image_per_thread[i][offset2][1] / static_cast<Float>(num_samples);
                        rendered_image[offset1 + 2] += image_per_thread[i][offset2][2] / static_cast<Float>(num_samples);
                    }
            }
    }
}

// Helper functions

void Binned_TofPathTracerAD_PathSpace::weightedImportance(const Scene& scene, RndSampler* sampler, const Intersection& its0, int max_depth, bool is_direct,
        const BinnedWeight& weight,
        std::vector<std::tuple<int, int, Spectrum>>& ret) const {
    Intersection its = its0;
    Spectrum throughput(1.0f);
    Vector d0;
    Vector2 pix_uv;
    Ray ray_sensor;
    Float current_path_length = 0;
    
    const auto &camera = scene.camera;
    const auto &pif = camera.pif;
    Float pif_pdf;
    int idx_bin;
    Float total_path_length_sampled = pif->sample(sampler, idx_bin, pif_pdf);
    Float pif_value = pif->eval(total_path_length_sampled, idx_bin);
    throughput *= pif_value / pif_pdf;

    for (int d_sensor = 1; d_sensor <= max_depth; d_sensor++) {
        // Ellipsoidal connections.
        if (is_direct) {
            assert(weight.size() == 1);
            Float tmp_path_length = current_path_length + std::get<1>(weight[0][0]);

            if (total_path_length_sampled - tmp_path_length > 0) {
                auto res = ellipsoidalConnectToCamera(scene, sampler, its, idx_bin,
                    total_path_length_sampled - tmp_path_length, throughput);

                for (const auto&[contrib, idx_pixel, idx_bin] : res) {
                    ret.push_back(std::make_tuple(idx_pixel, idx_bin, contrib));
                }
                //ret.insert(ret.end(), res.begin(), res.end());
            }
        }
        else {
            for (int d_emitter = 0; d_emitter <= max_depth - d_sensor; d_emitter++) {
                for (const auto &[L, path_length] : weight[d_emitter]) {
                    Float tmp_path_length = current_path_length + path_length;

                    if (total_path_length_sampled - tmp_path_length > 0) {
                        auto res = ellipsoidalConnectToCamera(scene, sampler, its, idx_bin,
                            total_path_length_sampled - tmp_path_length, L * throughput);

                        for (const auto&[contrib, idx_pixel, idx_bin] : res) {
                            ret.push_back(std::make_tuple(idx_pixel, idx_bin, contrib));
                        }
                        //ret.insert(ret.end(), res.begin(), res.end());
                    }
                }
            }
        }

        // Sample a new direction.
        Vector wo_local, wo;
        Float bsdf_pdf, bsdf_eta;
        Spectrum bsdf_weight = its.sampleBSDF(sampler->next3D(), wo_local, bsdf_pdf, bsdf_eta, EBSDFMode::EImportanceWithCorrection);
        if (bsdf_weight.isZero())
            break;
        wo = its.toWorld(wo_local);
        Vector wi = its.toWorld(its.wi);
        Float wiDotGeoN = wi.dot(its.geoFrame.n), woDotGeoN = wo.dot(its.geoFrame.n);
        if (wiDotGeoN * its.wi.z() <= 0 || woDotGeoN * wo_local.z() <= 0)
            break;
        throughput *= bsdf_weight;
        ray_sensor = Ray(its.p, wo);
        scene.rayIntersect(ray_sensor, true, its);
        if (!its.isValid())
            break;

        // Update path length
        current_path_length += its.t;
        if (current_path_length >= total_path_length_sampled)
            break;
    }
}

void Binned_TofPathTracerAD_PathSpace::radiance(const Scene& scene, RndSampler* sampler, const Intersection &its, int max_bounces,
        BinnedWeight &ret) const {
    Intersection _its(its);
    Float current_path_length = 0;

    if (_its.isEmitter()) {
        Spectrum L = _its.Le(_its.toWorld(_its.wi));
        Float path_length = current_path_length;
        ret[0].push_back(std::tie(L, path_length));
    }

    Spectrum throughput(1.0f);
    for (int d_emitter = 1; d_emitter <= max_bounces; d_emitter++) {
        // Direct illumination
        Float pdf_nee;
        Float dist;
        Vector wo;
        Spectrum value = scene.sampleEmitterDirect(_its, sampler->next2D(), sampler, wo, pdf_nee, &dist);
        if (!value.isZero(Epsilon)) {
            Spectrum bsdf_val = _its.evalBSDF(wo);
            Float bsdf_pdf = _its.pdfBSDF(wo);
            Float mis_weight = pdf_nee / (pdf_nee + bsdf_pdf);

            Spectrum L = throughput * value*bsdf_val*mis_weight;
            Float path_length = current_path_length + dist;
            ret[d_emitter].push_back(std::tie(L, path_length));
        }

        // Indirect illumination
        Float bsdf_pdf, bsdf_eta;
        Spectrum bsdf_weight = _its.sampleBSDF(sampler->next3D(), wo, bsdf_pdf, bsdf_eta);
        if (bsdf_weight.isZero(Epsilon)) break;
        wo = _its.toWorld(wo);

        Ray ray_emitter(_its.p, wo);
        if (!scene.rayIntersect(ray_emitter, true, _its)) break;
        throughput *= bsdf_weight;
        current_path_length += (_its.p - ray_emitter.org).norm();

        if (_its.isEmitter()) {
            Spectrum light_contrib = _its.Le(-ray_emitter.dir);
            if (!light_contrib.isZero(Epsilon)) {
                Float dist_sq = (_its.p - ray_emitter.org).squaredNorm();
                Float G = _its.geoFrame.n.dot(-ray_emitter.dir) / dist_sq;
                pdf_nee = scene.pdfEmitterSample(_its) / G;
                Float mis_weight = bsdf_pdf / (pdf_nee + bsdf_pdf);

                Spectrum L = throughput * light_contrib*mis_weight;
                ret[d_emitter].push_back(std::tie(L, current_path_length));
            }
        }
    }
}

void Binned_TofPathTracerAD_PathSpace::traceRayFromEdgeSegement(const Scene &scene, const EdgeEvaluationRecord& eRec, Float edgePdf, int max_depth, RndSampler *sampler, std::vector<Spectrum> &image) const {
    assert(max_depth > 0);
    const auto& camera = scene.camera;
    const auto& pif = camera.pif;
    const int num_bins = pif->num_bins;
    const int num_pixels = image.size() / num_bins / nder;

    // Trace ray towards emitter from its2.
    BinnedWeight weight(max_depth);
    {
        radiance(scene, sampler, eRec.its2, max_depth - 1, weight);
#ifdef USE_BOUNDARY_NEE
        weight[0].clear();
#endif

        Float edge_length = (eRec.its2.p - eRec.its1.p).norm();
        for (int d_emitter = 0; d_emitter < max_depth; d_emitter++) {
            for (auto& w : weight[d_emitter]) {
                std::get<1>(w) += edge_length;
            }
        }
    }

    // Trace ray towards sensor from its1.
    if (!eRec.sensor_vals.isZero()) {
        Float camera_its1_length = (scene.camera.cpos.val - eRec.its1.p).norm();

        // Direct connect to sensor
        for (int d_emitter = 0; d_emitter < max_depth; d_emitter++) {
            for (const auto&[L, path_length] : weight[d_emitter]) {
                Float total_path_length = path_length + camera_its1_length;

                if (pif->getBinIndex(total_path_length) > -1) {
                    Vector2i bin_range = pif->getBinIndexRange(total_path_length);
                    for (int idx_bin = bin_range[0]; idx_bin <= bin_range[1]; idx_bin++) {
                        int bin_offset = idx_bin * nder * num_pixels;
                        Float pif_kernel = pif->eval(total_path_length, idx_bin);
                        for (int i = 0; i < 4; ++i) {
                            if (eRec.sensor_vals[i] > Epsilon) {
                                for (int j = 0; j < nder; ++j) {
                                    image[bin_offset + j * num_pixels + eRec.pixel_indices[i]] += pif_kernel * L * Spectrum(eRec.value0.grad(j)) * eRec.sensor_vals[i] / edgePdf;
                                }
                            }
                        }
                    }
                }
            }
        }
    }

    if (max_depth > 1) {
        std::vector<std::tuple<int, int, Spectrum>> pathThroughput;
        weightedImportance(scene, sampler, eRec.its1, max_depth - 1, false, weight, pathThroughput);
        for (const auto &[idx_pixel, idx_bin, L] : pathThroughput) {
            int bin_offset = idx_bin * nder * num_pixels;
            for (int j = 0; j < nder; j++) {
                image[bin_offset + j * num_pixels + idx_pixel] += L * eRec.value1.grad(j) / edgePdf;
            }
        }
    }
}

std::vector<std::tuple<SpectrumAD, int>> Binned_TofPathTracerAD_PathSpace::pixelColorAD(const Scene &scene, const RenderOptions &options, RndSampler *sampler, int x, int y) const {
    int pixel_x = x;
    int pixel_y = y;
    std::vector<std::tuple<SpectrumAD, int>> ret;
    Ray rays[2];
    const int max_depth = options.max_bounces;
    const auto &camera = scene.camera;
    const auto &pif = camera.pif;
    camera.samplePrimaryRayFromFilter(pixel_x, pixel_y, sampler->next2D(), rays[0], rays[1]);
    VectorAD o = camera.cpos;
    uint64_t state = sampler->state;
    for (int i = 0; i < 2; i++) {
        sampler->state = state;
        Ray &ray = rays[i];

        int idx_bin;
        Float pif_pdf;
        //Float ec_length = camera.samplePIF(sampler, idx_bin, pif_pdf);
        Vector2 ec_length = pif->sampleAntithetic(sampler, idx_bin, pif_pdf);
        uint64_t state_path_length = sampler->state;

        for (int k = 0; k < (use_antithetic_interior ? 2 : 1); k++) {
            if (use_antithetic_interior) sampler->state = state_path_length;
            auto samples = TofIntegratorAD_PathSpace::LiAD(scene, sampler, RayAD(o, ray.dir), ec_length[k], pixel_x, pixel_y, max_depth);
            for (const auto&[tmp_contrib, path_length] : samples) {
                SpectrumAD contrib = 0.5 * tmp_contrib * pif->evalAD(path_length, idx_bin) / pif_pdf;
                if (use_antithetic_interior) contrib *= 0.5;
                ret.push_back(std::make_tuple(contrib, idx_bin));
            }
        }
    }
    return ret;
}

// Interior term

void Binned_TofPathTracerAD_PathSpace::renderInterior_ZeroAndOneBounce(const Scene &scene, const RenderOptions &options, ptr<float> rendered_image) const {
    const auto &camera = scene.camera;
    const auto &pif = camera.pif;
    int size_block = camera.getNumPixels();
    int num_block = options.num_samples;
    const bool cropped = camera.rect.isValid();
    const int num_pixels = cropped ? camera.rect.crop_width * camera.rect.crop_height
        : camera.width * camera.height;
    const int num_bins = pif->num_bins;

    std::vector<std::vector<Spectrum> > image_per_thread(nworker);
    for (int i = 0; i < nworker; i++) image_per_thread[i].resize(num_bins*(nder+1)*num_pixels, Spectrum(0.0f));

#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int index_block = 0; index_block < num_block; index_block++) {
        int block_start = index_block * size_block;
        int block_end = (index_block + 1)*size_block;
        for (int index_sample = block_start; index_sample < block_end; index_sample++) {
            RndSampler sampler(options.seed, index_sample);
            int tid = omp_get_thread_num();

            int idx_bin;
            Float pif_pdf;
            //Float ec_length = camera.samplePIF(&sampler, idx_bin, pif_pdf);
            Vector2 ec_length = pif->sampleAntithetic(&sampler, idx_bin, pif_pdf);

            uint64_t state = sampler.state;
            for (int k = 0; k < (use_antithetic_interior ? 2 : 1); k++) {
                if (use_antithetic_interior) sampler.state = state;
                auto res = LiAD_ZeroAndOneBounce(scene, &sampler, ec_length[k]);

                for (const auto&[tmp_contrib, path_length, idx_pixel] : res) {
                    if (idx_pixel >= 0) {
                        auto contrib = tmp_contrib * pif->evalAD(path_length, idx_bin) / pif_pdf;
                        if (use_antithetic_interior) contrib *= 0.5;
                        bool val_valid = std::isfinite(contrib.val[0]) && std::isfinite(contrib.val[1]) && std::isfinite(contrib.val[2]) && contrib.val.minCoeff() >= 0.0f;
                        Float tmp_val = contrib.der.abs().maxCoeff();
                        bool der_valid = std::isfinite(tmp_val) && tmp_val < options.grad_threshold;

                        if (val_valid && der_valid) {
                            int bin_offset = idx_bin * (nder + 1) * num_pixels;
                            image_per_thread[tid][bin_offset + idx_pixel] += contrib.val;
                            for (int j = 1; j <= nder; j++) {
                                image_per_thread[tid][bin_offset + j * num_pixels + idx_pixel] += contrib.grad(j - 1);
                            }
                        }
                        else {
                            if (!options.quiet) {
                                omp_set_lock(&messageLock);
                                if (!val_valid)
                                    std::cerr << std::scientific << std::setprecision(2) << "\n[WARN] Invalid path contribution: [" << contrib.val << "]" << std::endl;
                                if (!der_valid)
                                    std::cerr << std::scientific << std::setprecision(2) << "\n[WARN] Rejecting large gradient: [" << contrib.der << "]" << std::endl;
                                omp_unset_lock(&messageLock);
                            }
                        }
                    }
                }
            }
        }

        if (!options.quiet) {
            omp_set_lock(&messageLock);
            progressIndicator(Float(index_block) / num_block);
            omp_unset_lock(&messageLock);
        }
    }

    size_t num_samples = size_t(size_block) * num_block;
    for (int i = 0; i < nworker; i++) {
        for (int idx_bin = 0; idx_bin < num_bins; idx_bin++) {
            int bin_offset = idx_bin * (nder + 1) * num_pixels;
            for (int j = 0; j <= nder; j++) {
                for (int idx_pixel = 0; idx_pixel < num_pixels; idx_pixel++) {
                    int offset = bin_offset + j * num_pixels + idx_pixel;
                    rendered_image[offset * 3] += image_per_thread[i][offset][0] / static_cast<Float>(num_samples);
                    rendered_image[offset * 3 + 1] += image_per_thread[i][offset][1] / static_cast<Float>(num_samples);
                    rendered_image[offset * 3 + 2] += image_per_thread[i][offset][2] / static_cast<Float>(num_samples);
                }
            }
        }
    }
}

// Main

void Binned_TofPathTracerAD_PathSpace::render(const Scene &scene, const RenderOptions &options, ptr<float> rendered_image) const {
    using namespace std::chrono;

    const auto &pif = scene.camera.pif;
    int num_bins = pif->num_bins;
    int num_pixels = scene.camera.getNumPixels();
    for (int i = 0; i < num_bins * (nder + 1) * num_pixels * 3; i++) rendered_image[i] = 0;
    num_ellipsoidal_connections = options.num_ellipsoidal_connections;

    use_antithetic_boundary = options.use_antithetic_boundary;
    use_antithetic_interior = options.use_antithetic_interior;
    if (use_antithetic_boundary) {
        assert(pif->name == "Boxcar" || pif->name == "TruncatedGaussian");
    }
    if (use_antithetic_interior) {
        assert(pif->name == "Gaussian" || pif->name == "TruncatedGaussian");
    }

    auto to_second = [](auto start, auto end) {
        return duration_cast<milliseconds>(end - start).count() / 1000.0;
    };

    // Interior term
    if (!options.quiet)
        std::cout << "\nRendering the interior term." << std::endl;
    auto start_interior = high_resolution_clock::now();
    if (options.num_samples > 0) {
        renderInterior_ZeroAndOneBounce(scene, options, rendered_image);
        renderInterior(scene, options, rendered_image);
    }
    auto t_interior = to_second(start_interior, high_resolution_clock::now());

    // Visbility boundary term
    if (!options.quiet)
        std::cout << "\nRendering the visibility boundary term." << std::endl;
    auto start_visibility = high_resolution_clock::now();
    // Primary
    if (options.num_samples_primary_edge > 0 && scene.ptr_edgeManager->getNumPrimaryEdges() > 0)
        Binned_PathTracerAD_PathSpace::renderPrimaryEdges(scene, options, rendered_image);
    // Direct
#ifdef USE_BOUNDARY_NEE
    if (options.num_samples_secondary_edge_direct > 0)
        renderEdgesDirect(scene, options, rendered_image);
#endif
    // Indirect
    if (options.num_samples_secondary_edge_indirect > 0)
        renderEdges(scene, options, rendered_image);
    auto t_visibility = to_second(start_visibility, high_resolution_clock::now());

    // Path-length boundary term
    if (!options.quiet)
        std::cout << "\nRendering the path-length boundary term." << std::endl;
    auto start_pathlength = high_resolution_clock::now();
    if ((pif->name == "Boxcar" || pif->name == "TruncatedGaussian") && options.num_samples_time_edge > 0)
        renderBoxTimeGateBoundary(scene, options, rendered_image);
    auto t_pathlength = to_second(start_pathlength, high_resolution_clock::now());

    if (!options.quiet) {
        std::cout << std::fixed << std::setprecision(2)
                  << "\nElapsed time: \n"
                  << "\tInterior: " << t_interior << "s\n"
                  << "\tVisibility boundary: " << t_visibility << "s\n"
                  << "\tPath-length boundary: " << t_pathlength << "s\n"
                  << std::endl;
    }
}
