#include "direct.h"
#include "../emitter/area.h"
#include <render/imageblock.h>
#include <core/ray.h>
#include <core/sampler.h>
#include <render/scene.h>
#include <core/timer.h>
#include <iomanip>

namespace
{
    Spectrum radiance(const Scene &scene, RndSampler *sampler, Ray &ray,
                      const Array2i &pixel_idx)
    {
        Intersection its;
        Float throughput = 1.;
        Spectrum ret = Spectrum::Zero();
        // Perform the first intersection
        scene.rayIntersect(ray, true, its, IntersectionMode::EMaterial);
        if (its.isValid())
        {
            throughput *= scene.camera.evalFilter(pixel_idx.x(), pixel_idx.y(), its);
            if (its.isEmitter())
                ret += its.Le(-ray.dir) * its.J;
            // Direct illumination
            DirectSamplingRecord dRec(its);
            auto value = scene.sampleEmitterDirect(sampler->next2D(), dRec);
            if (!value.isZero())
            {
#ifdef MIS
                auto bsdf_val = its.evalBSDF(its.toLocal(dRec.dir));
                Float bsdf_pdf = its.pdfBSDF(its.toLocal(dRec.dir)) * dRec.G;

                Float mis_weight = square(dRec.pdf) / (square(dRec.pdf) + square(bsdf_pdf));
                ret += value * bsdf_val * mis_weight;
#else
                auto bsdf_val = its.evalBSDF(its.toLocal(dRec.dir));
                ret += value * bsdf_val * its.J;
#endif
            }

#ifdef MIS
            Float bsdf_pdf, bsdf_eta;
            Vector wo;
            auto bsdf_weight = its.sampleBSDF(sampler->next3D(), wo, bsdf_pdf, bsdf_eta);
            if (!bsdf_weight.isZero())
            {
                wo = its.toWorld(wo);
                Ray ray(its.p, wo);
                Intersection its_light;
                if (scene.rayIntersect(ray, true, its_light, IntersectionMode::EMaterial))
                {
                    if (its_light.isEmitter())
                    {
                        Float G = geometric(ray.org, its_light.p, its_light.geoFrame.n);
                        bsdf_pdf *= detach(G); // area measure
                        auto le = its_light.Le(-ray.dir) * its_light.J * G;
                        if (!le.isZero())
                        {
                            bsdf_weight = its.evalBSDF(its.toLocal(its_light.p - its.p).normalized());
                            Float pdf_nee = scene.pdfEmitterSample(its_light);
                            Float mis_weight = square(bsdf_pdf) / (square(pdf_nee) + square(bsdf_pdf));
                            ret += bsdf_weight * le * mis_weight * its.J / bsdf_pdf;
                        }
                    }
                }
            }
#endif
        }
        return throughput * ret;
    }

    void Li(const Scene &scene, const Array2i &pixel_idx,
            RndSampler *sampler, int n_samples, Spectrum &res)
    {
        res = Spectrum::Zero();
        const auto &cam = scene.camera;
        Ray ray_primal, ray_dual;
        // antithetic sampling
        cam.samplePrimaryRayFromFilter(pixel_idx.x(), pixel_idx.y(), sampler->next2D(), ray_primal, ray_dual);
        sampler->save();
        res += radiance(scene, sampler, ray_primal, pixel_idx);
        sampler->restore();
        res += radiance(scene, sampler, ray_dual, pixel_idx);
        res /= 2;
    }

    void pixelColor(const Scene &scene, const Array2i &pixel_idx,
                    RndSampler *sampler, int n_samples, Spectrum &res)
    {
        res = Spectrum::Zero();
        for (int i = 0; i < n_samples; i++)
        {
            Spectrum val;
            Li(scene, pixel_idx, sampler, n_samples, val);
            res += val;
        }
        res /= n_samples;
    }

    void d_pixelColor(const Scene &scene, Scene &d_scene,
                      const Array2i &pixel_idx, RndSampler *sampler,
                      int n_sample, Spectrum d_res)
    {
        Spectrum res;
        for (int i = 0; i < n_sample; i++)
        {
            [[maybe_unused]] Spectrum _d_res = d_res / n_sample;
            __enzyme_autodiff((void *)Li,
                              enzyme_dup, &scene, &d_scene,
                              enzyme_const, &pixel_idx,
                              enzyme_const, sampler,
                              enzyme_const, n_sample,
                              enzyme_dupnoneed, &res, &_d_res);
        }
    }
}

Spectrum Direct::radiance(const Scene &scene, RndSampler *sampler, Ray &ray,
                          const Array2i &pixel_idx) const
{
    return ::radiance(scene, sampler, ray, pixel_idx);
}

Spectrum Direct::Li(const Scene &scene, RndSampler *sampler, const Array2i &pixel_idx) const
{
    Spectrum ret = Spectrum::Zero();
    ::Li(scene, pixel_idx, sampler, 1, ret);
    return ret;
}

void Direct::d_Li(SceneAD &sceneAD, const Array2i &pixel_idx,
                  RndSampler *sampler, int n_samples, Spectrum &d_res) const
{
    Spectrum res;
    __enzyme_autodiff((void *)::Li,
                      enzyme_dup, &sceneAD.val, &sceneAD.der,
                      enzyme_const, &pixel_idx,
                      enzyme_const, sampler,
                      enzyme_const, n_samples,
                      enzyme_dupnoneed, &res, &d_res);
}

void Direct::render(const Scene &scene, const RenderOptions &options,
                    ptr<float> rendered_image) const
{
    Timer _("Forward rendering");
    const int nworker = omp_get_num_procs();
    const auto &camera = scene.camera;
    assert(!camera.rect.isValid());
    BlockedImage blocks({camera.width, camera.height}, {16, 16});
    int blockProcessed = 0;
    std::vector<Spectrum> spec_list(camera.getNumPixels(), Spectrum::Zero());
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < blocks.m_BlocksTotal; i++)
    {
        ImageBlock block = blocks.getBlock(i);
        for (Array2i pixelIdx = block.curPixel(); block.hasNext(); pixelIdx = block.nextPixel())
        {
            int pixel_ravel_idx = ravel_multi_index(pixelIdx, {camera.width, camera.height});
            RndSampler sampler(options.seed, pixel_ravel_idx);
            Spectrum pixel_val = Spectrum::Zero();
            pixelColor(scene, pixelIdx, &sampler, options.num_samples, pixel_val);
            spec_list[pixel_ravel_idx] += pixel_val;
        }
        if (verbose)
#pragma omp critical
            progressIndicator(static_cast<Float>(++blockProcessed) / blocks.size());
    }
    if (verbose)
        std::cout << std::endl;
}

ArrayXd Direct::renderC(const Scene &scene, const RenderOptions &options) const
{
    Timer _("Forward rendering");
    const int nworker = omp_get_num_procs();
    const auto &camera = scene.camera;
    BlockedImage blocks({camera.width, camera.height}, {16, 16});
    std::vector<Spectrum> spec_list(camera.getNumPixels(), Spectrum::Zero());
    int blockProcessed = 0;
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < blocks.m_BlocksTotal; i++)
    {
        ImageBlock block = blocks.getBlock(i);
        for (Array2i pixelIdx = block.curPixel(); block.hasNext(); pixelIdx = block.nextPixel())
        {
            int pixel_ravel_idx = ravel_multi_index(pixelIdx, {camera.width, camera.height});
            RndSampler sampler(options.seed, pixel_ravel_idx);
            Spectrum pixel_val = Spectrum::Zero();
            pixelColor(scene, pixelIdx, &sampler, options.num_samples, pixel_val);
            spec_list[pixel_ravel_idx] += pixel_val;
        }
        if (verbose)
#pragma omp critical
            progressIndicator(static_cast<Float>(++blockProcessed) / blocks.size());
    }
    if (verbose)
        std::cout << std::endl;

    return from_spectrum_list_to_tensor(spec_list, camera.getNumPixels());
}

ArrayXd Direct::renderD(SceneAD &sceneAD,
                        const RenderOptions &options, const ArrayXd &__d_image) const
{
    const Scene &scene = sceneAD.val;
    [[maybe_unused]] Scene &d_scene = sceneAD.der;
    GradientManager<Scene> &gm = sceneAD.gm;
    gm.setZero(); // zero multi-thread gradient

    const int nworker = omp_get_num_procs();
    Timer _("Render interior");

    const auto &camera = scene.camera;
    assert(!camera.rect.isValid());
    BlockedImage blocks({camera.width, camera.height}, {16, 16});

    const auto d_image_spec_list = from_tensor_to_spectrum_list(
        __d_image, camera.getNumPixels());
    // gradient image
    std::vector<Spectrum> g_image_spec_list(camera.getNumPixels(), Spectrum::Zero());
#if defined(FORWARD)
    int shape_idx = scene.getShapeRequiresGrad();
#endif

    int blockProcessed = 0;
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < blocks.m_BlocksTotal; i++)
    {
        const int tid = omp_get_thread_num();
        ImageBlock block = blocks.getBlock(i);
        for (Array2i pixelIdx = block.curPixel(); block.hasNext(); pixelIdx = block.nextPixel())
        {
#ifdef FORWARD
            gm.get(tid).shape_list[shape_idx]->param = 0;
            Float param = gm.get(tid).shape_list[shape_idx]->param;
#endif
            int pixel_ravel_idx = ravel_multi_index(pixelIdx, {camera.width, camera.height});
            RndSampler sampler(options.seed, pixel_ravel_idx);
            d_pixelColor(scene, gm.get(tid),
                         pixelIdx, &sampler, options.num_samples, d_image_spec_list[pixel_ravel_idx]);
#ifdef FORWARD
            param = gm.get(tid).shape_list[shape_idx]->param - param;
            g_image_spec_list[pixel_ravel_idx] += Spectrum(param, 0.f, 0.f);
#endif
        }
        if (verbose)
#pragma omp critical
            progressIndicator(static_cast<Float>(++blockProcessed) / blocks.size());
    }
    if (verbose)
        std::cout << std::endl;

    gm.merge();

    /* normal related */
#ifdef NORMAL_PREPROCESS
    Timer _2("normal preprocess");
    d_precompute_normal(scene, d_scene);
#endif

    return from_spectrum_list_to_tensor(g_image_spec_list, camera.getNumPixels());
}
