#pragma once
#include <core/logger.h>
#include <render/intersection.h>

struct SceneAD;
struct LightPath
{
    LightPath() {}
    LightPath(const Array2i &pixel_idx) : pixel_idx(pixel_idx) {}
    Array2i pixel_idx;
    std::vector<Intersection> vertices;
    std::vector<int> vs;

    void reserve(int n)
    {
        vertices.reserve(n);
        vs.reserve(n);
    }
    void resize(int n)
    {
        vertices.resize(n);
        vs.resize(n);
    }
    void clear()
    {
        vertices.clear();
        vs.clear();
    }
    void clear(const Array2i &pixel_idx)
    {
        this->pixel_idx = pixel_idx;
        vertices.clear();
        vs.clear();
    }

    void append(const Intersection &v)
    {
        vertices.push_back(v);
        vs.push_back(vertices.size() - 1);
    }
    void append_nee(const Intersection &v)
    {
        int cur_id = vs.back();
        // FIXME
        // Vertex &curV = vertices[cur_id];
        vertices.push_back(v);
        vertices[cur_id].l_nee_id = vertices.size() - 1;
    }
    void append_bsdf(const Intersection &v)
    {
        // FIXME
        // Vertex &curV = vertices.at(vs.back());
        int cur_id = vs.back();
        vertices.push_back(v);
        vertices[cur_id].l_bsdf_id = vertices.size() - 1;
    }
    void setZero()
    {
        for (auto &v : vertices)
            v.setZero();
    }
    void setZero(int n)
    {
        assert(n <= vertices.size());
        for (int i = 0; i < n; i++)
            vertices[i].setZero();
    }
    std::string toString() const
    {
        Color::Modifier red(Color::FG_RED);
        Color::Modifier def(Color::FG_DEFAULT);
        std::stringstream ss;
        ss << "vertices" << std::endl;
        for (auto &v : vertices)
        {
            ss << red << "\t type: " << v.type << def << std::endl;
            ss << "\t p: " << v.p[0] << " " << v.p[1] << " " << v.p[2] << std::endl;
            ss << "\t pdf: " << v.pdf << std::endl;
            ss << "\t J: " << v.J << std::endl;
            ss << "\t value: " << v.value[0] << " " << v.value[1] << " " << v.value[2] << std::endl;
            ss << "\t bsdf_bsdf: " << v.bsdf_bsdf[0] << " " << v.bsdf_bsdf[1] << " " << v.bsdf_bsdf[2] << std::endl;
            ss << "\t nee_bsdf: " << v.nee_bsdf[0] << " " << v.nee_bsdf[1] << " " << v.nee_bsdf[2] << std::endl;
            ss << "\t bsdf_id: " << v.l_bsdf_id << std::endl;
            ss << "\t nee_id: " << v.l_nee_id << std::endl;
            ss << "\t med_id: " << v.medium_id << std::endl;
        }
        ss << "vs" << std::endl;

        for (auto &v : vs)
        {
            const Intersection &vv = vertices[v];
            ss << "\t id: " << v << std::endl;
            ss << "\t\t nee_id: " << vv.l_nee_id << std::endl;
            ss << "\t\t bsdf_id: " << vv.l_bsdf_id << std::endl;
        }
        return ss.str();
    }
};

struct LightPathAD
{
    LightPath val;
    LightPath der;
    LightPathAD() {}
    LightPathAD(const LightPath &val) : val(val), der(val)
    {
        der.setZero();
    }
    void reserve(int n)
    {
        val.reserve(n);
        der.resize(n);
    }
    void clear()
    {
        val.clear();
        der.clear();
    }
    void zeroGrad()
    {
        der.setZero(val.vertices.size());
    }
};

namespace algorithm1_ptracer
{
    struct Tuple
    {
        int type;
        RndSampler *sampler;
        std::array<Intersection *, 3> intersections;
    };

    struct TupleAD
    {
        Tuple val;
        Tuple der;
    };

    struct LightPath
    {
        void clear()
        {
            vertices.clear();
            path.clear();
        }
        void setZero()
        {
            for (auto &v : vertices)
            {
                v.setZero();
            }
        }
        void append(const Intersection &v)
        {
            vertices.push_back(v);
            path.push_back(vertices.size() - 1);
        }
        void resize(int n)
        {
            vertices.resize(n);
            path.resize(n);
        }
        std::vector<Intersection> vertices;
        std::vector<int> path;
    };

#include <core/template.h>
    using LightPathAD = TypeAD<LightPath>;

    // valid, anti_path, pdf1A, pdf1B, pdf2A, pdf2B
    std::tuple<bool, Intersection, Float, Float, Float, Float> antithetic(const Scene &scene, const LightPath &path);
    std::tuple<bool, Intersection, Float, Float, Float, Float> antithetic_surf(const Scene &scene, const LightPath &path, RndSampler *sampler);
    std::tuple<bool, Intersection, Float, Float, Float, Float> antithetic_vol(const Scene &scene, const LightPath &path, RndSampler *sampler, bool is_equal_trans = false);
    Spectrum eval(const Scene &scene, LightPath &path, RndSampler *sampler);
    void d_eval(SceneAD &sceneAD,
                LightPathAD &pathAD,
                Spectrum d_value, RndSampler *sampler);
    void baseline(const Scene &scene, Scene &d_scene,
                  LightPathAD &pathAD,
                  Spectrum d_value, RndSampler *sampler);
}

namespace algorithm1
{
    void d_velocity(SceneAD &sceneAD, const Intersection &its, Float d_u);

    Spectrum eval(const Scene &scene, LightPath &path, RndSampler *sampler);
    void d_eval(const Scene &scene, Scene &d_scene,
                LightPathAD &pathAD,
                Spectrum d_value, RndSampler *sampler);
    void baseline(const Scene &scene, Scene &d_scene,
                  LightPathAD &pathAD,
                  Spectrum d_value, RndSampler *sampler);

    std::pair<Spectrum, Spectrum> d_evalFwd(const Scene &scene, Scene &d_scene, LightPathAD &pathAD, RndSampler *sampler);
    Spectrum evalFwd(const Scene &scene, LightPath &path, RndSampler *sampler);
}

namespace algorithm1_vol
{
    Spectrum eval(const Scene &scene, LightPath &path, RndSampler *sampler);
    void d_evalPath(const Scene &scene, Scene &d_scene,
                    const LightPath &path, LightPath &d_path,
                    Spectrum d_value);
    void d_evalVertex(const Scene &scene, Scene &d_scene,
                      const LightPath &path, LightPath &d_path,
                      RndSampler *sampler);
    void d_getPath(const Scene &scene, Scene &d_scene, LightPath &path, LightPath &d_path);
    void d_getPoint(const Scene &scene, Scene &d_scene, const Intersection &v, Intersection &d_v);

    void d_eval(const Scene &scene, Scene &d_scene,
                LightPathAD &pathAD,
                Spectrum d_value, RndSampler *sampler);
    void baseline(const Scene &scene, Scene &d_scene,
                  LightPathAD &pathAD,
                  Spectrum d_value, RndSampler *sampler);

    void getPointFwd(const Scene &scene, Scene &d_scene, const Intersection &v, Intersection &d_v);
    void getPathFwd(const Scene &scene, Scene &d_scene, LightPath &path, LightPath &d_path);
    void evalVertexFwd(const Scene &scene, Scene &d_scene,
                       const LightPath &path, LightPath &d_path,
                       RndSampler *sampler);

    std::pair<Spectrum, Spectrum> evalPathFwd(const Scene &scene, Scene &d_scene,
                                              const LightPath &path, LightPath &d_path);
    std::pair<Spectrum, Spectrum> evalFwd(const Scene &scene, Scene &d_scene,
                                          LightPathAD &pathAD, RndSampler *sampler);
    std::pair<Spectrum, Spectrum> baselineFwd(const Scene &scene, Scene &d_scene,
                                              LightPathAD &pathAD, RndSampler *sampler);
}

namespace algorithm1_bdpt
{
    struct LightPath
    {
        LightPath() {}

        void clear()
        {
            vertices.clear();
            weight.clear();
        }
        void setZero()
        {
            for (auto &v : vertices)
                v.setZero();
            for (int i = 0; i < 3; i++)
                antithetic_vtx[i].setZero();
        }
        void append(const Intersection &v)
        {
            vertices.push_back(v);
            weight.push_back(0.f);
        }
        void reserve(int n)
        {
            vertices.reserve(n);
            weight.reserve(n);
        }
        int size() const { return vertices.size(); }
        const Intersection &operator[](std::size_t idx) const { return vertices[idx]; }
        Intersection &operator[](std::size_t idx) { return vertices[idx]; }

        bool isCameraPath;
        Array2i pixelIdx;
        Intersection antithetic_vtx[3];
        std::vector<Intersection> vertices;
        std::vector<Float> weight;
    };

#include <core/template.h>
    using LightPathAD = TypeAD<LightPath>;
    Spectrum d_evalPath(SceneAD &sceneAD, LightPathAD &cameraPathAD, int s, LightPathAD &lightPathAD, int t,
                        Float w, bool antithetic_success, Spectrum d_value);
    void d_evalVertex(SceneAD &sceneAD, LightPathAD &pathAD, int index);
    void d_getPoint(SceneAD &sceneAD, LightPathAD &pathAD, int index, bool antithetic);

}

namespace algorithm1_MALA_direct
{

    struct EdgeBound { // edge[idx].v0 corresponds to min, edge[idx].v1 corresponds to max
        Float min;
        Float max;
        Vector emitter_point;
        Vector emitter_dir;
        int mode; // 0: surface, 1: direction
        int edge_idx;
        int shape_idx;
    };

    struct LightPathPSS
    {
        LightPathPSS(int dim) {
            pss_state.resize(dim);
            discrete_dim.resize(dim);
            type[dim] = '\0';
        }

        void clear()
        {
            vertices.clear();
            pss_state.setZero();
        }
        void setZero()
        {
            for (auto &v : vertices)
                v.setZero();
            pss_state.setZero();
            baseValue.setZero();
            discrete_dim.setZero();
        }
        void append(const Intersection &v)
        {
            vertices.push_back(v);
        }
        void reserve(int n)
        {
            vertices.reserve(n);
        }
        void resize(const LightPathPSS &other) {
            vertices.resize(other.vertices.size());
            pss_state = other.pss_state;
            pss_state.setZero();
        }
        int size() const { return vertices.size(); }
        const Intersection &operator[](std::size_t idx) const { return vertices[idx]; }
        Intersection &operator[](std::size_t idx) { return vertices[idx]; }

        std::vector<Intersection> vertices;
        ArrayXd pss_state;
        ArrayXd discrete_dim;
        EdgePrimarySampleRecord ePSRec;
        EmitterPrimarySampleRecord dPSRec;
        Spectrum baseValue;

        char type[100];

        EdgeBound bound;
    };

#include <core/template.h>
    using LightPathPSSAD = TypeAD<LightPathPSS>;

    Spectrum eval(const Scene &scene,
                    const ArrayXd &d_rnd, int max_bounces,
                    const DiscreteDistribution &edge_dist,
                    const std::vector<Vector2i> &edge_indices,
                    LightPathPSS *path = nullptr, bool echo = false);
        
    void d_eval(const Scene &scene, Scene &d_scene,
                  LightPathPSSAD &pathAD
                  ) ;

}

namespace algorithm1_MALA_indirect
{
    // #define DIFF_DIM0

    struct EdgeBound { // edge[idx].v0 corresponds to min, edge[idx].v1 corresponds to max
        Float min;
        Float max;
        Vector dir;
        int edge_idx;
        int shape_idx;
    };

    struct LightPathPSS
    {
        LightPathPSS(int dim) {
            pss_stateS.resize(dim);
            pss_stateD.resize(dim);
            discrete_dimS.resize(dim);
            discrete_dimD.setZero();
            typeS[dim] = '\0';
        }

        void clear()
        {
            verticesS.clear();
            verticesD.clear();
            pss_stateS.setZero();
            pss_stateD.setZero();
            pss_stateE.setZero();
        }
        void setZero()
        {
            for (auto &v : verticesS)
                v.setZero();
            for (auto &v : verticesD)
                v.setZero();
            pss_stateS.setZero();
            pss_stateD.setZero();
            pss_stateE.setZero();
            discrete_dimS.setZero();
            discrete_dimD.setZero();
            baseValue = 0.0;
        }
        void resize(const LightPathPSS &other) {
            verticesS.resize(other.verticesS.size());
            verticesD.resize(other.verticesD.size());
            pss_stateS = other.pss_stateS;
            pss_stateD = other.pss_stateD;
            discrete_dimD = other.discrete_dimD;
            discrete_dimS = other.discrete_dimS;
            setZero();
        }

        // ArrayXd getDiscreteDim(int max_bounces, int cam_bounce) const {
        //     ArrayXd discrete_dim(3 * max_bounces);
        //     // Vector tmp = Vector::Zero();
        //     discrete_dim.head(3).setZero();
        //     if (max_bounces - cam_bounce - 1 > 0){
        //         // PSDR_INFO("discrete_dimS.size(): {}, max_bounces: {}, cam_bounce: {}", discrete_dimS.size(), max_bounces, cam_bounce);
        //         assert(discrete_dimS.size() == 3 * (max_bounces - cam_bounce - 1));
        //         discrete_dim.segment(3, 3 * (max_bounces - cam_bounce - 1)) = discrete_dimS;
        //     }
        //     if (cam_bounce > 0){
        //         assert(discrete_dimD.size() == 3 * cam_bounce);
        //         discrete_dim.segment(3 * (max_bounces - cam_bounce), 3 * cam_bounce) = discrete_dimD;
        //     }
        //     return discrete_dim;
        // }

        ArrayXd getPSSState(int max_bounces, int cam_bounce) const {
            ArrayXd pss_state(3 * max_bounces);
            pss_state.head(3) = pss_stateE;
            if (max_bounces - cam_bounce - 1 > 0){
                assert(pss_stateS.size() == 3 * (max_bounces - cam_bounce - 1));
                pss_state.segment(3, 3 * (max_bounces - cam_bounce - 1)) = pss_stateS;
            }
            if (cam_bounce > 0){
                assert(pss_stateD.size() == 3 * cam_bounce);
                pss_state.segment(3 * (max_bounces - cam_bounce), 3 * cam_bounce) = pss_stateD;
            }
            return pss_state;
        }

#ifndef DIFF_DIM0
        ArrayXd getEdgePSSState(int max_bounces, int cam_bounce) const {
            ArrayXd pss_state(3 * (max_bounces + 1));
            pss_state.head(3) = edge_point;
            pss_state.segment(3, 3) = edge_dir;
            if (max_bounces - cam_bounce - 1 > 0){
                assert(pss_stateS.size() == 3 * (max_bounces - cam_bounce - 1));
                pss_state.segment(6, 3 * (max_bounces - cam_bounce - 1)) = pss_stateS;
            }
            if (cam_bounce > 0){
                assert(pss_stateD.size() == 3 * cam_bounce);
                pss_state.segment(3 * (max_bounces - cam_bounce + 1), 3 * cam_bounce) = pss_stateD;
            }
            return pss_state;
        }
#endif

        std::vector<Intersection> verticesD;
        std::vector<Intersection> verticesS;
        ArrayXd pss_stateS;
        ArrayXd pss_stateD;
        Vector pss_stateE;
        ArrayXd discrete_dimS;
        ArrayXd discrete_dimD;
        EdgePrimarySampleRecord ePSRec;
#ifndef DIFF_DIM0
        Vector edge_point;
        Vector edge_dir;
#endif
        // EmitterPrimarySampleRecord dPSRec;
        Float baseValue;

        char typeS[100];
        char typeD[100];

        EdgeBound bound;
    };

#include <core/template.h>
    using LightPathPSSAD = TypeAD<LightPathPSS>;

    Spectrum eval(const Scene &scene,
                    const ArrayXd &d_rnd, int max_bounces, int cam_bounce,
                    const DiscreteDistribution &edge_dist,
                    const std::vector<Vector2i> &edge_indices,
                    LightPathPSS *path = nullptr, bool echo = false);
        
    void d_eval(const Scene &scene, Scene &d_scene,
                  LightPathPSSAD &pathAD
                  ) ;

}