#include <core/ray.h>
#include <core/sampler.h>
#include <core/timer.h>
#include <render/imageblock.h>
#include <render/integrator.h>
#include <render/scene.h>
#include <render/spiral.h>
#include <medium/heterogeneous.h>
#include <core/statistics.h>

std::unordered_map<std::string, MISIntegrator::ESamplingMode>
    MISIntegrator::s_sampling_mode_map = {
        { "solid_angle", ESolidAngle },
        { "area", EArea },
        { "mis", EMIS },
        { "mis_balance", EMISBalance },
        { "mis_power", EMISPower }
    };

MISIntegrator::MISIntegrator(const ESamplingMode &mode) : m_sampling_mode(mode) {}

MISIntegrator::MISIntegrator(const Properties &props) {
    m_sampling_mode = (MISIntegrator::ESamplingMode) props.get<int>("sampling_mode", MISIntegrator::EMISPower);
}

ArrayXd UnidirectionalPathTracer::renderC(const Scene &scene, const RenderOptions &options) const 
{
    std::string filter_type = scene.camera.rfilter->type_name();
    PSDR_INFO("{} renderC with spp = {}, "
              "with filter = {}, "
              "with antithetic = {}, "
              "with two point antithetic = {}",
              getName(), options.num_samples, filter_type, 
              enable_antithetic, two_point_antithetic);
    if (verbose)
        printf("%s renderC:\n", getName().c_str());
    Timer _("Forward rendering");
    const int nworker = omp_get_num_procs();
    const auto &camera = scene.camera;
    int block_size = 8;
    Spiral spiral(camera.getCropSize(), camera.getOffset(), block_size);
    ImageBlock image = ImageBlock(camera.getOffset(), camera.getCropSize());
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < static_cast<int>(spiral.block_count()); i++)
    {
        auto [offset, size, block_id] = spiral.next_block();
        ImageBlock block(offset, size);
        for (Array2i pixelIdx = block.curPixel(); block.hasNext(); pixelIdx = block.nextPixel())
        {
            int pixel_ravel_idx = ravel_multi_index(pixelIdx, {camera.width, camera.height});
            RndSampler sampler(options.seed, pixel_ravel_idx);
            PixelQueryRecord pRec{pixelIdx, &sampler, options, options.num_samples, enable_antithetic};
            block.put(pixelIdx, pixelColor(scene, pRec));
        }
        // accumulate the block into the image
        image.put(block);
        if (verbose)
#pragma omp critical
            progressIndicator(static_cast<Float>(spiral.block_counter()) / spiral.block_count());
    }
    if (verbose) {
        printf("\n");
        Statistics::getInstance()->printStats();
    }
    return image.flattened();
}

ArrayXd UnidirectionalPathTracer::renderD(SceneAD &sceneAD, const RenderOptions &options, const ArrayXd &__d_image) const
{
    PSDR_INFO("{} renderD with spp = {}, "
              "with antithetic = {}, "
              "with two point antithetic = {}",
              getName(), options.num_samples, 
              enable_antithetic, two_point_antithetic);
    if (verbose)
        printf("%s renderD:\n", getName().c_str());
    const Scene &scene = sceneAD.val;
    GradientManager<Scene> &gm = sceneAD.gm;
    Scene &d_scene = sceneAD.der;
    gm.setZero(); // zero multi-thread gradientn

    const int nworker = omp_get_num_procs();
    Timer _("Render interior");

    const auto &camera = scene.camera;

    int block_size = 8;
    Spiral spiral(camera.getCropSize(), camera.getOffset(), block_size);
    ImageBlock grad_image = ImageBlock(camera.getOffset(), camera.getCropSize());
    ImageBlock d_image = ImageBlock(camera.getOffset(), camera.getCropSize(), __d_image);

    // gradient image
    std::vector<Spectrum> g_image_spec_list(camera.getNumPixels(), Spectrum::Zero());
#if defined(FORWARD)
    int shape_idx = scene.getShapeRequiresGrad();
#endif

    int blockProcessed = 0;
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < static_cast<int>(spiral.block_count()); i++)
    {
        auto [offset, size, block_id] = spiral.next_block();
        ImageBlock block(offset, size);
        const int tid = omp_get_thread_num();
        for (Array2i pixelIdx = block.curPixel(); block.hasNext(); pixelIdx = block.nextPixel())
        {
#ifdef FORWARD
            gm.get(tid).zeroParameter();
#endif
            int pixel_ravel_idx = ravel_multi_index(pixelIdx, {camera.width, camera.height});
            RndSampler sampler(options.seed, pixel_ravel_idx);
            PixelQueryRecord pRec{pixelIdx, &sampler, options, options.num_samples, enable_antithetic};
            if (!forward)
                pixelColorAD(sceneAD, pRec, d_image.get(pixelIdx));
            else
                block.put(pixelIdx, pixelColorFwd(sceneAD, pRec));
#ifdef FORWARD
            Float param = gm.get(tid).getParameter();
            block.put(pixelIdx, Spectrum(param, 0.f, 0.f));
#endif
        }
        grad_image.put(block);
        if (verbose)
#pragma omp critical
            progressIndicator(static_cast<Float>(spiral.block_counter()) / spiral.block_count());
    }
    if (verbose)
        printf("\n");

    gm.merge();

    /* normal related */
#ifdef NORMAL_PREPROCESS
    Timer _2("normal preprocess");
    d_precompute_normal(scene, d_scene);
#endif

    // Note: this may crash at runtime
    // Fixed by using a wrapper function and overriding renderD
    d_scene.configureD(scene);

    return grad_image.flattened();
}

Spectrum UnidirectionalPathTracer::pixelColor(const Scene &scene, PixelQueryRecord &pRec) const
{
    const Camera &cam = scene.camera;
    Spectrum res = Spectrum::Zero();
    for (int i = 0; i < pRec.nsamples; i++)
    {
        Ray ray;
        cam.samplePrimaryRayFromFilter(pRec.pixel_idx.x(), pRec.pixel_idx.y(), pRec.sampler->next2D(), ray);
        RadianceQueryRecord rRec(pRec);
        res += Li(scene, ray, rRec);
    }
    res /= pRec.nsamples;
    return res;
}

void UnidirectionalPathTracer::pixelColorAD(SceneAD &sceneAD, PixelQueryRecord &pRec, const Spectrum &_d_res) const
{
    const Scene &scene = sceneAD.val;
    const auto &cam = scene.camera;
    Spectrum d_res = _d_res / pRec.nsamples;
    for (int i = 0; i < pRec.nsamples; i++)
    {
        if (enable_antithetic)
        {
            Ray rays[4];
            cam.samplePrimaryRayFromFilter(pRec.pixel_idx.x(), pRec.pixel_idx.y(), pRec.sampler->next2D(), rays);
            pRec.sampler->save();
            if (!two_point_antithetic) {
                for (int i = 0; i < 4; ++i) {
                    pRec.sampler->restore();
                    LiAD(sceneAD, rays[i], pRec, d_res / 4.0);
                }
            } else {
                LiAD(sceneAD, rays[0], pRec, d_res / 2.0);
                pRec.sampler->restore();
                LiAD(sceneAD, rays[3], pRec, d_res / 2.0);
            }
        } else {
            Ray ray;
            cam.samplePrimaryRayFromFilter(pRec.pixel_idx.x(), pRec.pixel_idx.y(), pRec.sampler->next2D(), ray);
            LiAD(sceneAD, ray, pRec, d_res);
        }
    }
}

Spectrum UnidirectionalPathTracer::pixelColorFwd(SceneAD &sceneAD, PixelQueryRecord &pRec) const
{
    const Scene &scene = sceneAD.val;
    const auto &cam = scene.camera;
    Spectrum d_res(0.);
    // Spectrum d_res = _d_res / pRec.nsamples;
    for (int i = 0; i < pRec.nsamples; i++)
    {
        if (enable_antithetic)
        {
            Ray rays[4];
            cam.samplePrimaryRayFromFilter(pRec.pixel_idx.x(), pRec.pixel_idx.y(), pRec.sampler->next2D(), rays);
            pRec.sampler->save();
#ifndef TWO_POINT_ANTITHETIC
            for (int i = 0; i < 4; ++i)
            {
                pRec.sampler->restore();
                d_res += LiFwd(sceneAD, rays[i], pRec) / 4.0;
            }
#else
            d_res += LiAD(sceneAD, rays[0], pRec) / 2.;
            pRec.sampler->restore();
            d_res += LiAD(sceneAD, rays[3], pRec) / 2.;
#endif
        }
        else
        {
            Ray ray;
            cam.samplePrimaryRayFromFilter(pRec.pixel_idx.x(), pRec.pixel_idx.y(), pRec.sampler->next2D(), ray);
            d_res += LiFwd(sceneAD, ray, pRec);
        }
    }
    return d_res / pRec.nsamples;
}