#pragma once
#include <core/logger.h>
#include <render/intersection.h>

struct SceneAD;
struct LightPath
{
    LightPath() {}
    LightPath(const Array2i &pixel_idx) : pixel_idx(pixel_idx) {}
    Array2i pixel_idx;
    std::vector<Intersection> vertices;
    std::vector<int> vs;

    void reserve(int n)
    {
        vertices.reserve(n);
        vs.reserve(n);
    }
    void resize(int n)
    {
        vertices.resize(n);
        vs.resize(n);
    }
    void clear()
    {
        vertices.clear();
        vs.clear();
    }
    void clear(const Array2i &pixel_idx)
    {
        this->pixel_idx = pixel_idx;
        vertices.clear();
        vs.clear();
    }

    void append(const Intersection &v)
    {
        vertices.push_back(v);
        vs.push_back(vertices.size() - 1);
    }
    void append_nee(const Intersection &v)
    {
        int cur_id = vs.back();
        // FIXME
        // Vertex &curV = vertices[cur_id];
        vertices.push_back(v);
        vertices[cur_id].l_nee_id = vertices.size() - 1;
    }
    void append_bsdf(const Intersection &v)
    {
        // FIXME
        // Vertex &curV = vertices.at(vs.back());
        int cur_id = vs.back();
        vertices.push_back(v);
        vertices[cur_id].l_bsdf_id = vertices.size() - 1;
    }
    void setZero()
    {
        for (auto &v : vertices)
            v.setZero();
    }
    void setZero(int n)
    {
        assert(n <= vertices.size());
        for (int i = 0; i < n; i++)
            vertices[i].setZero();
    }
    std::string toString() const
    {
        Color::Modifier red(Color::FG_RED);
        Color::Modifier def(Color::FG_DEFAULT);
        std::stringstream ss;
        ss << "vertices" << std::endl;
        for (auto &v : vertices)
        {
            ss << red << "\t type: " << v.type << def << std::endl;
            ss << "\t p: " << v.p[0] << " " << v.p[1] << " " << v.p[2] << std::endl;
            ss << "\t pdf: " << v.pdf << std::endl;
            ss << "\t J: " << v.J << std::endl;
            ss << "\t value: " << v.value[0] << " " << v.value[1] << " " << v.value[2] << std::endl;
            ss << "\t bsdf_bsdf: " << v.bsdf_bsdf[0] << " " << v.bsdf_bsdf[1] << " " << v.bsdf_bsdf[2] << std::endl;
            ss << "\t nee_bsdf: " << v.nee_bsdf[0] << " " << v.nee_bsdf[1] << " " << v.nee_bsdf[2] << std::endl;
            ss << "\t bsdf_id: " << v.l_bsdf_id << std::endl;
            ss << "\t nee_id: " << v.l_nee_id << std::endl;
            ss << "\t med_id: " << v.medium_id << std::endl;
        }
        ss << "vs" << std::endl;

        for (auto &v : vs)
        {
            const Intersection &vv = vertices[v];
            ss << "\t id: " << v << std::endl;
            ss << "\t\t nee_id: " << vv.l_nee_id << std::endl;
            ss << "\t\t bsdf_id: " << vv.l_bsdf_id << std::endl;
        }
        return ss.str();
    }
};

struct LightPathAD
{
    LightPath val;
    LightPath der;
    LightPathAD() {}
    LightPathAD(const LightPath &val) : val(val), der(val)
    {
        der.setZero();
    }
    void reserve(int n)
    {
        val.reserve(n);
        der.resize(n);
    }
    void clear()
    {
        val.clear();
        der.clear();
    }
    void zeroGrad()
    {
        der.setZero(val.vertices.size());
    }
};

namespace algorithm1_ptracer
{
    struct Tuple
    {
        int type;
        RndSampler *sampler;
        std::array<Intersection *, 3> intersections;
    };

    struct TupleAD
    {
        Tuple val;
        Tuple der;
    };

    struct LightPath
    {
        void clear()
        {
            vertices.clear();
            path.clear();
        }
        void setZero()
        {
            for (auto &v : vertices)
            {
                v.setZero();
            }
        }
        void append(const Intersection &v)
        {
            vertices.push_back(v);
            path.push_back(vertices.size() - 1);
        }
        void resize(int n)
        {
            vertices.resize(n);
            path.resize(n);
        }
        std::vector<Intersection> vertices;
        std::vector<int> path;
    };

#include <core/template.h>
    using LightPathAD = TypeAD<LightPath>;

    // valid, anti_path, pdf1A, pdf1B, pdf2A, pdf2B
    std::tuple<bool, Intersection, Float, Float, Float, Float> antithetic(const Scene &scene, const LightPath &path);
    std::tuple<bool, Intersection, Float, Float, Float, Float> antithetic_surf(const Scene &scene, const LightPath &path, RndSampler *sampler);
    std::tuple<bool, Intersection, Float, Float, Float, Float> antithetic_vol(const Scene &scene, const LightPath &path, RndSampler *sampler, bool is_equal_trans = false);
    Spectrum eval(const Scene &scene, LightPath &path, RndSampler *sampler);
    void d_eval(SceneAD &sceneAD,
                LightPathAD &pathAD,
                Spectrum d_value, RndSampler *sampler);
    void baseline(const Scene &scene, Scene &d_scene,
                  LightPathAD &pathAD,
                  Spectrum d_value, RndSampler *sampler);
}

namespace algorithm1
{
    Spectrum eval(const Scene &scene, LightPath &path, RndSampler *sampler);
    void d_eval(const Scene &scene, Scene &d_scene,
                LightPathAD &pathAD,
                Spectrum d_value, RndSampler *sampler);
    void baseline(const Scene &scene, Scene &d_scene,
                  LightPathAD &pathAD,
                  Spectrum d_value, RndSampler *sampler);
}

namespace algorithm1_vol
{
    Spectrum eval(const Scene &scene, LightPath &path, RndSampler *sampler);
    void d_evalPath(const Scene &scene, Scene &d_scene,
                    const LightPath &path, LightPath &d_path,
                    Spectrum d_value);
    void d_evalVertex(const Scene &scene, Scene &d_scene,
                      const LightPath &path, LightPath &d_path,
                      RndSampler *sampler);
    void d_getPath(const Scene &scene, Scene &d_scene, LightPath &path, LightPath &d_path);
    void d_getPoint(const Scene &scene, Scene &d_scene, const Intersection &v, Intersection &d_v);
    void d_eval(const Scene &scene, Scene &d_scene,
                LightPathAD &pathAD,
                Spectrum d_value, RndSampler *sampler);
    void baseline(const Scene &scene, Scene &d_scene,
                  LightPathAD &pathAD,
                  Spectrum d_value, RndSampler *sampler);
}

namespace algorithm1_bdpt
{
    struct LightPath
    {
        LightPath() {}

        void clear()
        {
            vertices.clear();
            weight.clear();
        }
        void setZero()
        {
            for (auto &v : vertices)
                v.setZero();
            for (int i = 0; i < 3; i++)
                antithetic_vtx[i].setZero();
        }
        void append(const Intersection &v)
        {
            vertices.push_back(v);
            weight.push_back(0.f);
        }
        void reserve(int n)
        {
            vertices.reserve(n);
            weight.reserve(n);
        }
        int size() const { return vertices.size(); }
        const Intersection &operator[](std::size_t idx) const { return vertices[idx]; }
        Intersection &operator[](std::size_t idx) { return vertices[idx]; }

        bool isCameraPath;
        Array2i pixelIdx;
        Intersection antithetic_vtx[3];
        std::vector<Intersection> vertices;
        std::vector<Float> weight;
    };

#include <core/template.h>
    using LightPathAD = TypeAD<LightPath>;
    Spectrum d_evalPath(SceneAD &sceneAD, LightPathAD &cameraPathAD, int s, LightPathAD &lightPathAD, int t,
                        Float w, bool antithetic_success, Spectrum d_value);
    void d_evalVertex(SceneAD &sceneAD, LightPathAD &pathAD, int index);
    void d_getPoint(SceneAD &sceneAD, LightPathAD &pathAD, int index, bool antithetic);

}

namespace algorithm1_pathwas
{
    
    // #define FWDDIFF
    int NUM_AUX_SAMPLES = 8;
    Float base_sigma = 0.006;
    Float power = 1.4;

    struct WASCache{
        Intersection aux_half;
        Intersection aux;
        Float w;
        Vector dw;
        bool force_zero;
        bool VALID;
        WASCache() {
            setZero();
        }
        void setZero() {
            aux_half.setZero();
            aux.setZero();
            w = 0.0;
            dw.setZero();
            force_zero = false;
            VALID = true;
        }
    };
    struct LightPathWAS
    {
        LightPathWAS() {}
        LightPathWAS(const Array2i &pixel_idx) : pixel_idx(pixel_idx) {}
        Array2i pixel_idx;
        std::vector<Intersection> vertices;
        std::vector<int> vs;

        
        std::vector<std::vector<WASCache>> vertex_cache;
        std::vector<Vector> warped_X;
        std::vector<Float> div_warped_X;
        std::vector<Float> prefix_div_warped_X;

        void reserve(int n)
        {
            vertices.reserve(n);
            vs.reserve(n);

            vertex_cache.reserve(n);
            warped_X.reserve(n);
            div_warped_X.reserve(n);
            prefix_div_warped_X.reserve(n);
        }
        void resize(int n)
        {
            vertices.resize(n);
            vs.resize(n);
            
            vertex_cache.resize(n);
            for (int i = 0; i < n; i++)
            {
                vertex_cache[i].resize(NUM_AUX_SAMPLES);
            }
            
            warped_X.resize(n);
            div_warped_X.resize(n);
            prefix_div_warped_X.resize(n);
        }
        void resize(const LightPathWAS& other)
        {
            vertices.resize(other.vertices.size());
            vs.resize(other.vs.size());
            
            vertex_cache.resize(other.vertex_cache.size());
            warped_X.resize(other.warped_X.size());
            div_warped_X.resize(other.div_warped_X.size());
            prefix_div_warped_X.resize(other.prefix_div_warped_X.size());

            for (int i = 0; i < other.vertex_cache.size(); i++)
            {
                vertex_cache[i].resize(other.vertex_cache[i].size());
            }
            

        }
        void clear()
        {
            vertices.clear();
            vs.clear();
            
            vertex_cache.clear();
            warped_X.clear();
            div_warped_X.clear();
            prefix_div_warped_X.clear();
        }
        void clear(const Array2i &pixel_idx)
        {
            this->pixel_idx = pixel_idx;
            vertices.clear();
            vs.clear();
            
            vertex_cache.clear();
            warped_X.clear();
            div_warped_X.clear();
            prefix_div_warped_X.clear();
        }

        void append(const Intersection &v)
        {
            vertices.push_back(v);
            vs.push_back(vertices.size() - 1);
            
            vertex_cache.push_back(std::vector<WASCache>(NUM_AUX_SAMPLES));
            Vector tmp;
            warped_X.push_back(tmp);
            div_warped_X.push_back(0.0);
            prefix_div_warped_X.push_back(0.0);
        }
        void append_nee(const Intersection &v)
        {
            int cur_id = vs.back();
            // FIXME
            // Vertex &curV = vertices[cur_id];
            vertices.push_back(v);
            vertices[cur_id].l_nee_id = vertices.size() - 1;
            
            vertex_cache.push_back(std::vector<WASCache>(NUM_AUX_SAMPLES));
            Vector tmp;
            warped_X.push_back(tmp);
            div_warped_X.push_back(0.0);
            prefix_div_warped_X.push_back(0.0);
        }
        void append_bsdf(const Intersection &v)
        {
            // FIXME
            // Vertex &curV = vertices.at(vs.back());
            int cur_id = vs.back();
            vertices.push_back(v);
            vertices[cur_id].l_bsdf_id = vertices.size() - 1;

            vertex_cache.push_back(std::vector<WASCache>(NUM_AUX_SAMPLES));
            Vector tmp;
            warped_X.push_back(tmp);
            div_warped_X.push_back(0.0);
            prefix_div_warped_X.push_back(0.0);
        }
        void setZero()
        {
            for (auto &v : vertices)
                v.setZero();
            for (auto &X : warped_X) {
                X.setZero();
            }
            for (auto &dX : div_warped_X) {
                dX = 0.0;
            }
            for (auto &l : vertex_cache){
                for (auto &cache : l) {
                    cache.setZero();
                }
            }
            for (auto &dX : prefix_div_warped_X) {
                dX = 0.0;
            }
        }
        void setZero(int n)
        {
            assert(n <= vertices.size());
            for (int i = 0; i < n; i++){
                vertices[i].setZero();
                warped_X[i].setZero();
                div_warped_X[i] = 0.0;
                for (auto &cache : vertex_cache[i]) {
                    cache.setZero();
                }
                prefix_div_warped_X[i] = 0.0;
            }
        }
        std::string toString() const
        {
            Color::Modifier red(Color::FG_RED);
            Color::Modifier def(Color::FG_DEFAULT);
            std::stringstream ss;
            ss << "vertices" << std::endl;
            for (auto &v : vertices)
            {
                ss << red << "\t type: " << v.type << def << std::endl;
                ss << "\t p: " << v.p[0] << " " << v.p[1] << " " << v.p[2] << std::endl;
                ss << "\t pdf: " << v.pdf << std::endl;
                ss << "\t J: " << v.J << std::endl;
                ss << "\t value: " << v.value[0] << " " << v.value[1] << " " << v.value[2] << std::endl;
                ss << "\t bsdf_bsdf: " << v.bsdf_bsdf[0] << " " << v.bsdf_bsdf[1] << " " << v.bsdf_bsdf[2] << std::endl;
                ss << "\t nee_bsdf: " << v.nee_bsdf[0] << " " << v.nee_bsdf[1] << " " << v.nee_bsdf[2] << std::endl;
                ss << "\t bsdf_id: " << v.l_bsdf_id << std::endl;
                ss << "\t nee_id: " << v.l_nee_id << std::endl;
                ss << "\t med_id: " << v.medium_id << std::endl;
            }
            ss << "vs" << std::endl;

            for (auto &v : vs)
            {
                const Intersection &vv = vertices[v];
                ss << "\t id: " << v << std::endl;
                ss << "\t\t nee_id: " << vv.l_nee_id << std::endl;
                ss << "\t\t bsdf_id: " << vv.l_bsdf_id << std::endl;
            }
            return ss.str();
        }
    };

    
    struct LightPathWASAD
    {
        LightPathWAS val;
        LightPathWAS der;
        LightPathWASAD() {}
        LightPathWASAD(const LightPathWAS &val) : val(val), der(val)
        {
            der.setZero();
        }
        void resize(int n)
        {
            val.resize(n);
            der.resize(n);
        }
        void clear()
        {
            val.clear();
            der.clear();
        }
        void zeroGrad()
        {
            der.setZero(val.vertices.size());
        }
    };

    Spectrum eval(const Scene &scene, LightPathWAS &path, RndSampler *sampler);
    void d_eval(const Scene &scene, Scene &d_scene,
                LightPathWASAD &pathAD,
                Spectrum d_value, RndSampler *sampler);
    // void baseline(const Scene &scene, Scene &d_scene,
    //               LightPathAD &pathAD,
    //               Spectrum d_value, RndSampler *sampler);
}


namespace algorithm1_bdptwas
{   
    int NUM_AUX_SAMPLES = 8;
    Float base_sigma = 0.03;
    Float power = 1.4;
    bool use_prefix_sum = true;

    struct WASCache{
        Intersection aux_half;
        Intersection aux;
        Float w;
        Vector dw;
        bool force_zero;
        bool VALID;
        WASCache() {
            setZero();
        }
        void setZero() {
            aux_half.setZero();
            aux.setZero();
            w = 0.0;
            dw.setZero();
            force_zero = false;
            VALID = true;
        }
    };

    struct WASContainer{
        std::vector<WASCache> vertex_cache;
        Vector warped_X;
        Float div_warped_X;
        bool hasGradient;
        Vector sum_warped_X;
        Float sum_div_warped_X;
        WASContainer() {
            vertex_cache.resize(NUM_AUX_SAMPLES);
            div_warped_X = 0.0;
            hasGradient = false;
            setZero();
        }
        void setZero() {
            for (auto &cache : vertex_cache) {
                cache.setZero();
            }
            warped_X.setZero();
            div_warped_X = 0.0;
            hasGradient = false;
        }
    };

    struct LightPath
    {
        LightPath() {}

        void clear()
        {
            vertices.clear();
            vertex_container.clear();
            weight.clear();
        }
        void setZero()
        {
            for (auto &v : vertices)
                v.setZero();
            // set zero for all caches
            for (auto &c : vertex_container) {
                c.setZero();
            }
            for (auto &c : connect_vertex_cache) {
                c.setZero();
            }
            connect_warped_X.setZero();
            connect_div_warped_X = 0.0;
        }
        void resetHasGradient()
        {
            for (auto &c : vertex_container) {
                c.hasGradient = false;
            }
        }
        void append(const Intersection &v)
        {
            vertices.push_back(v);
            vertex_container.push_back(WASContainer());
            weight.push_back(0.f);
        }
        void reserve(int n)
        {
            vertices.reserve(n);
            vertex_container.reserve(n);
            weight.reserve(n);
        }
        void resize(const LightPath& path)
        {
            vertices.resize(path.vertices.size());
            vertex_container.resize(path.vertex_container.size());
            weight.resize(path.weight.size());
        }
        int size() const { return vertices.size(); }
        const Intersection &operator[](std::size_t idx) const { return vertices[idx]; }
        Intersection &operator[](std::size_t idx) { return vertices[idx]; }

        bool isCameraPath;
        Array2i pixelIdx;
        std::vector<Intersection> vertices;
        std::vector<WASContainer> vertex_container;
        // std::vector<std::vector<WASCache>> vertex_cache;
        std::vector<Float> weight;
        // std::vector<Vector> warped_X;
        // std::vector<Float> div_warped_X;
        // std::vector<Vector> sum_warped_X;
        // std::vector<Float> sum_div_warped_X;
        std::vector<WASCache> connect_vertex_cache;
        Float connect_div_warped_X;
        Vector connect_warped_X;
    };

#include <core/template.h>
    using LightPathAD = TypeAD<LightPath>;
    // Spectrum d_evalPath(SceneAD &sceneAD, LightPathAD &cameraPathAD, int s, LightPathAD &lightPathAD, int t,
    //                     Float w, Spectrum d_value);
    void d_evalVertex(SceneAD &sceneAD, LightPathAD &pathAD, int index);

    void d_getPoint(SceneAD &sceneAD, LightPathAD &pathAD, int index);
    // function definition for all the functions in algorithm1_bdptwas.cpp
    void evalWarp(const Scene &scene, 
                LightPath &camera_path, int s,
                LightPath &light_path, int t);

    void evalWarpPath(const Scene &scene, 
                LightPath &path);

    Spectrum evalBoundary(const Scene &scene, const LightPath& camera_path, int s, const LightPath& light_path, int t, 
            const Spectrum &L);

    void evalPrefix(LightPath& path);

    void d_evalPrefix(LightPathAD &pathAD);
    
    void d_evalBoundary(const Scene &scene,
                        LightPathAD &camera_path, int s, LightPathAD &light_path, int t,// only use Warp and divWarp
                        const Spectrum &L,
                        Spectrum d_value);
    
    void d_evalBoundary_prefix(const Scene &scene,
                        LightPathAD &camera_path, int s, LightPathAD &light_path, int t,// only use Warp and divWarp
                        const Spectrum &L,
                        Spectrum d_value);
    
    void d_evalWarp(const Scene &scene,
                    LightPathAD &camera_path, int s,
                    LightPathAD &light_path, int t);
        
    void d_evalWarpPath(const Scene &scene,
                    LightPathAD &pathAD);
        
    void d_getWAS(SceneAD &scene, 
                        LightPathAD &camera_path, int s, 
                        LightPathAD &light_path, int t);
    void d_getWASPath(SceneAD &scene, 
                        LightPathAD &path);
    
    void d_warpSurface(const std::vector<WASCache> &cache_list, std::vector<WASCache> &d_cache_list,
                    Vector &d_warp, Float &d_div_warp);
    
    void d_rayIntersectEdgeExt(SceneAD& sceneAD, 
                                const Intersection& origin, const Intersection &edge_ext, const Intersection &aux, Vector& d_x);
}