#pragma once
#ifndef INTEGRATOR_H__
#define INTEGRATOR_H__

#include <core/fwd.h>
#include <core/ptr.h>
#include <core/logger.h>
#include <render/medium.h>
#include <render/common.h>
#include <render/intersection.h>
struct Scene;
struct SceneAD;
struct RndSampler;
struct Ray;
struct RenderOptions;

struct RadianceQueryRecord {
    RadianceQueryRecord() = default;

    RadianceQueryRecord(const Array2i &pixel_idx, RndSampler *sampler, int max_bounces)
        : pixel_idx(pixel_idx), sampler(sampler), max_bounces(max_bounces) {}

    RadianceQueryRecord(const Array2i &pixel_idx, RndSampler *sampler,
                        const RenderOptions &options)
        : pixel_idx(pixel_idx), sampler(sampler), max_bounces(options.max_bounces) {}

    RadianceQueryRecord(const Array2i &_pixel_idx, RndSampler *sampler,
                        const RenderOptions &options, int med_id)
        : pixel_idx(_pixel_idx), sampler(sampler), max_bounces(options.max_bounces), med_id(med_id) {}

    Array2i pixel_idx;
    RndSampler *sampler = nullptr;
    bool incEmission = true;
    int max_bounces;
    int med_id = -1;
    int depth = 0; // current depth
};

struct PixelQueryRecord : RadianceQueryRecord {
    PixelQueryRecord() = default;

    PixelQueryRecord(const Array2i &pixel_idx, RndSampler *sampler, int max_bounces, int nsamples, bool enable_antithetic = true)
        : RadianceQueryRecord(pixel_idx, sampler, max_bounces), nsamples(nsamples), enable_antithetic(enable_antithetic) {}

    PixelQueryRecord(const Array2i &pixel_idx, RndSampler *sampler,
                     const RenderOptions &options, int nsamples, bool enable_antithetic = true)
        : RadianceQueryRecord(pixel_idx, sampler, options), nsamples(nsamples), enable_antithetic(enable_antithetic) {}

    PixelQueryRecord(const Array2i &_pixel_idx, RndSampler *sampler,
                     const RenderOptions &options, int med_id, int nsamples, bool enable_antithetic = true)
        : RadianceQueryRecord(_pixel_idx, sampler, options, med_id), nsamples(nsamples), enable_antithetic(enable_antithetic) {}
    int nsamples;
    bool enable_antithetic = true;
};

struct Integrator {
    Integrator(const Properties &props = Properties())
        : enable_antithetic(props.get<bool>("enable_antithetic", true)),
          two_point_antithetic(props.get<bool>("two_point_antithetic", true)) {}

    virtual ~Integrator() {}
    virtual void configure([[maybe_unused]] const Scene &scene) {}
    // forward rendering
    virtual ArrayXd renderC(const Scene &scene, const RenderOptions &options) const = 0;

    // differentiable rendering
    virtual ArrayXd renderD(SceneAD &sceneAD, const RenderOptions &options, const ArrayXd &d_image) const = 0;

    bool enable_antithetic = true;
    bool two_point_antithetic = true;
};

struct MISIntegrator {

    enum ESamplingMode {
        ESolidAngle         = 1,
        EArea               = 1 << 1,
        EMISBalance         = 1 << 2,
        EMISPower           = 1 << 3,
        EMISFirst           = 1 << 4,
        ESkipSensor         = 1 << 5,
        EDebugMISPower      = 1 << 6, // used for debugging
        EDebugMISBalance    = 1 << 7, // used for debugging
        EMISPath            = 1 << 8,
        ENEE                = 1 << 9,
        EDebugMIS           = EDebugMISPower | EDebugMISBalance,
        EMIS                = EMISBalance | EMISPower,
        EAreaSampling       = EArea | EMIS,
        ESolidAngleSampling = ESolidAngle | EMIS,
    };

    struct Vertex {
        enum EType {
            ESurface  = 1 << 0,
            EMedium   = 1 << 1,
            EBoundary = 1 << 2,
            ECamera   = 1 << 3,
        };

        static inline Vertex createSurface(const Intersection &its) {
            Vertex v;
            v.type = ESurface;
            v.its  = &its;
            return v;
        }

        static inline Vertex createMedium(const Vector &p, const Medium *medium, Float pdf, Float pdf_next) {
            Vertex v;
            v.type     = EMedium;
            v.p        = p;
            v.medium   = medium;
            v.pdf_fwd  = pdf;
            v.pdf_next = pdf_next;
            return v;
        }

        static inline Vertex createBoundary(const Intersection &its, Float pdf, Float pdf_next) {
            Vertex v;
            v.type     = EBoundary;
            v.its      = &its;
            v.pdf_fwd  = pdf;
            v.pdf_next = pdf_next;
            return v;
        }

        static inline Vertex createCamera(const Camera &camera, const Array2i &pixel_idx, Float pdf_next, Float pdf_rev) {
            Vertex v;
            v.type = ECamera;
            v.camera = &camera;
            v.pixel_idx = pixel_idx;
            v.pdf_next = pdf_next;
            v.pdf_fwd = 1.;
            v.pdf_rev = pdf_rev; // due to the random choice of the pixel when connecting to the camera
            return v;
        }

        Vector getP() const {
            if (type == ESurface || type == EBoundary) {
                return its->p;
            } else if (type == EMedium) {
                return p;
            } else if (type == ECamera) {
                return camera->cpos;
            } else {
                Throw("Invalid vertex type {}", type);
            }
        }

        Float convertDensity(Float pdf, const Vertex &next) {
            if (next.type == ESurface || next.type == EBoundary) {
                return pdf * geometric(getP(), next.getP(), next.its->geoFrame.n);
            } else {
                return pdf * geometric(getP(), next.getP());
            }
        }

        EType type;

        // Medium Interaction
        Vector        p;
        const Medium *medium;

        // Surface Interaction
        const Intersection *its;

        // Camera
        Array2i pixel_idx;
        const Camera *camera;

        Float pdf_fwd;
        Float pdf_rev;
        Float pdf_next; // phase pdf or bsdf pdf
    };

    struct MISContext {
        void append(const Vertex v) {
            vertices.push_back(v);
        }

        Float pdfFwd() const {
            Float pdf = 1.;
            for (int i = 0; i < static_cast<int>(vertices.size()); ++i) {
                pdf *= vertices[i].pdf_fwd;
            }
            return pdf;
        }

        Float pdfRev() const {
            Float pdf = 1.;
            for (int i = 0; i < static_cast<int>(vertices.size()); ++i) {
                pdf *= vertices[i].pdf_rev;
            }
            return pdf;
        }

        std::vector<Vertex> vertices;
    };

    MISIntegrator() = default;
    MISIntegrator(const ESamplingMode &mode);
    MISIntegrator(const Properties &props);

    static std::unordered_map<std::string, ESamplingMode> s_sampling_mode_map;
    ESamplingMode                                         m_sampling_mode = EMISPower;
};

struct UnidirectionalPathTracer : Integrator {
    UnidirectionalPathTracer(const Properties &props) : Integrator(props) {}
    virtual ~UnidirectionalPathTracer() {}

    // forward rendering
    virtual ArrayXd renderC(const Scene &scene, const RenderOptions &options) const;

    // differentiable rendering
    virtual ArrayXd renderD(SceneAD &sceneAD, const RenderOptions &options, const ArrayXd &d_image) const;

    // query pixel color
    Spectrum pixelColor(const Scene &scene, PixelQueryRecord &pRec) const;

    void pixelColorAD(SceneAD &sceneAD, PixelQueryRecord &pRec, const Spectrum &d_res) const;

    Spectrum pixelColorFwd(SceneAD &sceneAD, PixelQueryRecord &pRec) const;

    // query radiance
    virtual Spectrum Li(const Scene &scene, const Ray &ray, RadianceQueryRecord &rRec) const = 0;

    virtual void LiAD(SceneAD &sceneAD, const Ray &ray, RadianceQueryRecord &rRec, const Spectrum &d_res) const = 0;

    virtual Spectrum LiFwd([[maybe_unused]] SceneAD &sceneAD, [[maybe_unused]] const Ray &ray, [[maybe_unused]] RadianceQueryRecord &rRec) const { return Spectrum(0.0); }

    virtual std::string getName() const = 0;
};

struct IntegratorBoundary {
    virtual void configure([[maybe_unused]] const Scene &scene) {}
};

namespace {
struct DebugInfo {
    DebugInfo() {}
    DebugInfo(int nworkers, int npixels, int nsamples)
        : nworkers(nworkers), npixels(npixels), nsamples(nsamples) {
        image_per_thread.resize(nworkers);
        for (int i = 0; i < nworkers; i++) {
            image_per_thread[i].resize(npixels);
            for (int j = 0; j < static_cast<int>(image_per_thread[i].size()); j++) {
                image_per_thread[i][j] = Spectrum::Zero();
            }
        }
    }
    int nworkers;
    int npixels;
    int nsamples;
    std::vector<std::vector<Spectrum>> image_per_thread;
    std::vector<Spectrum> getData() {
        std::vector<Spectrum> ret(npixels, Spectrum::Zero());
        for (int i = 0; i < nworkers; i++) {
            for (int j = 0; j < npixels; j++) {
                ret[j] += image_per_thread[i][j];
            }
        }
        return ret;
    }

    ArrayXd getArray() {
        ArrayXd ret = ArrayXd::Zero(npixels * 3);
        for (int i = 0; i < nworkers; i++) {
            ret = ret + from_spectrum_list_to_tensor(
                            image_per_thread[i], npixels);
        }
        return ret;
    }

    void save([[maybe_unused]] ptr<float> d_image) {
        std::vector<Spectrum> img = getData();
        // from_spectrum_list_to_ptr(img, npixels, d_image);
    }
};
} // namespace

#endif // INTEGRATOR_H__