#include "boundary.h"
#include <render/scene.h>
#include <core/math_func.h>
#include <core/timer.h>
#include <core/logger.h>
namespace
{
    static DebugInfo debugInfo;
    [[maybe_unused]] void velocity(const Scene &scene,
                                   const BoundarySamplingRecord &bRec,
                                   Float &res)
    {
        const Shape *shape = scene.shape_list[bRec.shape_id];
        const Edge &edge = shape->edges[bRec.edge_id];
        const Vector &xB_0 = shape->getVertex(edge.v0);
        const Vector &xB_1 = shape->getVertex(edge.v1);
        const Vector &xB_2 = shape->getVertex(edge.v2);

        const Shape *shapeS = scene.shape_list[bRec.shape_id_S];
        const auto &indS = shapeS->getIndices(bRec.tri_id_S);
        const Vector &xS_0 = shapeS->getVertex(indS[0]);
        const Vector &xS_1 = shapeS->getVertex(indS[1]);
        const Vector &xS_2 = shapeS->getVertex(indS[2]);

        res = normal_velocity(scene.camera.cpos,
                              xB_0, xB_1, xB_2, bRec.t, bRec.dir,
                              xS_0, xS_1, xS_2);
    }

    void d_velocity(const Scene &scene, Scene &d_scene,
                    const EdgeRaySamplingRecord &eRec,
                    Float d_u)
    {
        [[maybe_unused]] Float u;
#if defined(ENZYME) && defined(ENZYME_BOUNDARY_PRIMARY)
        __enzyme_autodiff((void *)velocity,
                          enzyme_dup, &scene, &d_scene,
                          enzyme_const, &eRec,
                          enzyme_dupnoneed, &u, &d_u);
#endif
    }

    [[maybe_unused]] Spectrum Li(const Scene &scene, RndSampler *sampler, const Ray &_ray, int max_depth)
    {
        Ray ray(_ray);
        Intersection its;
        Spectrum ret = Spectrum::Zero();
        scene.rayIntersect(ray, true, its);
        if (its.isValid())
        {
            Spectrum throughput = Spectrum::Ones();
            Float eta = 1.0f;
            int depth = 0;
            while (depth <= max_depth && its.isValid())
            {
                if (its.isEmitter() && depth == 0)
                    ret += throughput * its.Le(-ray.dir);
                if (depth >= max_depth)
                    break;
                // Direct illumination
                Float pdf_nee;
                Vector wo;
                auto value = scene.sampleEmitterDirect(its, sampler->next2D(), sampler, wo, pdf_nee);
                if (!value.isZero())
                {
                    auto bsdf_val = its.evalBSDF(wo);
                    Float bsdf_pdf = its.pdfBSDF(wo);
                    auto mis_weight = square(pdf_nee) / (square(pdf_nee) + square(bsdf_pdf));
                    ret += throughput * value * bsdf_val * mis_weight;
                }
                // Indirect illumination
                Float bsdf_pdf, bsdf_eta;
                auto bsdf_weight = its.sampleBSDF(sampler->next3D(), wo, bsdf_pdf, bsdf_eta);
                if (bsdf_weight.isZero())
                    break;

                wo = its.toWorld(wo);
                ray = Ray(its.p, wo);

                if (!scene.rayIntersect(ray, true, its))
                    break;

                throughput *= bsdf_weight;
                eta *= bsdf_eta;
                if (its.isEmitter())
                {
                    Spectrum light_contrib = its.Le(-ray.dir);
                    if (!light_contrib.isZero())
                    {
                        auto dist_sq = (its.p - ray.org).squaredNorm();
                        auto geometry_term = its.geoFrame.n.dot(-ray.dir) / dist_sq;
                        pdf_nee = scene.pdfEmitterSample(its) / geometry_term;
                        auto mis_weight = square(bsdf_pdf) / (square(pdf_nee) + square(bsdf_pdf));
                        ret += throughput * light_contrib * mis_weight;
                    }
                }

                depth++;
            }
        }
        return ret;
    }

    void handleBoundaryInteraction(const Vector &p,
                                   const Scene &scene, Scene &d_scene,
                                   BoundarySamplingRecord &eRec,
                                   RndSampler &sampler, const Spectrum &weight,
                                   int max_depth, std::vector<Spectrum> &d_image)
    {
        int shape_idx = -1;
        if (forward) shape_idx = scene.getShapeRequiresGrad();
        CameraDirectSamplingRecord cRec;
        if (!scene.camera.sampleDirect(p, cRec))
            return;
        if (!scene.isVisible(p, true, scene.camera.cpos, true))
            return;
        auto [pixel_idx, sensor_val] = scene.camera.sampleDirectPixel(cRec, sampler.next1D());
        if (sensor_val < Epsilon)
            return;
        Float d_u = (d_image[pixel_idx] * weight * sensor_val).sum();

        Float param = 0;
        if (forward) {
            d_scene.shape_list[shape_idx]->param = 0;
            param = d_scene.shape_list[shape_idx]->param;
        }
        d_velocity(scene, d_scene, eRec, d_u);
        if (forward) {
            param = d_scene.shape_list[shape_idx]->param - param;
            const int tid = omp_get_thread_num();
            debugInfo.image_per_thread[tid][pixel_idx] += Spectrum(param, 0, 0);
        }
    }

    void d_samplePrimaryBoundary(const Scene &scene, Scene &d_scene,
                                 RndSampler &sampler, const RenderOptions &options,
                                 const DiscreteDistribution &edge_dist,
                                 const std::vector<Vector2i> &edge_indices,
                                 std::vector<Spectrum> &d_image)
    {
        /* Sample a point on the boundary */
        BoundarySamplingRecord eRec;
        scene.sampleEdgePoint(sampler.next1D(),
                              edge_dist, edge_indices,
                              eRec);
        if (eRec.shape_id == -1)
        {
            PSDR_WARN(eRec.shape_id == -1);
            return;
        }
        const Shape *shape = scene.shape_list[eRec.shape_id];
        const Edge &edge = shape->edges[eRec.edge_id];
        Vector v0 = shape->getVertex(edge.v0);
        Vector v1 = shape->getVertex(edge.v1);
        Vector v2 = shape->getVertex(edge.v2);
        const Vector xB = v0 + (v1 - v0) * eRec.t,
                     &xD = scene.camera.cpos;
        Ray ray(xB, (xB - xD).normalized());
        Intersection its;

        if (!scene.rayIntersect(ray, true, its))
            return;
        const Vector &xS = its.p;
        // populate the data in BoundarySamplingRecord eRec
        eRec.shape_id_S = its.indices[0];
        eRec.tri_id_S = its.indices[1];
        // sanity check
        {
            // make sure the ray is tangent to the surface
            if (edge.f0 >= 0 && edge.f1 >= 0)
            {
                Vector n0 = shape->getGeoNormal(edge.f0),
                       n1 = shape->getGeoNormal(edge.f1);
                Float dotn0 = ray.dir.dot(n0),
                      dotn1 = ray.dir.dot(n1);
                if (math::signum(dotn0) * math::signum(dotn1) > -0.5)
                    return;
            }
            // NOTE prevent intersection with a backface

            Float gnDotD = its.geoFrame.n.dot(-ray.dir);
            Float snDotD = its.shFrame.n.dot(-ray.dir);
            bool success = (its.ptr_bsdf->isTransmissive() && math::signum(gnDotD) * math::signum(snDotD) > 0.5f) ||
                           (!its.ptr_bsdf->isTransmissive() && gnDotD > Epsilon && snDotD > Epsilon);
            if (!success)
                return;
            if (!scene.isVisible(xB, true, xD, true))
                return;
        }

        Vector n = (v0 - v1).cross(ray.dir).normalized();
        n *= -math::signum(n.dot(v2 - v0)); // make sure n points to the visible side

        /* Jacobian determinant that accounts for the change of variable */
        Float J = dlS_dlB(xD,
                          xB, (v0 - v1).normalized(),
                          xS, its.geoFrame.n);

        Float cosS = std::abs(its.geoFrame.n.dot(-ray.dir));
        Float baseValue = J * cosS;
        assert(baseValue > -Epsilon);
        /* Sample source path */
        Spectrum value = Li(scene, &sampler,
                            Ray(xB, ray.dir),
                            options.max_bounces);
        /* Evaluate primary boundary segment */
        handleBoundaryInteraction(xS, scene, d_scene,
                                  eRec, sampler, value * baseValue / eRec.pdf,
                                  options.max_bounces, d_image);
    }

}

PrimaryEdgeIntegrator::PrimaryEdgeIntegrator(const Scene &scene)
{
    configure(scene);
}

PrimaryEdgeIntegrator::PrimaryEdgeIntegrator(const Properties &props)
{}

void PrimaryEdgeIntegrator::configure(const Scene &scene)
{
    edge_indices.clear();
    edge_dist.clear();
    /* generate the edge distribution */
    auto &camera = scene.camera;
    for (size_t i = 0; i < scene.shape_list.size(); i++)
    {
        auto &shape = *scene.shape_list[i];
        if (scene.bsdf_list[shape.bsdf_id]->isNull() &&
            shape.light_id < 0)
            continue;
        for (size_t j = 0; j < shape.edges.size(); j++)
        {
            const Edge &edge = shape.edges[j];
            // if an edge is a boundary edge or a sihoulette edge
            if (edge.mode == 0 ||
                shape.isSihoulette(edge, camera.cpos))
            {
                edge_indices.push_back({i, j});
                edge_dist.append(edge.length);
                continue;
            }
        }
    }
    edge_dist.normalize();
}

ArrayXd PrimaryEdgeIntegrator::renderD(SceneAD &sceneAD,
                                       RenderOptions &options, const ArrayXd &__d_image) const
{
    PSDR_INFO("PrimaryEdgeIntegrator::renderD with spp = {}",
              options.num_samples_primary_edge);
    const Scene &scene = sceneAD.val;
    [[maybe_unused]] Scene &d_scene = sceneAD.der;
    GradientManager<Scene> &gm = sceneAD.gm;
    gm.setZero(); // zero multi-thread gradient

    const int nworker = omp_get_num_procs();
    const auto &camera = scene.camera;
    const int nsamples = options.num_samples_primary_edge;
    const int nblocks = std::ceil(static_cast<Float>(camera.getNumPixels()) / (options.block_size * options.block_size));
    const int nblock_samples = options.block_size * options.block_size * nsamples;
    /* init debug info */
    debugInfo = DebugInfo(nworker, camera.getNumPixels(), nsamples);
    if (nsamples <= 0)
        return debugInfo.getArray();

    std::vector<Spectrum> _d_image_spec_list = from_tensor_to_spectrum_list(
        __d_image / nblock_samples / nblocks, camera.getNumPixels());

    Timer _("Primary boundary");

    int blockProcessed = 0;
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < nblocks; i++)
    {
        for (int j = 0; j < nblock_samples; j++)
        {
            const int tid = omp_get_thread_num();
            RndSampler sampler(options.seed, i * nblock_samples + j);
            // sample a point on the boundary
            d_samplePrimaryBoundary(scene, gm.get(tid),
                                    sampler, options,
                                    edge_dist, edge_indices,
                                    _d_image_spec_list);
        }
        if (verbose)
#pragma omp critical
            progressIndicator(static_cast<Float>(++blockProcessed) / nblocks);
    }
    if (verbose)
        std::cout << std::endl;

    // merge d_scene
    gm.merge();
    d_scene.configureD(scene);
    /* normal related */
#ifdef NORMAL_PREPROCESS
    Timer preprocess_timer("preprocess");
    d_precompute_normal(scene, d_scene);
#endif

    return flattened(debugInfo.getArray());
}

ArrayXd PrimaryEdgeIntegrator::forwardRenderD(SceneAD &sceneAD,
                                       RenderOptions &options) const
{
    PSDR_INFO("PrimaryEdgeIntegrator::forwardRenderD with spp = {}",
              options.num_samples_primary_edge);
    const Scene &scene = sceneAD.val;
    [[maybe_unused]] Scene &d_scene = sceneAD.der;
    GradientManager<Scene> &gm = sceneAD.gm;
    gm.setZero(); // zero multi-thread gradient

    const int nworker = omp_get_num_procs();
    const auto &camera = scene.camera;
    const int nsamples = options.num_samples_primary_edge;
    const int nblocks = std::ceil(static_cast<Float>(camera.getNumPixels()) / (options.block_size * options.block_size));
    const int nblock_samples = options.block_size * options.block_size * nsamples;
    /* init debug info */
    debugInfo = DebugInfo(nworker, camera.getNumPixels(), nsamples);
    if (nsamples <= 0)
        return debugInfo.getArray();

    std::vector<Spectrum> _d_image_spec_list(camera.getNumPixels(), Spectrum(1, 0, 0) / nblock_samples / nblocks);
    
    Timer _("Primary boundary");

    int blockProcessed = 0;
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < nblocks; i++)
    {
        for (int j = 0; j < nblock_samples; j++)
        {
            const int tid = omp_get_thread_num();
            RndSampler sampler(options.seed, i * nblock_samples + j);
            // sample a point on the boundary
            d_samplePrimaryBoundary(scene, gm.get(tid),
                                    sampler, options,
                                    edge_dist, edge_indices,
                                    _d_image_spec_list);
        }
        if (verbose)
#pragma omp critical
            progressIndicator(static_cast<Float>(++blockProcessed) / nblocks);
    }
    if (verbose)
        std::cout << std::endl;

    // merge d_scene
    gm.merge();
    d_scene.configureD(scene);
    /* normal related */
#ifdef NORMAL_PREPROCESS
    Timer preprocess_timer("preprocess");
    d_precompute_normal(scene, d_scene);
#endif

    return flattened(debugInfo.getArray());
}
