#include "path2.h"
#include <render/common.h>
#include <render/imageblock.h>
#include <core/ray.h>
#include <core/sampler.h>
#include <render/scene.h>
#include <core/timer.h>
#include <iomanip>
#include "algorithm1.h"
#include <signal.h>

namespace path2_meta
{
    Spectrum __Li(const Scene &scene, const Ray &_ray, RadianceQueryRecord &rRec, LightPath *path)
    {
        Ray ray(_ray);
        RndSampler *sampler = rRec.sampler;
        Array2i pixel_idx = rRec.pixel_idx;

        Intersection its;
        if (path)
        {
            path->clear(pixel_idx);
            path->append(scene.camera); // NOTE
        }

        const int max_depth = rRec.max_bounces;
        Spectrum ret = Spectrum::Zero();
        scene.rayIntersect(ray, true, its);
        if (!its.isValid())
            return Spectrum::Zero();

        its.pdf = 1.;
        if (path) path->append(its);

        Spectrum throughput = Spectrum::Ones();
        // Float eta = 1.0f;
        if (its.isEmitter())
            ret += throughput * its.Le(-ray.dir);
        for (int depth = 0; depth < max_depth && its.isValid(); depth++)
        {
            // Direct illumination
            Float pdf_nee;
            Vector wo;
            DirectSamplingRecord dRec(its);
            if (its.isEmitter() && its.ptr_emitter->m_type == EnvironmentMap::TYPE_ID) break;
            auto value = scene.sampleEmitterDirect(sampler->next2D(), dRec);
            wo = its.toLocal(dRec.dir);
            if (!value.isZero(Epsilon))
            {
                auto bsdf_val = its.evalBSDF(wo);
#if defined(MIS)
                Float bsdf_pdf = its.pdfBSDF(wo);
                pdf_nee = dRec.pdf / geometric(its.p, dRec.p, dRec.n);
                auto mis_weight = square(pdf_nee) / (square(pdf_nee) + square(bsdf_pdf));
                ret += throughput * value * bsdf_val * mis_weight;

                dRec.pdf /= mis_weight;  // NOTE
                if (path) path->append_nee({dRec}); // NOTE
#else
                ret += throughput * value * bsdf_val;

                if (path) path->append_nee({dRec}); // NOTE
#endif
            }
            // Indirect illumination
            Float bsdf_pdf, bsdf_eta;
            auto bsdf_weight = its.sampleBSDF(sampler->next3D(), wo, bsdf_pdf, bsdf_eta);
            if (bsdf_weight.isZero(Epsilon))
                break;
            wo = its.toWorld(wo);
            ray = Ray(its.p, wo);

            Vector pre_p = its.p;
            if (!scene.rayIntersect(ray, true, its))
                break;

            throughput *= bsdf_weight;
            // eta *= bsdf_eta;

#if defined(MIS)
            if (its.isEmitter())
            {
                Spectrum light_contrib = its.Le(-ray.dir);
                if (!light_contrib.isZero(Epsilon))
                {
                    auto dist_sq = (its.p - ray.org).squaredNorm();
                    auto geometry_term = its.wi.z() / dist_sq;
                    pdf_nee = scene.pdfEmitterSample(its) / geometry_term;
                    auto mis_weight = square(bsdf_pdf) / (square(pdf_nee) + square(bsdf_pdf));
                    ret += throughput * light_contrib * mis_weight;

                    its.pdf = bsdf_pdf * geometric(pre_p, its.p, its.geoFrame.n) / mis_weight;
                    if (path) path->append_bsdf(its); // NOTE
                }
            }
#endif
            its.pdf = bsdf_pdf * geometric(pre_p, its.p, its.geoFrame.n);
            if (path) path->append(its); // NOTE
        }
        return ret;
    }
} // namespace path2_meta

Spectrum Path2::Li(const Scene &scene, const Ray &ray, RadianceQueryRecord &rRec) const
{
#if 0
    LightPath path;    
    auto value = path2_meta::__Li(scene, ray, rRec, &path);
    Spectrum ret = Spectrum::Zero();
    if (!value.isZero(Epsilon))
    {
        ret = algorithm1::eval(scene, path, rRec.sampler);
        if (!ret.allFinite())
            ret.setZero();
    }
    return ret;
#else
    return path2_meta::__Li(scene, ray, rRec, nullptr);
#endif
}

void Path2::LiAD(SceneAD &sceneAD, const Ray &ray, RadianceQueryRecord &rRec, const Spectrum &d_res) const
{
    LightPath path;
    Spectrum value = path2_meta::__Li(sceneAD.val, ray, rRec, &path);
    LightPathAD pathAD(path);
    if (!value.isZero(Epsilon))
        algorithm1::d_eval(sceneAD.val, sceneAD.getDer(), pathAD, d_res, rRec.sampler);
}

#include <render/spiral.h>
#include <integrator/d_scene.h>

// copy from Integrator::renderD and modify scene::configureD
ArrayXd Path2::renderD(SceneAD &sceneAD, const RenderOptions &options, const ArrayXd &__d_image) const
{
    PSDR_INFO("{} renderD with spp = {}, "
              "with antithetic = {}, "
              "with two point antithetic = {}",
              getName(), options.num_samples, 
              enable_antithetic, two_point_antithetic);
    if (verbose)
        printf("%s renderD:\n", getName().c_str());
    const Scene &scene = sceneAD.val;
    GradientManager<Scene> &gm = sceneAD.gm;
    Scene &d_scene = sceneAD.der;
    gm.setZero(); // zero multi-thread gradientn

    const int nworker = omp_get_num_procs();
    Timer _("Render interior");

    const auto &camera = scene.camera;

    int block_size = 8;
    Spiral spiral(camera.getCropSize(), camera.getOffset(), block_size);
    ImageBlock grad_image = ImageBlock(camera.getOffset(), camera.getCropSize());
    ImageBlock d_image = ImageBlock(camera.getOffset(), camera.getCropSize(), __d_image);

    // gradient image
    std::vector<Spectrum> g_image_spec_list(camera.getNumPixels(), Spectrum::Zero());

    int blockProcessed = 0;
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < static_cast<int>(spiral.block_count()); i++)
    {
        auto [offset, size, block_id] = spiral.next_block();
        ImageBlock block(offset, size);
        const int tid = omp_get_thread_num();
        for (Array2i pixelIdx = block.curPixel(); block.hasNext(); pixelIdx = block.nextPixel())
        {
            int pixel_ravel_idx = ravel_multi_index(pixelIdx, {camera.width, camera.height});
            RndSampler sampler(options.seed, pixel_ravel_idx);
            PixelQueryRecord pRec{pixelIdx, &sampler, options, options.num_samples, enable_antithetic};
            pixelColorAD(sceneAD, pRec, d_image.get(pixelIdx));
        }
        grad_image.put(block);
        if (verbose)
#pragma omp critical
            progressIndicator(static_cast<Float>(spiral.block_counter()) / spiral.block_count());
    }
    if (verbose)
        printf("\n");

    gm.merge();

    /* normal related */
#ifdef NORMAL_PREPROCESS
    Timer _2("normal preprocess");
    d_precompute_normal(scene, d_scene);
#endif

    // Note: use wrapper function to fix this runtime issue
    // d_scene.configureD(scene);
    scene_configure_d(sceneAD);
    
    return grad_image.flattened();
}
// #define USE_ENZYME_FORWARD
ArrayXd Path2::forwardRenderD(SceneAD &sceneAD, const RenderOptions &options) const {
        PSDR_INFO("{} forwardRenderD with spp = {}, "
              "with antithetic = {}, "
              "with two point antithetic = {}",
              getName(), options.num_samples, 
              enable_antithetic, two_point_antithetic);
    if (verbose)
        printf("%s renderD:\n", getName().c_str());
    const Scene &scene = sceneAD.val;
    GradientManager<Scene> &gm = sceneAD.gm;
    Scene &d_scene = sceneAD.der;

    gm.setZero(); // zero multi-thread gradientn

    const int nworker = omp_get_num_procs();
    Timer _("Render interior");

    const auto &camera = scene.camera;

    int block_size = 8;
    Spiral spiral(camera.getCropSize(), camera.getOffset(), block_size);
    ImageBlock grad_image = ImageBlock(camera.getOffset(), camera.getCropSize());
    // gradient image
    std::vector<Spectrum> g_image_spec_list(camera.getNumPixels(), Spectrum::Zero());
    int shape_idx = scene.getShapeRequiresGrad();
    int blockProcessed = 0;
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < static_cast<int>(spiral.block_count()); i++)
    {
        auto [offset, size, block_id] = spiral.next_block();
        ImageBlock block(offset, size);
        const int tid = omp_get_thread_num();
        for (Array2i pixelIdx = block.curPixel(); block.hasNext(); pixelIdx = block.nextPixel())
        {
#if !defined(USE_ENZYME_FORWARD)
            gm.get(tid).zeroParameter();
#endif
            int pixel_ravel_idx = ravel_multi_index(pixelIdx, {camera.width, camera.height});
            RndSampler sampler(options.seed, pixel_ravel_idx);
            PixelQueryRecord pRec{pixelIdx, &sampler, options, options.num_samples, enable_antithetic};
#ifdef USE_ENZYME_FORWARD
            block.put(pixelIdx, pixelColorFwd(sceneAD, pRec));
#else
            pixelColorAD(sceneAD, pRec, Spectrum(1, 0, 0));
            Float param = gm.get(tid).getParameter();
            block.put(pixelIdx, Spectrum(param, 0.f, 0.f));
#endif
        }
        grad_image.put(block);
        if (verbose)
#pragma omp critical
            progressIndicator(static_cast<Float>(spiral.block_counter()) / spiral.block_count());
    }
    if (verbose)
        printf("\n");

    gm.merge();

    /* normal related */
#ifdef NORMAL_PREPROCESS
    Timer _2("normal preprocess");
    d_precompute_normal(scene, d_scene);
#endif

    // Note: use wrapper function to fix this runtime issue
    // d_scene.configureD(scene);
    scene_configure_d(sceneAD);
    
    return grad_image.flattened();
}

Spectrum Path2::LiFwd(SceneAD &sceneAD, const Ray &ray, RadianceQueryRecord &rRec) const {
    LightPath path;    
    auto value = path2_meta::__Li(sceneAD.val, ray, rRec, &path);
    Spectrum ret = Spectrum::Zero();
    if (!value.isZero(Epsilon)) {
        // ret = algorithm1::evalFwd(sceneAD.val, path, rRec.sampler);
        // if (!ret.allFinite())
        //     ret.setZero();
        // return ret;
        LightPathAD pathAD(path);
        auto [value, d_value] = algorithm1::d_evalFwd(sceneAD.val, sceneAD.getDer(), pathAD, rRec.sampler);
        return d_value;
    }
    return ret;
}