#include "vol_boundary.h"
#include <render/scene.h>
#include <core/math_func.h>
#include <core/timer.h>
#include <core/logger.h>

namespace
{
    static DebugInfo debugInfo;
    [[maybe_unused]] void velocity(const Scene &scene,
                                   const BoundarySamplingRecord &bRec,
                                   Float &res)
    {
        const Shape *shape = scene.shape_list[bRec.shape_id];
        const Edge &edge = shape->edges[bRec.edge_id];
        const Vector &xB_0 = shape->getVertex(edge.v0);
        const Vector &xB_1 = shape->getVertex(edge.v1);
        const Vector &xB_2 = shape->getVertex(edge.v2);

        // getPoint xD
        Vector xD = scene.camera.cpos;

        if (bRec.onSurface_S)
        {
            const Shape *shapeS = scene.shape_list[bRec.shape_id_S];
            const auto &indS = shapeS->getIndices(bRec.tri_id_S);
            const Vector &xS_0 = shapeS->getVertex(indS[0]);
            const Vector &xS_1 = shapeS->getVertex(indS[1]);
            const Vector &xS_2 = shapeS->getVertex(indS[2]);
            res = normal_velocity(xD,
                                  xB_0, xB_1, xB_2, bRec.t,
                                  xS_0, xS_1, xS_2);
        }
        else
        {
            assert(bRec.med_id_S != -1);
            assert(bRec.tet_id_S != -1);
            assert(abs(1.0 - bRec.barycentric4_S.sum()) < 1e-5);
            const Medium *medium = scene.medium_list[bRec.med_id_S];
            const Vector4i &indices = medium->getTet(bRec.tet_id_S);
            const Vector &xS_0 = medium->getVertex(scene, indices[0]);
            const Vector &xS_1 = medium->getVertex(scene, indices[1]);
            const Vector &xS_2 = medium->getVertex(scene, indices[2]);
            const Vector &xS_3 = medium->getVertex(scene, indices[3]);
            // getPoint
            Vector xS = Vector::Zero();
            for (int i = 0; i < 4; i++)
                xS += bRec.barycentric4_S[i] *
                      medium->getVertex(scene, indices[i]);
            Float dist = (xS - xD).norm();
            res = normal_velocity(xD,
                                  xB_0, xB_1, xB_2, bRec.t,
                                  xS_0, xS_1, xS_2, xS_3, detach(dist));
        }
    }

    void d_velocity(const Scene &scene, Scene &d_scene,
                    const BoundarySamplingRecord &eRec,
                    Float d_u)
    {
        [[maybe_unused]] Float u;
#if defined(ENZYME) && defined(ENZYME_BOUNDARY_PRIMARY)
        __enzyme_autodiff((void *)velocity,
                          enzyme_dup, &scene, &d_scene,
                          enzyme_const, &eRec,
                          enzyme_dupnoneed, &u, &d_u);
#endif
    }

    Spectrum radiance(const Scene &scene, const Ray &_ray, bool OnSurface, RndSampler *sampler, const Medium *medium, int max_bounces)
    {
        // loop variables : ray, its, mRec, ret, depth, throughput, incEmission, medium
        Ray ray(_ray);
        Intersection its;
        MediumSamplingRecord mRec;
        Spectrum ret = Spectrum::Zero();
        int depth = 0;
        Spectrum throughput = Spectrum::Ones();
        bool incEmission = true;
        int med_id = -1; // FIXME: assume the camera is outside any shapes
        int shape_id = -1;
        Float eta = 1.0f;
        scene.rayIntersect(ray, OnSurface, its);
        while (depth <= max_bounces)
        {
            bool inside_med = medium != nullptr &&
                              medium->sampleDistance(Ray(ray), its.t, sampler, mRec);
            if (inside_med) // sampled a medium interaction
            {
                if (depth >= max_bounces)
                    break;
                const PhaseFunction *phase = scene.phase_list[medium->phase_id];
                throughput *= mRec.sigmaS * mRec.transmittance / mRec.pdfSuccess;
                // ====================== emitter sampling =========================
                DirectSamplingRecord dRec(mRec.p);
                Spectrum value = scene.sampleAttenuatedEmitterDirect(
                    dRec, sampler->next2D(), sampler, mRec.medium);
                if (!value.isZero())
                {
                    Float phaseVal = phase->eval(-ray.dir, dRec.dir);
                    if (phaseVal != 0)
                    {
                        Float phasePdf = phase->pdf(-ray.dir, dRec.dir);
                        Float mis_weight = miWeight(dRec.pdf / dRec.G, phasePdf);
                        ret += throughput * value * phaseVal * mis_weight;
                    }
                }

                // ====================== phase sampling =============================
                Vector wo;
                Float phaseVal = phase->sample(-ray.dir, sampler->next2D(), wo);
                Float phasePdf = phase->pdf(-ray.dir, wo);
                if (phaseVal == 0)
                    break;
                throughput *= phaseVal;
                ray = Ray(mRec.p, wo);

                value = scene.rayIntersectAndLookForEmitter(
                    ray, false, sampler, mRec.medium, its, dRec);
                if (!value.isZero())
                {
                    Float pdf_emitter = scene.pdfEmitterDirect(dRec);
                    Float mis_weight = miWeight(phasePdf, pdf_emitter / dRec.G);
                    ret += throughput * value * mis_weight;
                }
                // update loop variables
                incEmission = false;
            }
            else // sampled a surface interaction
            {
                if (medium)
                {
                    throughput *= mRec.transmittance / mRec.pdfFailure;
                }
                if (!its.isValid())
                    break;

                if (its.isEmitter() && incEmission)
                    ret += throughput * its.Le(-ray.dir);

                if (depth >= max_bounces)
                    break;

                // ====================== emitter sampling =========================
                DirectSamplingRecord dRec(its);
                if (!its.getBSDF()->isNull())
                {
                    Spectrum value = scene.sampleAttenuatedEmitterDirect(
                        dRec, its, sampler->next2D(), sampler, medium);

                    if (!value.isZero())
                    {
                        Spectrum bsdfVal = its.evalBSDF(its.toLocal(dRec.dir));
                        Float bsdfPdf = its.pdfBSDF(its.toLocal(dRec.dir));
                        Float mis_weight = miWeight(dRec.pdf / dRec.G, bsdfPdf);
                        ret += throughput * value * bsdfVal * mis_weight;
                    }
                }

                // ====================== BSDF sampling =============================
                Vector wo;
                Float bsdfPdf, bsdfEta;
                Spectrum bsdfWeight = its.sampleBSDF(sampler->next3D(), wo, bsdfPdf, bsdfEta);
                if (bsdfWeight.isZero())
                    break;

                wo = its.toWorld(wo);
                ray = Ray(its.p, wo);

                throughput *= bsdfWeight;
                if (its.isMediumTransition())
                {
                    med_id = its.getTargetMediumId(wo);
                    medium = its.getTargetMedium(wo);
                }
                if (its.getBSDF()->isNull())
                {
                    scene.rayIntersect(ray, true, its);
                }
                else
                {
                    Spectrum value = scene.rayIntersectAndLookForEmitter(
                        ray, true, sampler, medium, its, dRec);
                    if (!value.isZero())
                    {
                        Float mis_weight = miWeight(bsdfPdf, dRec.pdf / dRec.G);
                        ret += throughput * value * mis_weight;
                    }
                    incEmission = false;
                }
            }
            depth++;
        }
        return ret;
    }

    void handleBoundaryInteraction(const Vector &p,
                                   const Scene &scene, Scene &d_scene,
                                   BoundarySamplingRecord &eRec,
                                   RndSampler &sampler, const Spectrum &weight,
                                   int max_depth, std::vector<Spectrum> &d_image)
    {
#if defined(FORWARD)
        int shape_idx = scene.getShapeRequiresGrad();
#endif

        Matrix2x4 pixel_uvs;
        Array4 attenuations(0.0);
        Vector dir;
        scene.camera.sampleDirect(p, pixel_uvs, attenuations, dir);
        if (attenuations.isZero())
            return;
        Spectrum value = weight;
        // handle 4 neighbor pixels
        for (int i = 0; i < 4; i++)
        {
            int pixel_idx = scene.camera.getPixelIndex(pixel_uvs.col(i));
            Float d_u = 0.f;
            if (attenuations[i] < Epsilon)
                continue;
            // handle 3 channels
            for (int c = 0; c < 3; c++)
            {
                // chain rule
                d_u += d_image[pixel_idx][c] * (value[c] * attenuations[i]);
            }
#ifdef FORWARD
            d_scene.shape_list[shape_idx]->param = 0;
            Float param = d_scene.shape_list[shape_idx]->param;
#endif
            d_velocity(scene, d_scene, eRec, d_u);
#ifdef FORWARD
            param = d_scene.shape_list[shape_idx]->param - param;
            const int tid = omp_get_thread_num();
            debugInfo.image_per_thread[tid][pixel_idx] += Spectrum(param, 0, 0);
#endif
        }
    }

    void d_samplePrimaryBoundary(const Scene &scene, Scene &d_scene,
                                 RndSampler &sampler, const RenderOptions &options,
                                 const DiscreteDistribution &edge_dist,
                                 const std::vector<Vector2i> &edge_indices,
                                 std::vector<Spectrum> &d_image)
    {
        /* Sample a point on the boundary */
        BoundarySamplingRecord eRec;
        scene.sampleEdgePoint(sampler.next1D(),
                              edge_dist, edge_indices,
                              eRec);
        if (eRec.shape_id == -1)
        {
            PSDR_WARN(eRec.shape_id == -1);
            return;
        }
        const Shape *shape = scene.shape_list[eRec.shape_id];
        const Edge &edge = shape->edges[eRec.edge_id];
        Vector v0 = shape->getVertex(edge.v0);
        Vector v1 = shape->getVertex(edge.v1);
        Vector v2 = shape->getVertex(edge.v2);
        const Vector xB = v0 + (v1 - v0) * eRec.t,
                     &xD = scene.camera.cpos;
        Ray ray(xB, (xB - xD).normalized());
        BoundaryEndpointSamplingRecord beRecS;
        if (!scene.trace(ray, &sampler, options.max_bounces, eRec.med_id, beRecS))
            return;
        const Vector &xS = beRecS.p;

        // sanity check
        {
            // make sure the ray is tangent to the surface
            if (edge.f0 >= 0 && edge.f1 >= 0)
            {
                Vector n0 = shape->getGeoNormal(edge.f0),
                       n1 = shape->getGeoNormal(edge.f1);
                Float dotn0 = ray.dir.dot(n0),
                      dotn1 = ray.dir.dot(n1);
                if (math::signum(dotn0) * math::signum(dotn1) > -0.5)
                    return;
            }
            // NOTE prevent intersection with a backface
            if (beRecS.onSurface)
            {
                Float gnDotS = beRecS.n.dot(-ray.dir);
                Float snDotS = beRecS.sh_n.dot(-ray.dir);
                shape = scene.getShape(beRecS.shape_id);
                const BSDF *bsdf = scene.getBSDF(shape->bsdf_id);
                bool success = (bsdf->isTransmissive() &&
                                math::signum(gnDotS) * math::signum(snDotS) > 0.5f) ||
                               (!bsdf->isTransmissive() && gnDotS > Epsilon && snDotS > Epsilon);
                if (!success)
                    return;
            }
        }

        // populate the data in BoundarySamplingRecord eRec
        eRec.dir = -ray.dir; // NOTE
        eRec.onSurface_S = beRecS.onSurface;
        if (beRecS.onSurface)
        {
            eRec.shape_id_S = beRecS.shape_id;
            eRec.tri_id_S = beRecS.tri_id;
            eRec.barycentric3_S = beRecS.barycentric3;
        }
        else
        {
            eRec.med_id_S = beRecS.med_id;
            eRec.tet_id_S = beRecS.tet_id;
            eRec.barycentric4_S = beRecS.barycentric4;
        }

        /* Jacobian determinant that accounts for the change of variable */
        Vector n = (v0 - v1).cross(-ray.dir).normalized();
        n *= -math::signum(n.dot(v2 - v0)); // make sure n points to the visible side
        Float xD_xB = (xB - xD).norm();
        Float xD_xS = (xS - xD).norm();
        Float baseValue = 0; // J * G
        if (beRecS.onSurface)
        {
            Float sB = sinB(xD, xB, (v0 - v1).normalized(), xS);
            Float sD = sinD(xD, xB, (v0 - v1).normalized(), xS, beRecS.n);

            if (sB > Epsilon && sD > Epsilon)
            {
                baseValue = xD_xS / xD_xB * sB / sD;
            }
            baseValue *= abs(beRecS.n.dot(-ray.dir));
        }
        else
        {
            Float sB = sinB(xD, xB, (v0 - v1).normalized(), xS);
            if (sB > Epsilon)
            {
                baseValue = xD_xS / xD_xB * sB;
            }
        }
        assert(baseValue > 0.0);
        /* Sample source path */
        Spectrum value = radiance(scene, Ray(xB, ray.dir), true, &sampler,
                                  scene.getMedium(eRec.med_id), options.max_bounces);
        value *= beRecS.throughput;
        value *= scene.evalTransmittance(xB, true, scene.camera.cpos, false,
                                         scene.getMedium(eRec.med_id), &sampler);
        if (value.isZero(Epsilon))
            return;
        /* Evaluate primary boundary segment */
        handleBoundaryInteraction(xS, scene, d_scene,
                                  eRec, sampler, baseValue * value / eRec.pdf,
                                  options.max_bounces, d_image);
    }

}

VolPrimaryEdgeIntegrator::VolPrimaryEdgeIntegrator(const Scene &scene)
{
    /* generate the edge distribution */
    auto &camera = scene.camera;
    for (size_t i = 0; i < scene.shape_list.size(); i++)
    {
        auto &shape = *scene.shape_list[i];
        if (scene.bsdf_list[shape.bsdf_id]->isNull() &&
            shape.light_id < 0)
            continue;
        for (size_t j = 0; j < shape.edges.size(); j++)
        {
            const Edge &edge = shape.edges[j];
            // if an edge is a boundary edge or a sihoulette edge
            if (edge.mode == 0 ||
                shape.isSihoulette(edge, camera.cpos))
            {
                edge_indices.push_back({i, j});
                edge_dist.append(edge.length);
                continue;
            }
        }
    }
    edge_dist.normalize();
}

void VolPrimaryEdgeIntegrator::configure(const Scene &scene)
{
    edge_indices.clear();
    edge_dist.clear();
    /* generate the edge distribution */
    auto &camera = scene.camera;
    for (size_t i = 0; i < scene.shape_list.size(); i++)
    {
        auto &shape = *scene.shape_list[i];
        if (scene.bsdf_list[shape.bsdf_id]->isNull() &&
            shape.light_id < 0)
            continue;
        for (size_t j = 0; j < shape.edges.size(); j++)
        {
            const Edge &edge = shape.edges[j];
            // if an edge is a boundary edge or a sihoulette edge
            if (edge.mode == 0 ||
                shape.isSihoulette(edge, camera.cpos))
            {
                edge_indices.push_back({i, j});
                edge_dist.append(edge.length);
                continue;
            }
        }
    }
    edge_dist.normalize();
}

ArrayXd VolPrimaryEdgeIntegrator::renderD(SceneAD &sceneAD,
                                          RenderOptions &options, const ArrayXd &__d_image) const
{
    const Scene &scene = sceneAD.val;
    [[maybe_unused]] Scene &d_scene = sceneAD.der;
    GradientManager<Scene> &gm = sceneAD.gm;
    gm.setZero(); // zero multi-thread gradient

    const int nworker = omp_get_num_procs();
    const auto &camera = scene.camera;
    const int nsamples = options.num_samples_primary_edge;
    const int nblocks = std::ceil(static_cast<Float>(camera.getNumPixels()) / (options.block_size * options.block_size));
    const int nblock_samples = options.block_size * options.block_size * nsamples;
    /* init debug info */
    debugInfo = DebugInfo(nworker, camera.getNumPixels(), nsamples);
    if (nsamples <= 0)
        return debugInfo.getArray();

    std::vector<Spectrum> _d_image_spec_list = from_tensor_to_spectrum_list(
        __d_image / nblock_samples / nblocks, camera.getNumPixels());
    Timer _("Primary boundary");

    int blockProcessed = 0;
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < nblocks; i++)
    {
        for (int j = 0; j < nblock_samples; j++)
        {
            const int tid = omp_get_thread_num();
            RndSampler sampler(options.seed, i * nblock_samples + j);
            // sample a point on the boundary
            d_samplePrimaryBoundary(scene, gm.get(tid),
                                    sampler, options,
                                    edge_dist, edge_indices,
                                    _d_image_spec_list);
        }
        if (verbose)
#pragma omp critical
            progressIndicator(static_cast<Float>(++blockProcessed) / nblocks);
    }
    if (verbose)
        std::cout << std::endl;

    // merge d_scene
    gm.merge();

    /* normal related */
#ifdef NORMAL_PREPROCESS
    Timer preprocess_timer("preprocess");
    d_precompute_normal(scene, d_scene);
#endif
    return flattened(debugInfo.getArray());
}
