#pragma once
#ifndef INTEGRATOR_H__
#define INTEGRATOR_H__

#include <core/fwd.h>
#include <core/ptr.h>
#include <render/medium.h>
#include <render/common.h>
struct Scene;
struct SceneAD;
struct RndSampler;
struct Ray;
struct RenderOptions;

struct RadianceQueryRecord
{
    RadianceQueryRecord() = default;

    RadianceQueryRecord(const Array2i &pixel_idx, RndSampler *sampler, int max_bounces)
        : pixel_idx(pixel_idx), sampler(sampler), max_bounces(max_bounces) {}

    RadianceQueryRecord(const Array2i &pixel_idx, RndSampler *sampler,
                        const RenderOptions &options)
        : pixel_idx(pixel_idx), sampler(sampler), max_bounces(options.max_bounces) {}

    RadianceQueryRecord(const Array2i &_pixel_idx, RndSampler *sampler,
                        const RenderOptions &options, int med_id)
        : pixel_idx(_pixel_idx), sampler(sampler), max_bounces(options.max_bounces), med_id(med_id) {}

    Array2i pixel_idx;
    RndSampler *sampler = nullptr;
    int max_bounces;
    int med_id = -1;
    int depth = 0; // current depth
};

struct PixelQueryRecord : RadianceQueryRecord
{
    PixelQueryRecord() = default;

    PixelQueryRecord(const Array2i &pixel_idx, RndSampler *sampler, int max_bounces, int nsamples, bool enable_antithetic = true)
        : RadianceQueryRecord(pixel_idx, sampler, max_bounces), nsamples(nsamples), enable_antithetic(enable_antithetic) {}

    PixelQueryRecord(const Array2i &pixel_idx, RndSampler *sampler,
                     const RenderOptions &options, int nsamples, bool enable_antithetic = true)
        : RadianceQueryRecord(pixel_idx, sampler, options), nsamples(nsamples), enable_antithetic(enable_antithetic) {}

    PixelQueryRecord(const Array2i &_pixel_idx, RndSampler *sampler,
                     const RenderOptions &options, int med_id, int nsamples, bool enable_antithetic = true)
        : RadianceQueryRecord(_pixel_idx, sampler, options, med_id), nsamples(nsamples), enable_antithetic(enable_antithetic) {}
    int nsamples;
    bool enable_antithetic = true;
};

struct Integrator
{
    virtual ~Integrator() {}

    // forward rendering
    virtual ArrayXd renderC(const Scene &scene, const RenderOptions &options) const = 0;

    // differentiable rendering
    virtual ArrayXd renderD(SceneAD &sceneAD, const RenderOptions &options, const ArrayXd &d_image) const = 0;

    bool enable_antithetic = true;
};

struct UnidirectionalPathTracer : Integrator
{
    virtual ~UnidirectionalPathTracer() {}

    // forward rendering
    virtual ArrayXd renderC(const Scene &scene, const RenderOptions &options) const;

    // differentiable rendering
    virtual ArrayXd renderD(SceneAD &sceneAD, const RenderOptions &options, const ArrayXd &d_image) const;

    // query pixel color
    Spectrum pixelColor(const Scene &scene, PixelQueryRecord &pRec) const;

    void pixelColorAD(SceneAD &sceneAD, PixelQueryRecord &pRec, const Spectrum &d_res) const;
    
    // query radiance
    virtual Spectrum Li(const Scene &scene, const Ray &ray, RadianceQueryRecord &rRec) const = 0;

    virtual void LiAD(SceneAD &sceneAD, const Ray &ray, RadianceQueryRecord &rRec, const Spectrum &d_res) const = 0;

    virtual std::string getName() const = 0;
};

namespace
{
    struct DebugInfo
    {
        DebugInfo(){};
        DebugInfo(int nworkers, int npixels, int nsamples)
            : nworkers(nworkers), npixels(npixels), nsamples(nsamples)
        {
            image_per_thread.resize(nworkers);
            for (int i = 0; i < nworkers; i++)
            {
                image_per_thread[i].resize(npixels);
                for (int j = 0; j < image_per_thread[i].size(); j++)
                {
                    image_per_thread[i][j] = Spectrum::Zero();
                }
            }
        }
        int nworkers;
        int npixels;
        int nsamples;
        std::vector<std::vector<Spectrum>> image_per_thread;
        std::vector<Spectrum> getData()
        {
            std::vector<Spectrum> ret(npixels, Spectrum::Zero());
            for (int i = 0; i < nworkers; i++)
            {
                for (int j = 0; j < npixels; j++)
                {
                    ret[j] += image_per_thread[i][j];
                }
            }
            return ret;
        }

        ArrayXd getArray()
        {
            ArrayXd ret = ArrayXd::Zero(npixels * 3);
            for (int i = 0; i < nworkers; i++)
            {
                ret = ret + from_spectrum_list_to_tensor(
                                image_per_thread[i], npixels);
            }
            return ret;
        }

        void save(ptr<float> d_image)
        {
            std::vector<Spectrum> img = getData();
            // from_spectrum_list_to_ptr(img, npixels, d_image);
        }
    };
} // namespace

#endif // INTEGRATOR_H__
