#include "tofIntegratorADps.h"
#include "scene.h"
#include "sampler.h"
#include "rayAD.h"
#include "math_func.h"
#include "bidir_utils.h"
#include "epc_sampler.h"
#include "sphere_sampler.h"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <fstream>

static const int nworker = omp_get_num_procs();

// Boundary term due to box time gates

void TofIntegratorAD_PathSpace::renderBoxTimeGateBoundary(const Scene& scene, const RenderOptions& options, ptr<float> rendered_image) const {
    renderBoxTimeGateBoundary_ZeroAndOneBounce(scene, options, rendered_image);
    renderBoxTimeGateBoudnary_MultiBounces(scene, options, rendered_image);
}

void TofIntegratorAD_PathSpace::renderBoxTimeGateBoundary_ZeroAndOneBounce(const Scene& scene, const RenderOptions& options, ptr<float> rendered_image) const {
    using namespace std::chrono;

    const auto &camera = scene.camera;
    const auto &pif = camera.pif;
    int size_block = camera.getNumPixels();
    int num_block = options.num_samples_time_edge;
    const bool cropped = camera.rect.isValid();
    const int num_pixels = cropped ? camera.rect.crop_width * camera.rect.crop_height
        : camera.width * camera.height;
    const int num_bins = pif->num_bins;

    std::vector<std::vector<Spectrum> > image_per_thread(nworker);
    for (int i = 0; i < nworker; i++) image_per_thread[i].resize(num_bins*nder*num_pixels, Spectrum(0.0f));

#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int index_block = 0; index_block < num_block; index_block++) {
        int block_start = index_block * size_block;
        int block_end = (index_block + 1)*size_block;
        for (int index_sample = block_start; index_sample < block_end; index_sample++) {
            RndSampler sampler(options.seed, index_sample);
            int tid = omp_get_thread_num();

            int idx_bin;
            Float pif_pdf;
            Vector2 ec_lengths = pif->sampleBoundary(sampler.next1D(), idx_bin, pif_pdf);
            uint64_t state = sampler.state;

            for (int k = 0; k < (use_antithetic_boundary ? 2 : 1); k++) {
                int i;
                if (use_antithetic_boundary) {
                    sampler.state = state;
                    i = k;
                }
                else {
                    i = int(sampler.next1D() * 2);
                }
                auto res = LiAD_ZeroAndOneBounce(scene, &sampler, ec_lengths[i]);
                Float sgn = (i == 0 ? 1.0 : -1.0);
                if (!use_antithetic_boundary) sgn *= 2.0;

                for (const auto&[contrib, path_length, idx_pixel] : res) {
                    if (idx_pixel >= 0) {
                        bool val_valid = std::isfinite(contrib.val[0]) && std::isfinite(contrib.val[1]) && std::isfinite(contrib.val[2]) && contrib.val.minCoeff() >= 0.0f;
                        Float tmp_val = path_length.der.abs().maxCoeff();
                        bool der_valid = std::isfinite(tmp_val) && tmp_val < options.grad_threshold;

                        if (val_valid && der_valid) {
                            int bin_offset = idx_bin * nder * num_pixels;
                            Float pif_kernel = pif->eval(path_length.val, idx_bin);
                            for (int j = 0; j < nder; j++) {
                                image_per_thread[tid][bin_offset + j * num_pixels + idx_pixel] += sgn * pif_kernel * contrib.val * path_length.grad(j) / pif_pdf;
                            }
                        }
                        else {
                            if (!options.quiet) {
                                omp_set_lock(&messageLock);
                                std::cerr << "\n[WARN] In renderBoxTimeGateBoundary_OneBounce!" << std::endl;
                                if (!val_valid)
                                    std::cerr << std::scientific << std::setprecision(2) << "\n[WARN] Invalid path contribution: [" << contrib.val << "]" << std::endl;
                                if (!der_valid)
                                    std::cerr << std::scientific << std::setprecision(2) << "\n[WARN] Rejecting large gradient: [" << path_length.der << "]" << std::endl;
                                omp_unset_lock(&messageLock);
                            }
                        }
                    }
                }
            }

        }

        if (!options.quiet) {
            omp_set_lock(&messageLock);
            progressIndicator(Float(index_block) / num_block);
            omp_unset_lock(&messageLock);
        }
    }

    std::cout << std::endl;

    size_t num_samples = size_t(size_block) * num_block;
    for (int i = 0; i < nworker; i++) {
        for (int idx_bin = 0; idx_bin < num_bins; idx_bin++) {
            int bin_offset1 = idx_bin * (nder + 1) * num_pixels;
            int bin_offset2 = idx_bin * nder * num_pixels;
            for (int j = 0; j < nder; j++) {
                for (int idx_pixel = 0; idx_pixel < num_pixels; idx_pixel++) {
                    int offset1 = (bin_offset1 + (j + 1) * num_pixels + idx_pixel) * 3;
                    int offset2 = bin_offset2 + j * num_pixels + idx_pixel;
                    rendered_image[offset1] += image_per_thread[i][offset2][0] / static_cast<Float>(num_samples);
                    rendered_image[offset1 + 1] += image_per_thread[i][offset2][1] / static_cast<Float>(num_samples);
                    rendered_image[offset1 + 2] += image_per_thread[i][offset2][2] / static_cast<Float>(num_samples);
                }
            }
        }
    }
}

void TofIntegratorAD_PathSpace::renderBoxTimeGateBoudnary_MultiBounces(const Scene& scene, const RenderOptions& options, ptr<float> rendered_image) const {
    using namespace std::chrono;

    const auto &camera = scene.camera;
    const auto &pif = camera.pif;
    const bool cropped = camera.rect.isValid();
    const int num_pixels = cropped ? camera.rect.crop_width * camera.rect.crop_height
        : camera.width * camera.height;
    const int size_block = 4;
    const int num_bins = pif->num_bins;

    // Pixel sampling
    int num_block = std::ceil((Float)num_pixels / size_block);

    int finished_block = 0;
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int index_block = 0; index_block < num_block; index_block++) {
        int block_start = index_block * size_block;
        int block_end = std::min((index_block + 1)*size_block, num_pixels);

        for (int idx_pixel = block_start; idx_pixel < block_end; idx_pixel++) {
            int ix = cropped ? camera.rect.offset_x + idx_pixel % camera.rect.crop_width
                : idx_pixel % camera.width;
            int iy = cropped ? camera.rect.offset_y + idx_pixel / camera.rect.crop_width
                : idx_pixel / camera.width;
            RndSampler sampler(options.seed, idx_pixel);

            std::vector<SpectrumAD> pixel_vals(num_bins, SpectrumAD(0));
            for (int idx_sample = 0; idx_sample < options.num_samples_time_edge; idx_sample++) {
                int idx_bin;
                Float pif_pdf;
                Vector2 ec_lengths = pif->sampleBoundary(sampler.next1D(), idx_bin, pif_pdf);
                uint64_t state = sampler.state;

                for (int k = 0; k < (use_antithetic_boundary ? 2 : 1); k++) {
                    int i;
                    if (use_antithetic_boundary) {
                        sampler.state = state;
                        i = k;
                    }
                    else {
                        i = int(sampler.next1D() * 2);
                    }

                    Ray rays[2];
                    camera.samplePrimaryRayFromFilter(ix, iy, sampler.next2D(), rays[0], rays[1]);
                    VectorAD o = camera.cpos;
                    RayAD rayAD(o, rays[0].dir);

                    auto res = LiAD(scene, &sampler, rayAD, ec_lengths[i], ix, iy, options.max_bounces);
                    Float sgn = (i == 0 ? 1.0 : -1.0);
                    if (!use_antithetic_boundary) sgn *= 2.0;

                    for (const auto&[contrib, path_length] : res) {
                        bool val_valid = std::isfinite(contrib.val[0]) && std::isfinite(contrib.val[1]) && std::isfinite(contrib.val[2]) && contrib.val.minCoeff() >= 0.0f;
                        Float tmp_val = path_length.der.abs().maxCoeff();
                        bool der_valid = std::isfinite(tmp_val) && tmp_val < options.grad_threshold;

                        if (val_valid && der_valid) {
                            Float pif_kernel = pif->eval(path_length.val, idx_bin);
                            for (int j = 0; j < nder; j++) {
                                pixel_vals[idx_bin].grad(j) += sgn * pif_kernel * contrib.val * path_length.grad(j) / pif_pdf;
                            }
                        }
                        else {
                            if (!options.quiet) {
                                omp_set_lock(&messageLock);
                                std::cerr << "\n[WARN] In renderBoxTimeGateBoundary_OneBounce!" << std::endl;
                                if (!val_valid)
                                    std::cerr << std::scientific << std::setprecision(2) << "\n[WARN] Invalid path contribution: [" << contrib.val << "]" << std::endl;
                                if (!der_valid)
                                    std::cerr << std::scientific << std::setprecision(2) << "\n[WARN] Rejecting large gradient: [" << path_length.der << "]" << std::endl;
                                omp_unset_lock(&messageLock);
                            }
                        }
                    }
                }
            }

            for (int idx_bin = 0; idx_bin < num_bins; idx_bin++) {
                SpectrumAD pixel_val = pixel_vals[idx_bin] / options.num_samples_time_edge;
                int bin_offset = idx_bin * (nder + 1) * num_pixels * 3;

                for (int ch = 1; ch <= nder; ++ch) {
                    int offset = (ch*num_pixels + idx_pixel) * 3;
                    rendered_image[bin_offset + offset] += static_cast<float>((pixel_val.grad(ch - 1))(0));
                    rendered_image[bin_offset + offset + 1] += static_cast<float>((pixel_val.grad(ch - 1))(1));
                    rendered_image[bin_offset + offset + 2] += static_cast<float>((pixel_val.grad(ch - 1))(2));
                }
            }
        }

        if (!options.quiet) {
            omp_set_lock(&messageLock);
            progressIndicator(Float(++finished_block) / num_block);
            omp_unset_lock(&messageLock);
        }
    }

    std::cout << std::endl;
}

// Helper functions

std::vector<std::tuple<Spectrum, int, int>> TofIntegratorAD_PathSpace::ellipsoidalConnectToCamera(const Scene& scene, RndSampler* sampler,
        const Intersection& x1_its, int idx_bin, Float ec_length, const Spectrum& throughput) const {
    std::vector<std::tuple<Spectrum, int, int>> ret;
    if (ec_length <= 0) return ret;

    const auto& camera = scene.camera;
    EPCSampler epc_sampler(x1_its.p, x1_its.geoFrame.n, camera.cpos.val, camera.cframe.n.val, ec_length, scene);

    for (int i = 0; i < num_ellipsoidal_connections; i++) {
        Intersection x_its;
        Float x_jacobian, x_pdf;
        if (epc_sampler.sample(sampler->next2D(), x_its, x_jacobian, x_pdf)) {
            Matrix2x4 pix_uvs;
            Vector vec_x_camera = (camera.cpos.val - x_its.p).normalized();

            Array4 sensor_values(0.0);
            camera.sampleDirect(x_its.p, pix_uvs, sensor_values, vec_x_camera);
            if (!sensor_values.isZero()) {
                Vector vec_x_cur = x1_its.p - x_its.p;
                Float dist_x_cur = vec_x_cur.norm();
                vec_x_cur /= dist_x_cur;
                Vector cur_wo = x1_its.toLocal(-vec_x_cur);

                Vector x_wo = x_its.toLocal(vec_x_camera);
                Float G_x_cur = fabs(x_its.geoFrame.n.dot(vec_x_cur)) / square(dist_x_cur);

                Spectrum res = throughput * x1_its.evalBSDF(cur_wo, EBSDFMode::EImportanceWithCorrection) * G_x_cur *
                    x_its.evalBSDF(x_wo, EBSDFMode::EImportanceWithCorrection) * x_jacobian / (x_pdf * num_ellipsoidal_connections);

                for (int k = 0; k < 4; k++) {
                    if (sensor_values[k] > Epsilon) {
                        int idx_pixel = camera.getPixelIndex(pix_uvs.col(k));
                        if (idx_pixel >= 0)
                            ret.push_back(std::make_tuple(res * sensor_values[k], idx_pixel, idx_bin));
                    }
                }
            }
        }
    }
    return ret;
}

std::vector<std::tuple<SpectrumAD, FloatAD, int>> TofIntegratorAD_PathSpace::ellipsoidalConnectToLightAD(const Scene& scene, RndSampler* sampler,
        const IntersectionAD& x1_its_AD, Float ec_length, bool x1_is_camera, Float throughput) const {
    std::vector<std::tuple<SpectrumAD, FloatAD, int>> ret;
    if (ec_length <= 0) return ret;

    VectorAD x1_AD = x1_its_AD.p;
    VectorAD n1_AD = x1_its_AD.geoFrame.n;

    Intersection x2_its;
    Float x2_pdf;
    scene.sampleEmitterPosition(sampler->next2D(), x2_its, &x2_pdf);
    Vector x2 = x2_its.p;

    EPCSampler epc_sampler(x1_AD.val, n1_AD.val, x2, x2_its.geoFrame.n, ec_length, scene);

    for (int i = 0; i < num_ellipsoidal_connections; i++) {
        Intersection x_its;
        Float x_jacobian, x_pdf;
        if (!epc_sampler.sample(sampler->next2D(), x_its, x_jacobian, x_pdf)) continue;

        FloatAD J = x_jacobian;

        // Set velocity for x.
        IntersectionAD x_its_AD;
        FloatAD tmpJ;
        scene.getPoint(x_its, x1_AD, x_its_AD, tmpJ);
        J *= tmpJ;
        VectorAD x_AD = x_its_AD.p;

        // Set velocity for x2.
        VectorAD x2_AD, n2_AD;
        scene.getPoint(x2_its, x2_AD, n2_AD, tmpJ);
        J *= tmpJ;

        // Compute contribution (x <-> x2).
        VectorAD vec_x_x2_AD = x2_AD - x_AD;
        FloatAD distSqr = vec_x_x2_AD.squaredNorm();
        vec_x_x2_AD /= distSqr.sqrt();
        FloatAD G_x_x2_AD = n2_AD.dot(-vec_x_x2_AD) / distSqr;
        VectorAD x_wo_AD = x_its_AD.toLocal(vec_x_x2_AD);
        SpectrumAD light_contrib = x2_its.ptr_emitter->evalAD(n2_AD, -vec_x_x2_AD);
        SpectrumAD x_x2_contrib = light_contrib * G_x_x2_AD * x_its_AD.evalBSDF(x_wo_AD);

        FloatAD path_length = (x1_AD - x_AD).norm() + (x_AD - x2_AD).norm();

        // Compute contribution (x1 <-> x).
        if (x1_is_camera) {
            Matrix2x4AD pixel_uvs;
            Vector4AD weights;
            VectorAD vec_x_x1;
            scene.camera.sampleDirectAD(x_AD, pixel_uvs, weights, vec_x_x1);

            Float pdf = x2_pdf * x_pdf;
            SpectrumAD res = x_x2_contrib * throughput * J / (pdf * num_ellipsoidal_connections);

            for (int i = 0; i < 4; i++) {
                Vector2 pixel_uv(pixel_uvs(0, i).val, pixel_uvs(1, i).val);
                int idx_pixel = scene.camera.getPixelIndex(pixel_uv);
                SpectrumAD x1_x_contrib = SpectrumAD(weights(i)) * x_its_AD.wi.z();

                ret.push_back(std::make_tuple(res * x1_x_contrib, path_length, idx_pixel));
            }
        }
        else {
            int idx_pixel = -1;
            VectorAD vec_x1_x_AD = x_AD - x1_AD;
            FloatAD distSqr = vec_x1_x_AD.squaredNorm();
            vec_x1_x_AD /= distSqr.sqrt();
            FloatAD G_x1_x_AD = x_its_AD.wi.z() / distSqr;
            VectorAD x1_xo_AD = x1_its_AD.toLocal(vec_x1_x_AD);
            SpectrumAD x1_x_contrib = x1_its_AD.evalBSDF(x1_xo_AD) * G_x1_x_AD;

            Float pdf = x2_pdf * x_pdf;
            SpectrumAD res = x_x2_contrib * x1_x_contrib * throughput * J / (pdf * num_ellipsoidal_connections);
            ret.push_back(std::make_tuple(res, path_length, idx_pixel));
        }
    }
    return ret;
}

std::vector<std::tuple<SpectrumAD, FloatAD, int>> TofIntegratorAD_PathSpace::sphericalConnectToLightAD(const Scene& scene, RndSampler* sampler,
        const IntersectionAD& x1_its_AD, Float sph_radius, bool x1_is_camera, Float throughput) const {
    // Only handle the case with 0 indirect bounce.
    assert(x1_is_camera);

    std::vector<std::tuple<SpectrumAD, FloatAD, int>> ret;
    if (sph_radius <= 0) return ret;

    VectorAD x1_AD = x1_its_AD.p;
    VectorAD n1_AD = x1_its_AD.geoFrame.n;

    SphereSampler sphere_sampler(x1_AD.val, n1_AD.val, sph_radius, scene);

    for (int i = 0; i < num_ellipsoidal_connections; i++) {
        Intersection x_its;
        Float x_jacobian, x_pdf;
        if (!sphere_sampler.sample(sampler->next2D(), x_its, x_jacobian, x_pdf)) continue;

        FloatAD J = x_jacobian;

        // Set velocity for x.
        IntersectionAD x_its_AD;
        FloatAD tmpJ;
        scene.getPoint(x_its, x1_AD, x_its_AD, tmpJ);
        J *= tmpJ;
        VectorAD x_AD = x_its_AD.p;

        FloatAD path_length = (x_AD - x1_AD).norm();

        // Compute contribution (x1 <-> x).
        if (x1_is_camera) {
            Matrix2x4AD pixel_uvs;
            Vector4AD weights;
            VectorAD vec_x_x1_AD;
            scene.camera.sampleDirectAD(x_AD, pixel_uvs, weights, vec_x_x1_AD);
            
            SpectrumAD x1_x_contrib = x_its_AD.ptr_emitter->evalAD(x_its_AD.shFrame.n, vec_x_x1_AD) * x_its_AD.wi.z();
            SpectrumAD res = x1_x_contrib * throughput * J / (x_pdf * num_ellipsoidal_connections);

            for (int i = 0; i < 4; i++) {
                Vector2 pixel_uv(pixel_uvs(0, i).val, pixel_uvs(1, i).val);
                int idx_pixel = scene.camera.getPixelIndex(pixel_uv);
                ret.push_back(std::make_tuple(res * SpectrumAD(weights(i)), path_length, idx_pixel));
            }
        }
    }
    return ret;
}

std::vector<std::tuple<SpectrumAD, FloatAD>> TofIntegratorAD_PathSpace::LiAD(const Scene &scene, RndSampler* sampler, const RayAD &_ray,
        Float ec_length, int pixel_x, int pixel_y, int max_depth) const {
    Ray ray = _ray.toRay();
    Intersection its;
    IntersectionAD itsAD, pre_itsAD;
    std::vector<std::tuple<SpectrumAD, FloatAD>> records;
    SpectrumAD throughputAD(Spectrum::Ones());
    Float eta = 1.0f;
    FloatAD path_length = 0;
    Float bsdf_pdf;
    const auto& camera = scene.camera;

    pre_itsAD.p = _ray.org;

    scene.rayIntersect(ray, false, its);

    for (int depth = 0; depth <= max_depth; depth++) {
        bool isFirst = (depth == 0);

        if (!its.isValid())
            break;

        // getPoint
        FloatAD J;
        scene.getPoint(its, pre_itsAD.p, itsAD, J);
        VectorAD x = itsAD.p;
        VectorAD dir = x - pre_itsAD.p;
        FloatAD dist = dir.norm();
        dir /= dist;

        path_length += dist;

        // calculate throughput
        VectorAD dir_local = pre_itsAD.toLocal(dir);
        SpectrumAD f = isFirst ? camera.evalFilterAD(pixel_x, pixel_y, itsAD) : pre_itsAD.evalBSDF(dir_local) / bsdf_pdf;
        FloatAD G = isFirst ? FloatAD(1.) : itsAD.geoFrame.n.dot(-dir) / dist.square();
        Float pdf = G.val;
        throughputAD *= f * G * J / pdf;

        if (depth + 1 >= max_depth || path_length.val >= ec_length)
            break;

        // Light ellipsoidal connection: 2+ bounces
        {
            auto res = ellipsoidalConnectToLightAD(scene, sampler, itsAD,
                ec_length - path_length.val, false, 1.0);
            for (auto&[tmp_contrib, ec_length, idx_pixel] : res) {
                if (!tmp_contrib.isZero(Epsilon)) {
                    assert(idx_pixel == -1);
                    FloatAD total_path_length = path_length + ec_length;
                    SpectrumAD contrib = throughputAD * tmp_contrib;
                    records.push_back(std::tie(contrib, total_path_length));
                }
            }
        }

        // BSDF sampling
        Vector wo;
        Float bsdf_eta;
        auto bsdf_weight = its.sampleBSDF(sampler->next3D(), wo, bsdf_pdf, bsdf_eta);
        if (bsdf_weight.isZero())
            break;
        wo = its.toWorld(wo);
        ray = Ray(its.p, wo);
        eta *= bsdf_eta;

        scene.rayIntersect(ray, true, its);

        pre_itsAD = itsAD;
    }
    return records;
}

std::vector<std::tuple<SpectrumAD, FloatAD, int>> TofIntegratorAD_PathSpace::LiAD_ZeroAndOneBounce(const Scene& scene, RndSampler* sampler,
        Float ec_length) const {
    const auto& camera = scene.camera;

    IntersectionAD x1_its_AD;
    x1_its_AD.p = camera.cpos;
    x1_its_AD.geoFrame.n = camera.cframe.n;
    auto res = ellipsoidalConnectToLightAD(scene, sampler, x1_its_AD,
        ec_length, true, 1.0);

    auto res0 = sphericalConnectToLightAD(scene, sampler, x1_its_AD,
        ec_length, true, 1.0);
    res.insert(res.end(), res0.begin(), res0.end());

    return res;
}