#include "direct2.h"
#include <render/common.h>
#include <render/imageblock.h>
#include <core/ray.h>
#include <core/sampler.h>
#include <render/scene.h>
#include <core/timer.h>
#include <iomanip>
#include "algorithm1.h"
namespace
{
    std::pair<Spectrum, LightPath>
    __radiance(const Scene &scene, RndSampler *sampler, Ray &ray,
               const Array2i &pixel_idx)
    {
        LightPath path(pixel_idx);
        path.append(scene.camera);
        Float pdf = 1.;
        Intersection its;
        Spectrum ret = Spectrum::Zero();
        // Perform the first intersection
        scene.rayIntersect(ray, true, its, IntersectionMode::EMaterial);
        if (!its.isValid())
            return {ret, path};

        its.pdf = 1.;
        path.append(its);

        if (its.isEmitter())
            ret += its.Le(-ray.dir) * its.J;

        DirectSamplingRecord dRec(its);
        Spectrum value = scene.sampleEmitterDirect(sampler->next2D(), dRec);
        if (!value.isZero(Epsilon))
            path.append_nee({dRec});

        return {ret, path};
    }

    Spectrum radiance(const Scene &scene, RndSampler *sampler, Ray &ray,
                      const Array2i &pixel_idx)
    {
        auto [value, path] = __radiance(scene, sampler, ray, pixel_idx);
        return algorithm1::eval(scene, path, sampler);
    }

    void d_radiance(const Scene &scene, Scene &d_scene,
                    RndSampler *sampler, Ray &ray, const Array2i &pixel_idx,
                    const Spectrum &d_value)
    {
        auto [value, path] = __radiance(scene, sampler, ray, pixel_idx);
        LightPathAD pathAD(path);
        algorithm1::d_eval(scene, d_scene, pathAD, d_value, sampler);
    }

    void Li(const Scene &scene, const Array2i &pixel_idx,
            RndSampler *sampler, int n_samples, Spectrum &res)
    {
        res = Spectrum::Zero();
        const auto &cam = scene.camera;
        Ray ray_primal, ray_dual;
        // antithetic sampling
        cam.samplePrimaryRayFromFilter(pixel_idx.x(), pixel_idx.y(), sampler->next2D(), ray_primal, ray_dual);
        sampler->save();
        res += radiance(scene, sampler, ray_primal, pixel_idx);
        sampler->restore();
        res += radiance(scene, sampler, ray_dual, pixel_idx);
        res /= 2;
    }

    void d_Li(const Scene &scene, Scene &d_scene, const Array2i &pixel_idx,
              RndSampler *sampler, int n_samples, Spectrum &d_res)
    {
        const auto &cam = scene.camera;
        Ray ray_primal, ray_dual;
        // antithetic sampling
        cam.samplePrimaryRayFromFilter(pixel_idx.x(), pixel_idx.y(), sampler->next2D(), ray_primal, ray_dual);
        sampler->save();
        d_radiance(scene, d_scene, sampler, ray_primal, pixel_idx, d_res / 2.);
        sampler->restore();
        d_radiance(scene, d_scene, sampler, ray_dual, pixel_idx, d_res / 2.);
    }

    void pixelColor(const Scene &scene, const Array2i &pixel_idx,
                    RndSampler *sampler, int n_samples, Spectrum &res)
    {
        res = Spectrum::Zero();
        for (int i = 0; i < n_samples; i++)
        {
            Spectrum val;
            Li(scene, pixel_idx, sampler, n_samples, val);
            res += val;
        }
        res /= n_samples;
    }

    void d_pixelColor(const Scene &scene, Scene &d_scene,
                      const Array2i &pixel_idx, RndSampler *sampler,
                      int n_sample, Spectrum d_res)
    {
        Spectrum res;
        for (int i = 0; i < n_sample; i++)
        {
            Spectrum _d_res = d_res / n_sample;
            d_Li(scene, d_scene, pixel_idx, sampler, n_sample, _d_res);
        }
    }
}

// forward rendering
ArrayXd Direct2::renderC(const Scene &scene, const RenderOptions &options) const
{
    Timer _("Forward rendering");
    const int nworker = omp_get_num_procs();
    const auto &camera = scene.camera;
    BlockedImage blocks({camera.width, camera.height}, {16, 16});
    std::vector<Spectrum> spec_list(camera.getNumPixels(), Spectrum::Zero());
    int blockProcessed = 0;
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < blocks.m_BlocksTotal; i++)
    {
        ImageBlock block = blocks.getBlock(i);
        for (Array2i pixelIdx = block.curPixel(); block.hasNext(); pixelIdx = block.nextPixel())
        {
            int pixel_ravel_idx = ravel_multi_index(pixelIdx, {camera.width, camera.height});
            RndSampler sampler(options.seed, pixel_ravel_idx);
            Spectrum pixel_val = Spectrum::Zero();
            pixelColor(scene, pixelIdx, &sampler, options.num_samples, pixel_val);
            spec_list[pixel_ravel_idx] += pixel_val;
        }
        if (verbose)
#pragma omp critical
            progressIndicator(static_cast<Float>(++blockProcessed) / blocks.size());
    }
    if (verbose)
        std::cout << std::endl;
    return from_spectrum_list_to_tensor(spec_list, camera.getNumPixels());
}

// rendering gradient image
ArrayXd Direct2::renderD(SceneAD &sceneAD,
                         const RenderOptions &options, const ArrayXd &_d_image) const
{
    const Scene &scene = sceneAD.val;
    Scene &d_scene = sceneAD.der;
    GradientManager<Scene> &gm = sceneAD.gm;
    gm.setZero(); // zero multi-thread gradient
    const int nworker = omp_get_num_procs();

    // const int nworker = 1;
    Timer _("Render interior");

    const auto &camera = scene.camera;
    assert(!camera.rect.isValid());
    BlockedImage blocks({camera.width, camera.height}, {16, 16});

    std::vector<Spectrum> spec_list = from_tensor_to_spectrum_list(_d_image, camera.getNumPixels());

    // gradient image
    std::vector<Spectrum> g_image_spec_list(camera.getNumPixels(), Spectrum::Zero());
#if defined(FORWARD)
    int shape_idx = scene.getShapeRequiresGrad();
#endif

    int blockProcessed = 0;
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < blocks.m_BlocksTotal; i++)
    {
        const int tid = omp_get_thread_num();
        ImageBlock block = blocks.getBlock(i);
        for (Array2i pixelIdx = block.curPixel(); block.hasNext(); pixelIdx = block.nextPixel())
        {
#ifdef FORWARD
            gm.get(tid).shape_list[shape_idx]->param = 0;
            Float param = gm.get(tid).shape_list[shape_idx]->param;
#endif
            int pixel_ravel_idx = ravel_multi_index(pixelIdx, {camera.width, camera.height});
            RndSampler sampler(options.seed, pixel_ravel_idx);
            d_pixelColor(scene, gm.get(tid),
                         pixelIdx, &sampler, options.num_samples, spec_list[pixel_ravel_idx]);
#ifdef FORWARD
            if (isfinite(gm.get(tid).shape_list[shape_idx]->param))
            {
                param = gm.get(tid).shape_list[shape_idx]->param - param;
                g_image_spec_list[pixel_ravel_idx] += Spectrum(param, 0.f, 0.f);
            }
#endif
        }
        if (verbose)
#pragma omp critical
            progressIndicator(static_cast<Float>(++blockProcessed) / blocks.size());
    }
    if (verbose)
        std::cout << std::endl;

    gm.merge();

    /* normal related */
#ifdef NORMAL_PREPROCESS
    Timer _2("normal preprocess");
    d_precompute_normal(scene, d_scene);
#endif

    return from_spectrum_list_to_tensor(g_image_spec_list, camera.getNumPixels());
}
