#include "volpath2.h"
#include <render/common.h>
#include <render/imageblock.h>
#include <core/ray.h>
#include <core/sampler.h>
#include <render/scene.h>
#include <core/timer.h>
#include <iomanip>
#include <algorithm1.h>
#include <core/statistics.h>
#include <render/spiral.h>
#include <fmt/core.h>
#include <core/math_func.h>
#include "volpathBase.h"
#include <fmt/color.h>
namespace volpath2_meta
{
    Spectrum __Li(const Scene &scene, const Ray &_ray, const RadianceQueryRecord &rRec, LightPath *path)
    {
        if (path)
        {
            path->clear(rRec.pixel_idx);
            path->append(scene.camera); // NOTE
        }

        Spectrum ret = Spectrum::Zero(),
                 throughput = Spectrum::Ones();

        // FIXME: assume the camera is outside any shapes
        int med_id = rRec.med_id;
        const Medium *medium = (med_id != -1 ? scene.getMedium(med_id) : nullptr);

        Float pdfFailure = 1.; // keep track of the pdf of hitting a surface
        Float pdfSuccess = 1.; // keep track of the pdf of hitting a medium

        Vector preX = scene.camera.cpos;
        bool incEmission = true;

        Ray ray(_ray);
        Intersection its;
        scene.rayIntersect(ray, false, its);

        MediumSamplingRecord mRec;
        RndSampler *sampler = rRec.sampler;
        const int max_bounces = rRec.max_bounces;
        int depth = 0, null_interations = 0;
        while (depth <= max_bounces && null_interations < max_null_interactions)
        {
            bool inside_med = medium && medium->sampleDistance(ray, its.t, sampler, mRec);
            if (inside_med)
            {
                // sampled a medium interaction

                if (depth >= max_bounces)
                    break;

                if (path)
                {
                    if (incEmission)
                        path->append({mRec, med_id, pdfSuccess * mRec.pdfSuccess}); // NOTE
                    else
                        path->append({mRec, med_id, pdfSuccess * mRec.pdfSuccess * geometric(preX, mRec.p)}); // NOTE
                }

                const PhaseFunction *phase = scene.phase_list[medium->phase_id];
                throughput *= mRec.sigmaS * mRec.transmittance / mRec.pdfSuccess;

                // ====================== emitter sampling =========================
                DirectSamplingRecord dRec(mRec.p);
                Spectrum value = scene.sampleAttenuatedEmitterDirect(
                    dRec, sampler->next2D(), sampler, mRec.medium);
                if (!value.isZero(Epsilon))
                {
                    Float phaseVal = phase->eval(-ray.dir, dRec.dir);
                    if (phaseVal > Epsilon)
                    {
                        Float phasePdf = phase->pdf(-ray.dir, dRec.dir);
                        Float mis_weight = miWeight(dRec.pdf / dRec.G, phasePdf);
                        ret += throughput * value * phaseVal * mis_weight;

                        if (path)
                            path->append_nee({dRec, dRec.pdf / mis_weight}); // NOTE
                    }
                }

                // ====================== phase sampling =============================
                Vector wo;
                Float phaseVal = phase->sample(-ray.dir, sampler->next2D(), wo);
                Float phasePdf = phase->pdf(-ray.dir, wo);
                if (phaseVal < Epsilon)
                    break;

                throughput *= phaseVal;
                pdfFailure = phasePdf;
                pdfSuccess = phasePdf;
                ray = Ray(mRec.p, wo);

                value = scene.rayIntersectAndLookForEmitter(
                    ray, false, sampler, mRec.medium, its, dRec);
                if (!value.isZero(Epsilon))
                {
                    Float pdf_emitter = scene.pdfEmitterDirect(dRec);
                    Float mis_weight = miWeight(phasePdf, pdf_emitter / dRec.G);
                    ret += throughput * value * mis_weight;

                    if (path)
                        path->append_bsdf({dRec, phasePdf * dRec.G / mis_weight}); // NOTE
                }

                // update loop variables
                incEmission = false;
                preX = mRec.p;
                depth++;
            }
            else
            {
                // sampled a surface interaction

                if (medium)
                {
                    pdfFailure *= mRec.pdfFailure;
                    pdfSuccess *= mRec.pdfFailure;
                    throughput *= mRec.transmittance / mRec.pdfFailure;
                }
                if (!its.isValid())
                    break;

                if (its.isEmitter() && incEmission)
                    ret += throughput * its.Le(-ray.dir);

                if (depth >= max_bounces)
                    break;

                // ====================== emitter sampling =========================
                DirectSamplingRecord dRec(its);
                if (!its.getBSDF()->isNull())
                {
                    if (path)
                    {
                        if (incEmission)
                            path->append({its, pdfFailure}); // NOTE
                        else
                            path->append({its, pdfFailure * geometric(preX, its.p, its.geoFrame.n)}); // NOTE
                    }

                    Spectrum value = scene.sampleAttenuatedEmitterDirect(
                        dRec, its, sampler->next2D(), sampler, medium);

                    if (!value.isZero(Epsilon))
                    {
                        Spectrum bsdfVal = its.evalBSDF(its.toLocal(dRec.dir));
                        Float bsdfPdf = its.pdfBSDF(its.toLocal(dRec.dir));
                        Float mis_weight = miWeight(dRec.pdf / dRec.G, bsdfPdf);
                        ret += throughput * value * bsdfVal * mis_weight;
                        if (path)
                            path->append_nee({dRec, dRec.pdf / mis_weight}); // NOTE
                    }
                }

                // ====================== BSDF sampling =============================
                Vector wo;
                Float bsdfPdf, bsdfEta;
                Spectrum bsdfWeight = its.sampleBSDF(sampler->next3D(), wo, bsdfPdf, bsdfEta);
                if (bsdfWeight.isZero(Epsilon))
                    break;

                wo = its.toWorld(wo);
                ray = Ray(its.p, wo);

                throughput *= bsdfWeight;
                if (its.isMediumTransition())
                {
                    med_id = its.getTargetMediumId(wo);
                    medium = its.getTargetMedium(wo);
                }
                if (its.getBSDF()->isNull())
                {
                    scene.rayIntersect(ray, true, its);
                    null_interations++;
                }
                else
                {
                    pdfFailure = bsdfPdf;
                    pdfSuccess = bsdfPdf;
                    Spectrum value = scene.rayIntersectAndLookForEmitter(
                        ray, true, sampler, medium, its, dRec);
                    if (!value.isZero(Epsilon))
                    {
                        Float mis_weight = miWeight(bsdfPdf, dRec.pdf / dRec.G);
                        ret += throughput * value * mis_weight;
                        if (path)
                            path->append_bsdf({dRec, bsdfPdf * dRec.G / mis_weight}); // NOTE
                    }
                    incEmission = false;
                    preX = ray.org;
                    depth++;
                }
            }
        }
        if (null_interations == max_null_interactions)
        {
            if (verbose)
                fprintf(stderr, "Max null interactions (%d) reached. Dead loop?\n", max_null_interactions);
            // Statistics::getInstance().getCounter("Warning", "Null interactions") += 1;
        }
        return ret;
    }

    void velocity(const Scene &scene, const Intersection &its, Float &res)
    {
        auto [x, n, J] = scene.getPoint(its);
        res = (x - detach(x)).dot(detach(n));
    }

    void d_velocity(SceneAD &sceneAD, const Intersection &its, Float d_u)
    {
        auto &d_scene = sceneAD.getDer();
        [[maybe_unused]] Float u;
#if defined(ENZYME)
        __enzyme_autodiff((void *)velocity,
                          enzyme_dup, &sceneAD.val, &d_scene,
                          enzyme_const, &its,
                          enzyme_dup, &u, &d_u);
#endif
    }
} // namespace volpath_meta

Spectrum VolpathInterior::Li(const Scene &scene, const Ray &ray, RadianceQueryRecord &rRec) const
{
    return volpath2_meta::__Li(scene, ray, rRec, nullptr);
}

__attribute__((optnone)) void VolpathInterior::LiAD(SceneAD &sceneAD, const Ray &ray, RadianceQueryRecord &rRec, const Spectrum &d_res) const
{
    LightPath path;
    Spectrum value = volpath2_meta::__Li(sceneAD.val, ray, rRec, &path);
    if (value.isZero(Epsilon))
        return;

    LightPathAD pathAD(path);
    algorithm1_vol::d_eval(sceneAD.val, sceneAD.getDer(), pathAD, d_res, rRec.sampler);
}

// ================================================================================
//                              BoundaryUnidirectional
// ================================================================================
Spectrum VolpathBoundary::Li(const Scene &scene, const Ray &ray, RadianceQueryRecord &rRec) const
{
    assert(false);
    return Spectrum::Zero();
}

void VolpathBoundary::LiAD(SceneAD &sceneAD, const Ray &ray, RadianceQueryRecord &rRec, const Spectrum &d_res) const
{
    RadianceTracer3 rt{sceneAD, d_res};
    // handle the first branch out
    rt.handleBoundary(sceneAD.val, rRec.sampler, nullptr, ray, rRec.max_bounces, Spectrum::Ones());
    // handle the following branch out
    rt.Li(sceneAD.val, rRec.sampler, nullptr, ray, rRec.max_bounces, rRec.incEmission);
}

// =============================================================================
//                              VolpathMerged
// =============================================================================

Spectrum VolpathMerged::Li(const Scene &scene, const Ray &ray, RadianceQueryRecord &rRec) const
{
    return volpathInterior.Li(scene, ray, rRec) + volpathBoundary.Li(scene, ray, rRec);
}

void VolpathMerged::LiAD(SceneAD &sceneAD, const Ray &ray, RadianceQueryRecord &rRec, const Spectrum &d_res) const
{
    sceneAD.val.props.set("is_get_point", false);
    volpathInterior.LiAD(sceneAD, ray, rRec, d_res);
    volpathBoundary.LiAD(sceneAD, ray, rRec, d_res);
}

ArrayXd Volpath2::renderC(const Scene &scene, const RenderOptions &options) const
{
    return volpathInterior.renderC(scene, options) +
           volpathBoundary.renderC(scene, options);
}

ArrayXd Volpath2::renderD(SceneAD &sceneAD, const RenderOptions &options, const ArrayXd &d_image) const
{
    sceneAD.val.props.set("is_get_point", false);
    return volpathInterior.renderD(sceneAD, options, d_image) +
           volpathBoundary.renderD(sceneAD, options, d_image);
}

// =============================================================================
//                              BoundaryBidirectional
// =============================================================================
BoundaryBidirectional::BoundaryBidirectional(const Scene &scene)
{
    shapeDistribution = buildShapeDistribution(scene);
}

DiscreteDistribution BoundaryBidirectional::buildShapeDistribution(const Scene &scene) const
{
    DiscreteDistribution dist;
    for (const auto &shape : scene.shape_list)
    {
        // FIXME: check if the shape contains a medium, not sure this is a valid check in a complex scene
        if (shape->isMediumTransition())
            dist.append(shape->getArea());
        else
            dist.append(0);
    }
    dist.normalize();
    return dist;
}

std::pair<Intersection, Float> BoundaryBidirectional::sampleBoundaryPoint(
    const Scene &scene, const Array2 &_rnd) const
{
    Intersection its;
    Array2 rnd(_rnd);
    Float pdf;
    int shape_id = shapeDistribution.sampleReuse(rnd[0], pdf);
    PositionSamplingRecord pRec;
    const Shape *shape = scene.getShape(shape_id);
    int tri_id = shape->samplePosition(rnd, pRec);
    its.indices[0] = shape_id;
    its.indices[1] = tri_id;
    its.barycentric = pRec.uv;
    pdf /= shape->getArea();
    its.ptr_shape = shape;
    its.ptr_bsdf = scene.getBSDF(shape->bsdf_id);
    its.ptr_med_ext = scene.getMedium(shape->med_ext_id);
    its.ptr_med_int = scene.getMedium(shape->med_int_id);
    its.p = pRec.p;
    its.geoFrame = its.shFrame = Frame(pRec.n);
    its.pdf = pdf;
    its.J = 1.0;
    return {its, pdf};
}

void BoundaryBidirectional::sampleBoundary(SceneAD &sceneAD, VolumeBoundaryQueryRecord &bRec) const
{
    const Scene &scene = sceneAD.val;
    RndSampler *sampler = bRec.sampler;
    int max_bounces = bRec.max_bounces;
    // 1. sample a point on the volume boundary
    auto [its_b, pdf_b] = sampleBoundaryPoint(scene, sampler->next2D());
    const Medium *medium = its_b.ptr_med_int;
    assert(medium);

    const PhaseFunction *phase = scene.getPhase(medium->phase_id);
    Vector wi(0., 1., 0.);
    Vector wo;
    phase->sample(wi, sampler->next2D(), wo);
    // 2. sample the source subpath

    // FIXME : wi
    Vector dir_in = -its_b.geoFrame.n;
    Vector p_offset = its_b.p + ShadowEpsilon * dir_in;
    RadianceTracer rt;

    auto [lins, radiances] = rt.sampleSource(scene, sampler, medium, p_offset, wo, max_bounces);
    // 3. sample the detector subpath
    // 4. merge the two subpaths
    // 5. compute the boundary normal velocity
    // 6. compute the boundary term
    // 7. accumulate the boundary term to the gradient image

    // without nee
    // {
    //     ImportanceTracer2 pt{sceneAD, its_b, max_bounces, bRec.d_image, lins, bRec.grad_image};
    //     pt.handleMedium(scene, sampler, medium, p_offset, 1, Spectrum::Ones() * medium->sigS(p_offset) / pdf_b);

    //     ImportanceTracer2 pt2{sceneAD, its_b, max_bounces, bRec.d_image, lins * medium->sigS(p_offset) / pdf_b / INV_FOURPI, bRec.grad_image};
    //     pt2.importance(scene, sampler, medium, Ray{p_offset, wo}, false, 1, max_bounces, Spectrum::Ones());
    // }
    // with nee
    {
        // connect to camera directly
        ImportanceTracer pt{sceneAD, its_b, max_bounces, bRec.d_image, radiances, bRec.grad_image};
        pt.handleMedium(scene, sampler, medium, p_offset, 1, Spectrum::Ones() * medium->sigS(p_offset) / pdf_b);
        // trace importance
        ImportanceTracer pt2{sceneAD, its_b, max_bounces, bRec.d_image, radiances, bRec.grad_image};
        pt2.importance(scene, sampler, medium, Ray{p_offset, wo}, false, 1, max_bounces, Spectrum::Ones() * medium->sigS(p_offset) / pdf_b / INV_FOURPI);
    }
}

ArrayXd BoundaryBidirectional::renderD(SceneAD &sceneAD, const RenderOptions &options, const ArrayXd &_d_image) const
{
    const Scene &scene = sceneAD.val;
    [[maybe_unused]] Scene &d_scene = sceneAD.der;
    GradientManager<Scene> &gm = sceneAD.gm;
    gm.setZero(); // zero multi-thread gradient

    const int nworker = omp_get_num_procs();
    const auto &camera = scene.camera;
    Spiral spiral(camera.getCropSize(), camera.getOffset(), 8 /* block_size */);
    ImageBlock d_image = ImageBlock(camera.getOffset(),
                                    camera.getCropSize(),
                                    _d_image / options.num_samples / camera.getCropSize().prod());
    ImageBlock grad_image = ImageBlock(camera.getOffset(), camera.getCropSize());
    ThreadManager thread_manager(grad_image, nworker);
    Timer _("Boundary Bidirectional");

#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < spiral.block_count(); i++)
    {
        auto [offset, size, block_id] = spiral.next_block();
        ImageBlock block(offset, size);
        const int tid = omp_get_thread_num();
        for (Array2i pixelIdx = block.curPixel(); block.hasNext(); pixelIdx = block.nextPixel())
        {
            int pixel_ravel_idx = ravel_multi_index(pixelIdx, camera.getCropSize());
            RndSampler sampler(options.seed, pixel_ravel_idx);
            for (int j = 0; j < options.num_samples; j++)
            {
                VolumeBoundaryQueryRecord bRec{&sampler, options.max_bounces, d_image, thread_manager.get(tid)};
                sampleBoundary(sceneAD, bRec);
            }
        }
        if (verbose)
#pragma omp critical
            progressIndicator(static_cast<Float>(spiral.block_counter()) / spiral.block_count());
    }
    if (verbose)
        std::cout << std::endl;

    // merge d_scenes
    gm.merge();
    thread_manager.merge();
    /* normal related */
#ifdef NORMAL_PREPROCESS
    Timer preprocess_timer("preprocess");
    d_precompute_normal(scene, d_scene);
#endif
    return grad_image.flattened();
}

ArrayXd BoundaryBidirectional::renderC(const Scene &scene, const RenderOptions &options) const
{
    return ArrayXd::Zero(1, 3);
}