#include "vol_boundary.h"
#include <render/scene.h>
#include <core/math_func.h>
#include <core/timer.h>
#include <core/logger.h>
namespace
{
    static DebugInfo debugInfo;
    [[maybe_unused]] void velocity(const Scene &scene,
                                   const BoundarySamplingRecord &bRec,
                                   Float &res)
    {
        const Shape *shape = scene.shape_list[bRec.shape_id];
        const Edge &edge = shape->edges[bRec.edge_id];
        const Vector &xB_0 = shape->getVertex(edge.v0);
        const Vector &xB_1 = shape->getVertex(edge.v1);
        const Vector &xB_2 = shape->getVertex(edge.v2);

        // getPoint xS
        Vector xS = Vector::Zero();
        if (bRec.onSurface_S)
        {
            assert(bRec.shape_id_S != -1);
            assert(bRec.tri_id_S != -1);
            assert(abs(1.0 - bRec.barycentric3_S.sum()) < 1e-5);
            const Shape *shapeS = scene.shape_list[bRec.shape_id_S];
            const auto &indS = shapeS->getIndices(bRec.tri_id_S);
            // getPoint
            for (int i = 0; i < 3; i++)
                xS += bRec.barycentric3_S[i] * shapeS->getVertex(indS[i]);
        }
        else
        {
            assert(bRec.med_id_S != -1);
            assert(bRec.tet_id_S != -1);
            assert(abs(1.0 - bRec.barycentric4_S.sum()) < 1e-5);
            const Medium *medium = scene.medium_list[bRec.med_id_S];
            const Vector4i &indices = medium->getTet(bRec.tet_id_S);

            // getPoint
            for (int i = 0; i < 4; i++)
                xS += bRec.barycentric4_S[i] *
                      medium->getVertex(scene, indices[i]);
        }

        if (bRec.onSurface_D)
        {
            const Shape *shapeD = scene.shape_list[bRec.shape_id_D];
            const auto &indD = shapeD->getIndices(bRec.tri_id_D);
            const Vector &xD_0 = shapeD->getVertex(indD[0]);
            const Vector &xD_1 = shapeD->getVertex(indD[1]);
            const Vector &xD_2 = shapeD->getVertex(indD[2]);
            res = normal_velocity(xS,
                                  xB_0, xB_1, xB_2, bRec.t,
                                  xD_0, xD_1, xD_2);
        }
        else
        {
            assert(bRec.med_id_D != -1);
            assert(bRec.tet_id_D != -1);
            assert(abs(1.0 - bRec.barycentric4_D.sum()) < 1e-5);
            const Medium *medium = scene.medium_list[bRec.med_id_D];
            const Vector4i &indices = medium->getTet(bRec.tet_id_D);
            const Vector &xD_0 = medium->getVertex(scene, indices[0]);
            const Vector &xD_1 = medium->getVertex(scene, indices[1]);
            const Vector &xD_2 = medium->getVertex(scene, indices[2]);
            const Vector &xD_3 = medium->getVertex(scene, indices[3]);
            // getPoint
            Vector xD = Vector::Zero();
            for (int i = 0; i < 4; i++)
                xD += bRec.barycentric4_D[i] *
                      medium->getVertex(scene, indices[i]);
            Float dist = (xD - xS).norm();
            res = normal_velocity(xS,
                                  xB_0, xB_1, xB_2, bRec.t,
                                  xD_0, xD_1, xD_2, xD_3, detach(dist));
        }
    }

    void d_velocity(const Scene &scene, Scene &d_scene,
                    const BoundarySamplingRecord &eRec,
                    Float d_u)
    {
        [[maybe_unused]] Float u;
#if defined(ENZYME) && defined(ENZYME_BOUNDARY_DIRECT)
        __enzyme_autodiff((void *)velocity,
                          enzyme_dup, &scene, &d_scene,
                          enzyme_const, &eRec,
                          enzyme_dupnoneed, &u, &d_u);
#endif
    }

    void handleSurfaceInteraction(const Intersection &its,
                                  const Scene &scene, Scene &d_scene,
                                  BoundarySamplingRecord &eRec,
                                  RndSampler &sampler, const Spectrum &weight,
                                  int max_depth, std::vector<Spectrum> &d_image)
    {
#if defined(FORWARD)
        int shape_idx = scene.getShapeRequiresGrad();
#endif

        Matrix2x4 pixel_uvs;
        Array4 attenuations(0.0);
        Vector dir;
        scene.camera.sampleDirect(its.p, pixel_uvs, attenuations, dir);
        // =====================  Volpath =====================
        Float transmittance = scene.evalTransmittance(
            its.p, true, scene.camera.cpos, false,
            its.getTargetMedium(dir), &sampler);
        if (transmittance < Epsilon)
            return;
        attenuations *= transmittance;
        // ====================================================
        auto bsdf_val = its.evalBSDF(its.toLocal(dir),
                                     EBSDFMode::EImportanceWithCorrection);
        Spectrum value = weight * bsdf_val;
        // handle 4 neighbor pixels
        for (int i = 0; i < 4; i++)
        {
            int pixel_idx = scene.camera.getPixelIndex(pixel_uvs.col(i));
            Float d_u = 0.f;
            if (eRec.pdf < Epsilon ||
                attenuations[i] < Epsilon)
                continue;
            // handle 3 channels
            for (int c = 0; c < 3; c++)
            {
                // chain rule
                d_u += d_image[pixel_idx][c] *
                       (value[c] * attenuations[i]);
            }
#ifdef FORWARD
            Float param = d_scene.shape_list[shape_idx]->param;
#endif
            d_velocity(scene, d_scene, eRec, d_u);
#ifdef FORWARD
            param = d_scene.shape_list[shape_idx]->param - param;
            const int tid = omp_get_thread_num();
            debugInfo.image_per_thread[tid][pixel_idx] += Spectrum(param, 0, 0);
#endif
        }
    }

    void handleMediumInteraction(const Vector &p, const Vector &wi, const Medium *medium,
                                 const Scene &scene, Scene &d_scene,
                                 BoundarySamplingRecord &eRec,
                                 RndSampler &sampler, const Spectrum &weight,
                                 int max_depth, std::vector<Spectrum> &d_image)
    {
#if defined(FORWARD)
        int shape_idx = scene.getShapeRequiresGrad();
#endif

        Matrix2x4 pixel_uvs;
        Array4 attenuations(0.0);
        Vector dir;

        scene.camera.sampleDirect(p, pixel_uvs, attenuations, dir);
        // =====================  Volpath =====================
        Float transmittance = scene.evalTransmittance(
            p, false, scene.camera.cpos, false, medium, &sampler);
        if (transmittance < Epsilon)
            return;
        attenuations *= transmittance;
        // ====================================================
        const PhaseFunction *phase = scene.getPhase(medium->phase_id);
        auto phase_val = phase->eval(wi, dir);
        Spectrum value = weight * phase_val;

        for (int i = 0; i < 4; i++)
        {
            int pixel_idx = scene.camera.getPixelIndex(pixel_uvs.col(i));
            Float d_u = 0.f;
            if (eRec.pdf < Epsilon ||
                attenuations[i] < Epsilon)
                continue;
            for (int c = 0; c < 3; c++)
            {
                // chain rule
                d_u += d_image[pixel_idx][c] *
                       (value[c] * attenuations[i]);
            }
#ifdef FORWARD
            Float param = d_scene.shape_list[shape_idx]->param;
#endif
            d_velocity(scene, d_scene, eRec, d_u);
#ifdef FORWARD
            param = d_scene.shape_list[shape_idx]->param - param;
            const int tid = omp_get_thread_num();
            debugInfo.image_per_thread[tid][pixel_idx] += Spectrum(param, 0, 0);
#endif
        }
    }

    void d_sampleDirectBoundary(const Scene &scene, Scene &d_scene,
                                RndSampler &sampler, const RenderOptions &options,
                                const DiscreteDistribution &edge_dist,
                                const std::vector<Vector2i> &edge_indices,
                                std::vector<Spectrum> &d_image)
    {
        /* Sample a point on the boundary */
        BoundarySamplingRecord eRec;
        scene.sampleEdgePoint(sampler.next1D(),
                              edge_dist, edge_indices,
                              eRec);
        if (eRec.shape_id < 0)
        {
            PSDR_WARN(eRec.shape_id < 0);
            return;
        }
        const Shape *shape = scene.shape_list[eRec.shape_id];
        const Edge &edge = shape->edges[eRec.edge_id];
        assert(edge.f0 >= 0);

        /* Sample point on emitters */
        DirectSamplingRecord dRec(eRec.ref);
        Vector2 rnd_light = sampler.next2D();
        const Medium *medium = scene.getMedium(eRec.med_id);
        Spectrum value = scene.sampleBoundaryAttenuatedEmitterDirect(
            dRec, rnd_light, &sampler, medium);
        if (value.isZero(Epsilon))
            return;
        const Vector xB = eRec.ref,
                     &xS = dRec.p;
        Ray ray(xB, (xB - xS).normalized());

        BoundaryEndpointSamplingRecord beRecD;
        if (!scene.trace(ray, &sampler, options.max_bounces, eRec.med_id, beRecD))
            return;

        // sanity check
        // make sure the ray is tangent to the surface
        if (edge.f0 >= 0 && edge.f1 >= 0)
        {
            Vector n0 = shape->getGeoNormal(edge.f0),
                   n1 = shape->getGeoNormal(edge.f1);
            Float dotn0 = ray.dir.dot(n0),
                  dotn1 = ray.dir.dot(n1);
            if (math::signum(dotn0) * math::signum(dotn1) > -0.5)
                return;
        }
        // NOTE prevent intersection with a backface
        if (beRecD.onSurface)
        {
            Float gnDotD = beRecD.n.dot(-ray.dir);
            Float snDotD = beRecD.sh_n.dot(-ray.dir);
            shape = scene.getShape(beRecD.shape_id);
            const BSDF *bsdf = scene.getBSDF(shape->bsdf_id);
            bool success = (bsdf->isTransmissive() &&
                            math::signum(gnDotD) * math::signum(snDotD) > 0.5f) ||
                           (!bsdf->isTransmissive() && gnDotD > 0.01 && snDotD > 0.01);
            if (!success)
                return;
        }
        // populate the data in BoundarySamplingRecord eRec
        eRec.dir = -ray.dir; // NOTE
        eRec.shape_id_S = dRec.shape_id;
        eRec.tri_id_S = dRec.tri_id;
        eRec.barycentric3_S = Vector(1. - dRec.barycentric.sum(),
                                     dRec.barycentric[0],
                                     dRec.barycentric[1]);
        eRec.onSurface_D = beRecD.onSurface;
        if (beRecD.onSurface)
        {
            eRec.shape_id_D = beRecD.shape_id;
            eRec.tri_id_D = beRecD.tri_id;
            eRec.barycentric3_D = beRecD.barycentric3;
        }
        else
        {
            eRec.med_id_D = beRecD.med_id;
            eRec.tet_id_D = beRecD.tet_id;
            eRec.barycentric4_D = beRecD.barycentric4;
        }

        /* Jacobian determinant that accounts for the change of variable */
        Vector v0 = shape->getVertex(edge.v0);
        Vector v1 = shape->getVertex(edge.v1);
        Vector v2 = shape->getVertex(edge.v2);
        const Vector &xD = beRecD.p;
        Vector n = (v0 - v1).cross(ray.dir).normalized();
        n *= -math::signum(n.dot(v2 - v0)); // make sure n points to the visible side
        Float baseValue = 0;
        if (beRecD.onSurface)
        {
            Float J = dlD_dlB(xS,
                              xB, (v0 - v1).normalized(),
                              xD, beRecD.n) *
                      dA_dw(xB, xS, dRec.n);
            Float G = geometric(xD, beRecD.n,
                                xS, dRec.n);
            baseValue = J * G;
        }
        else
        {
            Float J = dASdAD_dlBdwBdrD(xS, dRec.n,
                                       xB, (v0 - v1).normalized(),
                                       xD);
            Float G = geometric(xD,
                                xS, dRec.n);
            baseValue = J * G;
        }
        if (std::abs(baseValue) < Epsilon)
            return;
        // assert(baseValue > -Epsilon);

        /* Sample detector path */
        Spectrum throughput(1.0f);
        throughput *= beRecD.throughput; // NOTE
        Ray ray_sensor;
        bool inside_med = !beRecD.onSurface;
        Intersection its;
        MediumSamplingRecord mRec;
        if (beRecD.onSurface)
            its = beRecD.its;
        else
        {
            mRec = beRecD.mRec;
            medium = mRec.medium;
        }
        // max_bounces: include the current bounce
        int max_bounces = options.max_bounces - dRec.interactions - beRecD.interactions;
        for (int depth = 0; depth < max_bounces; depth++)
        {
            if (inside_med)
            {
                handleMediumInteraction(
                    mRec.p, mRec.wi, mRec.medium,
                    scene, d_scene,
                    eRec, sampler, value * throughput * baseValue / eRec.pdf,
                    max_bounces - depth - 1, d_image);

                const PhaseFunction *phase = scene.getPhase(mRec.medium->phase_id);
                Vector wo;
                Float phaseVal = phase->sample(-ray_sensor.dir, sampler.next2D(), wo);
                if (phaseVal == 0)
                    return;
                value *= phaseVal;
                ray_sensor = Ray(mRec.p, wo);
            }
            else
            {
                if (!its.isValid())
                    break;
                if (!its.getBSDF()->isNull())
                {
                    handleSurfaceInteraction(
                        its,
                        scene, d_scene,
                        eRec, sampler, value * throughput * baseValue / eRec.pdf,
                        max_bounces - depth - 1, d_image);
                }
                Vector wo_local, wo;
                Float bsdf_pdf, bsdf_eta;
                Spectrum bsdf_weight = its.sampleBSDF(sampler.next3D(), wo_local,
                                                      bsdf_pdf, bsdf_eta,
                                                      EBSDFMode::EImportanceWithCorrection);
                if (bsdf_weight.isZero())
                    break;
                wo = its.toWorld(wo_local);
                Vector wi = its.toWorld(its.wi);
                Float wiDotGeoN = wi.dot(its.geoFrame.n),
                      woDotGeoN = wo.dot(its.geoFrame.n);
                if (wiDotGeoN * its.wi.z() <= 0 ||
                    woDotGeoN * wo_local.z() <= 0)
                    break;
                value *= bsdf_weight;
                if (its.isMediumTransition())
                {
                    medium = its.getTargetMedium(wo);
                }
                ray_sensor = Ray(its.p, wo);
            }
            scene.rayIntersect(ray_sensor, !inside_med, its);
            inside_med = medium != nullptr &&
                         medium->sampleDistance(ray_sensor, its.t, &sampler, mRec);

            if (inside_med)
                throughput *= mRec.sigmaS * mRec.transmittance / mRec.pdfSuccess;
            else
            {
                if (medium)
                    throughput *= mRec.transmittance / mRec.pdfFailure;
            }
        }
    }
} // namespace

VolDirectEdgeIntegrator::VolDirectEdgeIntegrator(const Properties &props) {
    // auto scene = dynamic_cast<const Scene *>(props.get<Object*>("scene"));
    // if (!scene)
    //     throw std::runtime_error("VolPrimaryEdgeIntegrator: No scene was specified!");
    // configure(*scene);
}

VolDirectEdgeIntegrator::VolDirectEdgeIntegrator(const Scene &scene)
{
    /* generate the edge distribution */
    for (size_t i = 0; i < scene.shape_list.size(); i++)
    {
        auto &shape = *scene.shape_list[i];
        const BSDF &bsdf = *scene.bsdf_list[shape.bsdf_id];
        if (bsdf.isNull() &&
            shape.light_id < 0)
            continue;
        for (size_t j = 0; j < shape.edges.size(); j++)
        {
            const Edge &edge = shape.edges[j];
            // if an edge is a boundary edge or a sihoulette edge
            if (edge.mode == 0 ||
                bsdf.isTransmissive())
            {
                edge_indices.push_back({i, j});
                edge_dist.append(edge.length * 4 * M_PI); // area of a unit sphere
                continue;
            }
            if (edge.mode > 0)
            {
                const Vector &n0 = shape.getGeoNormal(edge.f0),
                             &n1 = shape.getGeoNormal(edge.f1);
                Float cos_angle = n0.dot(n1);
                edge_indices.push_back({i, j});
                edge_dist.append(edge.length * 4 * std::acos(cos_angle));
                continue;
            }
        }
    }
    edge_dist.normalize();
}

void VolDirectEdgeIntegrator::configure(const Scene &scene)
{
    PSDR_INFO("Configuring VolDirectEdgeIntegrator");
    edge_dist.clear();
    edge_indices.clear();
    /* generate the edge distribution */
    for (size_t i = 0; i < scene.shape_list.size(); i++)
    {
        auto &shape = *scene.shape_list[i];
        const BSDF &bsdf = *scene.bsdf_list[shape.bsdf_id];
        if (bsdf.isNull() &&
            shape.light_id < 0)
            continue;
        for (size_t j = 0; j < shape.edges.size(); j++)
        {
            const Edge &edge = shape.edges[j];
            // if an edge is a boundary edge or a sihoulette edge
            if (edge.mode == 0 ||
                bsdf.isTransmissive())
            {
                edge_indices.push_back({i, j});
                edge_dist.append(edge.length * 4 * M_PI); // area of a unit sphere
                continue;
            }
            if (edge.mode > 0)
            {
                const Vector &n0 = shape.getGeoNormal(edge.f0),
                             &n1 = shape.getGeoNormal(edge.f1);
                Float cos_angle = n0.dot(n1);
                edge_indices.push_back({i, j});
                edge_dist.append(edge.length * 4 * std::acos(cos_angle));
                continue;
            }
        }
    }
    edge_dist.normalize();
}

ArrayXd VolDirectEdgeIntegrator::renderD(
    SceneAD &sceneAD, RenderOptions &options, const ArrayXd &__d_image) const
{
    PSDR_INFO("Rendering VolDirectEdgeIntegrator with spp = {}", options.num_samples_secondary_edge_direct);
    const Scene &scene = sceneAD.val;
    [[maybe_unused]] Scene &d_scene = sceneAD.der;
    GradientManager<Scene> &gm = sceneAD.gm;
    gm.setZero(); // zero multi-thread gradient

    const int nworker = omp_get_num_procs();
    const auto &camera = scene.camera;
    const int nsamples = options.num_samples_secondary_edge_direct;
    const int nblocks = std::ceil(static_cast<Float>(camera.getNumPixels()) / (options.block_size * options.block_size));
    const int nblock_samples = options.block_size * options.block_size * nsamples;
    /* init debug info */
    debugInfo = DebugInfo(nworker, camera.getNumPixels(), nsamples);
    if (nsamples <= 0)
        return debugInfo.getArray();

    std::vector<Spectrum> _d_image_spec_list = from_tensor_to_spectrum_list(
        __d_image / nblock_samples / nblocks, camera.getNumPixels());

    Timer _("Direct boundary");
    int blockProcessed = 0;
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < nblocks; i++)
    {
        for (int j = 0; j < nblock_samples; j++)
        {
            const int tid = omp_get_thread_num();
            RndSampler sampler(options.seed, i * nblock_samples + j);
            // sample a point on the boundary
            d_sampleDirectBoundary(scene, gm.get(tid),
                                   sampler, options,
                                   edge_dist, edge_indices,
                                   _d_image_spec_list);
        }
        if (verbose)
#pragma omp critical
            progressIndicator(static_cast<Float>(++blockProcessed) / nblocks);
    }
    if (verbose)
        std::cout << std::endl;

    // merge gradient to d_scene
    gm.merge();
    d_scene.configureD(scene);
    /* normal related */
#ifdef NORMAL_PREPROCESS
    d_precompute_normal(scene, d_scene);
#endif
    return flattened(debugInfo.getArray());
}
