#include "path.h"
#include "../emitter/area.h"
#include <render/imageblock.h>
#include <core/ray.h>
#include <core/sampler.h>
#include <render/scene.h>
#include <core/timer.h>
#include <iomanip>
#include "algorithm1.h"

// #define USE_TEMPLATE
namespace
{
    template <int depth>
    inline Spectrum Lins(const Intersection &its, const Scene &scene,
                         RndSampler *sampler)
    {
        if constexpr (depth <= 0)
        {
            return Spectrum::Zero();
        }
        else
        {
            if (!its.isValid())
                return Spectrum::Zero();
            Spectrum ret = Spectrum::Zero();
            Spectrum throughput = Spectrum::Ones();

            // Light sampling
            Vector wo;
            DirectSamplingRecord dRec(its);
            Spectrum value = scene.sampleEmitterDirect(sampler->next2D(), dRec);
            if (!value.isZero(Epsilon))
            {
                Spectrum bsdf_val = its.evalBSDF(its.toLocal(dRec.dir));
                Float bsdf_pdf = its.pdfBSDF(its.toLocal(its.toLocal(dRec.dir))) * dRec.G;
                Float mis_weight = square(dRec.pdf) / (square(dRec.pdf) + square(bsdf_pdf));
                ret += throughput * value * bsdf_val * mis_weight;
            }

            // BSDF sampling
            Vector wo_local;
            Float bsdf_pdf, bsdf_eta;
            Spectrum bsdf_weight = its.sampleBSDF(sampler->next3D(), wo_local, bsdf_pdf, bsdf_eta);
            if (bsdf_weight.isZero())
                return ret;
            wo = its.toWorld(wo_local);
            Vector wi = its.toWorld(its.wi);
            if (wi.dot(its.geoFrame.n) * its.wi.z() < Epsilon ||
                wo.dot(its.geoFrame.n) * wo_local.z() < Epsilon)
                return ret;
            Intersection its1;
            if (!scene.rayIntersect(Ray(its.p, wo), true, its1, IntersectionMode::EMaterial))
                return ret;
            Vector dir = its1.toWorld(its1.wi);
            Float G = geometric(its.p, its1.p, its1.geoFrame.n);
            bsdf_pdf *= detach(G);
            throughput *= its.evalBSDF(its.toLocal(-dir)) * G * its1.J / bsdf_pdf;
            if (throughput.isZero(Epsilon))
                return ret;
            if (its1.isEmitter())
            {
                Spectrum le = its1.Le(dir);
                if (!le.isZero(Epsilon))
                {
                    Float pdf_nee = scene.pdfEmitterSample(its1);
                    Float mis_weight = square(bsdf_pdf) / (square(pdf_nee) + square(bsdf_pdf));
                    ret += throughput * le * mis_weight;
                }
            }
            ret += throughput * Lins<depth - 1>(its1, scene, sampler);
            return ret;
        }
    }

    Spectrum Lins(const Intersection &its, const Scene &scene,
                  RndSampler *sampler, int max_depth)
    {
        if (max_depth <= 0 || !its.isValid())
            return Spectrum::Zero();
        else
        {
            Spectrum ret = Spectrum::Zero();
            Spectrum throughput = Spectrum::Ones();

            // Light sampling
            Vector wo;
            DirectSamplingRecord dRec(its);
            Spectrum value = scene.sampleEmitterDirect(sampler->next2D(), dRec);
            if (!value.isZero(Epsilon))
            {
                Vector wo_local = its.toLocal(dRec.dir);
                Spectrum bsdf_val = its.evalBSDF(wo_local);
                Float bsdf_pdf = its.pdfBSDF(wo_local) * dRec.G;
                Float mis_weight = square(dRec.pdf) / (square(dRec.pdf) + square(bsdf_pdf));
                ret += throughput * value * bsdf_val * mis_weight;
            }

            // BSDF sampling
            Vector wo_local;
            Float bsdf_pdf, bsdf_eta;
            Spectrum bsdf_weight = its.sampleBSDF(sampler->next3D(), wo_local, bsdf_pdf, bsdf_eta);
            if (bsdf_weight.isZero())
                return ret;
            wo = its.toWorld(wo_local);
            Vector wi = its.toWorld(its.wi);
            if (wi.dot(its.geoFrame.n) * its.wi.z() < Epsilon ||
                wo.dot(its.geoFrame.n) * wo_local.z() < Epsilon)
                return ret;
            Intersection its1;
            if (!scene.rayIntersect(Ray(its.p, wo), true, its1, IntersectionMode::EMaterial))
                return ret;
            Vector dir = its1.toWorld(its1.wi);
            Float G = geometric(its.p, its1.p, its1.geoFrame.n);
            bsdf_pdf *= detach(G);
            throughput *= its.evalBSDF(its.toLocal(-dir)) * G * its1.J / bsdf_pdf;
            if (throughput.isZero(Epsilon))
                return ret;
            if (its1.isEmitter())
            {
                Spectrum le = its1.Le(dir);
                if (!le.isZero(Epsilon))
                {
                    Float pdf_nee = scene.pdfEmitterSample(its1);
                    Float mis_weight = square(bsdf_pdf) / (square(pdf_nee) + square(bsdf_pdf));
                    ret += throughput * le * mis_weight;
                }
            }

            ret += throughput * Lins(its1, scene, sampler, max_depth - 1);

            return ret;
        }
    }

    Spectrum radiance(const Scene &scene, RndSampler *sampler, Ray &_ray,
                      const Array2i &pixel_idx, const RenderOptions &options)
    {
        Ray ray(_ray);
        Intersection its;
        Spectrum throughput = Spectrum::Ones();
        Spectrum ret = Spectrum::Zero();
        int max_depth = options.max_bounces;
        // Perform the first intersection
        scene.rayIntersect(ray, true, its, IntersectionMode::EMaterial);
        if (!its.isValid())
            return ret;
        throughput *= scene.camera.evalFilter(pixel_idx.x(), pixel_idx.y(), its) * its.J;
        if (its.isEmitter())
            ret += throughput * its.Le(-ray.dir);

#ifdef USE_TEMPLATE
        ret += throughput * Lins<5>(its, scene, sampler);
#else
        ret += throughput * Lins(its, scene, sampler, max_depth);
#endif

        return ret;
    }

    void Li(const Scene &scene, const Array2i &pixel_idx,
            RndSampler *sampler, const RenderOptions &options, Spectrum &res)
    {
        res = Spectrum::Zero();
        const auto &cam = scene.camera;
        Ray ray_primal, ray_dual;
        // antithetic sampling
        cam.samplePrimaryRayFromFilter(pixel_idx.x(), pixel_idx.y(), sampler->next2D(), ray_primal, ray_dual);
        sampler->save();
        res += radiance(scene, sampler, ray_primal, pixel_idx, options);
        sampler->restore();
        res += radiance(scene, sampler, ray_dual, pixel_idx, options);
        res /= 2;
    }

    void pixelColor(const Scene &scene, const Array2i &pixel_idx,
                    RndSampler *sampler, const RenderOptions &options, Spectrum &res)
    {
        int n_samples = options.num_samples;
        res = Spectrum::Zero();
        for (int i = 0; i < n_samples; i++)
        {
            Spectrum val;
            Li(scene, pixel_idx, sampler, options, val);
            res += val;
        }
        res /= n_samples;
    }

    void d_pixelColor(const Scene &scene, Scene &d_scene,
                      const Array2i &pixel_idx, RndSampler *sampler,
                      const RenderOptions &options, Spectrum d_res)
    {
        Spectrum res;
        for (int i = 0; i < options.num_samples; i++)
        {
            [[maybe_unused]] Spectrum _d_res = d_res / options.num_samples;
            __enzyme_autodiff((void *)Li,
                              enzyme_dup, &scene, &d_scene,
                              enzyme_const, &pixel_idx,
                              enzyme_const, sampler,
                              enzyme_const, &options,
                              enzyme_dupnoneed, &res, &_d_res);
        }
    }
} // namespace

void Path::render(const Scene &scene, const RenderOptions &options,
                  ptr<float> rendered_image) const
{
    Timer _("Forward rendering");
    const int nworker = omp_get_num_procs();
    const auto &camera = scene.camera;
    assert(!camera.rect.isValid());
    BlockedImage blocks({camera.width, camera.height}, {16, 16});
    int blockProcessed = 0;
    std::vector<Spectrum> spec_list(camera.getNumPixels(), Spectrum::Zero());
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < blocks.m_BlocksTotal; i++)
    {
        ImageBlock block = blocks.getBlock(i);
        for (Array2i pixelIdx = block.curPixel(); block.hasNext(); pixelIdx = block.nextPixel())
        {
            int pixel_ravel_idx = ravel_multi_index(pixelIdx, {camera.width, camera.height});
            RndSampler sampler(options.seed, pixel_ravel_idx);
            Spectrum pixel_val = Spectrum::Zero();
            pixelColor(scene, pixelIdx, &sampler, options, pixel_val);
            spec_list[pixel_ravel_idx] += pixel_val;
        }
#pragma omp critical
        progressIndicator(static_cast<Float>(++blockProcessed) / blocks.size());
    }
    // from_spectrum_list_to_ptr(spec_list, camera.getNumPixels(), rendered_image);
    std::cout << std::endl;
}

// forward rendering
ArrayXd Path::renderC(const Scene &scene, const RenderOptions &options) const
{
    Timer _("Forward rendering");
    const int nworker = omp_get_num_procs();
    const auto &camera = scene.camera;
    BlockedImage blocks({camera.width, camera.height}, {16, 16});
    std::vector<Spectrum> spec_list(camera.getNumPixels(), Spectrum::Zero());
    int blockProcessed = 0;
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < blocks.m_BlocksTotal; i++)
    {
        ImageBlock block = blocks.getBlock(i);
        for (Array2i pixelIdx = block.curPixel(); block.hasNext(); pixelIdx = block.nextPixel())
        {
            int pixel_ravel_idx = ravel_multi_index(pixelIdx, {camera.width, camera.height});
            RndSampler sampler(options.seed, pixel_ravel_idx);
            Spectrum pixel_val = Spectrum::Zero();
            pixelColor(scene, pixelIdx, &sampler, options, pixel_val);
            spec_list[pixel_ravel_idx] += pixel_val;
        }
        if (verbose)
#pragma omp critical
            progressIndicator(static_cast<Float>(++blockProcessed) / blocks.size());
    }
    if (verbose)
        std::cout << std::endl;
    return from_spectrum_list_to_tensor(spec_list, camera.getNumPixels());
}

ArrayXd Path::renderD(SceneAD &sceneAD,
                      const RenderOptions &options, const ArrayXd &__d_image) const
{
    const Scene &scene = sceneAD.val;
    GradientManager<Scene> &gm = sceneAD.gm;
    Scene &d_scene = sceneAD.der;
    gm.setZero(); // zero multi-thread gradient
    const int nworker = omp_get_num_procs();

    // const int nworker = 1;
    Timer _("Render interior");

    const auto &camera = scene.camera;
    assert(!camera.rect.isValid());
    BlockedImage blocks({camera.width, camera.height}, {16, 16});

    const std::vector<Spectrum> spec_list = from_tensor_to_spectrum_list(__d_image, camera.getNumPixels());
    // gradient image
    std::vector<Spectrum> g_image_spec_list(camera.getNumPixels(), Spectrum::Zero());
#if defined(FORWARD)
    int shape_idx = scene.getShapeRequiresGrad();
#endif

    int blockProcessed = 0;
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < blocks.m_BlocksTotal; i++)
    {
        const int tid = omp_get_thread_num();
        ImageBlock block = blocks.getBlock(i);
        for (Array2i pixelIdx = block.curPixel(); block.hasNext(); pixelIdx = block.nextPixel())
        {
#ifdef FORWARD
            gm.get(tid).shape_list[shape_idx]->param = 0;
            Float param = gm.get(tid).shape_list[shape_idx]->param;
#endif
            int pixel_ravel_idx = ravel_multi_index(pixelIdx, {camera.width, camera.height});
            RndSampler sampler(options.seed, pixel_ravel_idx);
            d_pixelColor(scene, gm.get(tid),
                         pixelIdx, &sampler, options, spec_list[pixel_ravel_idx]);
#ifdef FORWARD
            if (isfinite(gm.get(tid).shape_list[shape_idx]->param))
            {
                param = gm.get(tid).shape_list[shape_idx]->param - param;
                g_image_spec_list[pixel_ravel_idx] += Spectrum(param, 0.f, 0.f);
            }
#endif
        }
        if (verbose)
#pragma omp critical
            progressIndicator(static_cast<Float>(++blockProcessed) / blocks.size());
    }
    if (verbose)
        std::cout << std::endl;

    gm.merge();

    /* normal related */
#ifdef NORMAL_PREPROCESS
    Timer _2("normal preprocess");
    d_precompute_normal(scene, d_scene);
#endif

    return from_spectrum_list_to_tensor(g_image_spec_list, camera.getNumPixels());
}
