#pragma once
#include <core/logger.h>
#include <render/intersection.h>

struct SceneAD;
struct LightPath
{
    LightPath() {}
    LightPath(const Array2i &pixel_idx) : pixel_idx(pixel_idx) {}
    Array2i pixel_idx;
    std::vector<Intersection> vertices;
    std::vector<int> vs;

    void reserve(int n)
    {
        vertices.reserve(n);
        vs.reserve(n);
    }
    void resize(int n)
    {
        vertices.resize(n);
        vs.resize(n);
    }
    void clear()
    {
        vertices.clear();
        vs.clear();
    }
    void clear(const Array2i &pixel_idx)
    {
        this->pixel_idx = pixel_idx;
        vertices.clear();
        vs.clear();
    }

    void append(const Intersection &v)
    {
        vertices.push_back(v);
        vs.push_back(vertices.size() - 1);
    }
    void append_nee(const Intersection &v)
    {
        int cur_id = vs.back();
        // FIXME
        // Vertex &curV = vertices[cur_id];
        vertices.push_back(v);
        vertices[cur_id].l_nee_id = vertices.size() - 1;
    }
    void append_bsdf(const Intersection &v)
    {
        // FIXME
        // Vertex &curV = vertices.at(vs.back());
        int cur_id = vs.back();
        vertices.push_back(v);
        vertices[cur_id].l_bsdf_id = vertices.size() - 1;
    }
    void setZero()
    {
        for (auto &v : vertices)
            v.setZero();
    }
    void setZero(int n)
    {
        assert(n <= vertices.size());
        for (int i = 0; i < n; i++)
            vertices[i].setZero();
    }
    std::string toString() const
    {
        Color::Modifier red(Color::FG_RED);
        Color::Modifier def(Color::FG_DEFAULT);
        std::stringstream ss;
        ss << "vertices" << std::endl;
        for (auto &v : vertices)
        {
            ss << red << "\t type: " << v.type << def << std::endl;
            ss << "\t p: " << v.p[0] << " " << v.p[1] << " " << v.p[2] << std::endl;
            ss << "\t pdf: " << v.pdf << std::endl;
            ss << "\t J: " << v.J << std::endl;
            ss << "\t value: " << v.value[0] << " " << v.value[1] << " " << v.value[2] << std::endl;
            ss << "\t bsdf_bsdf: " << v.bsdf_bsdf[0] << " " << v.bsdf_bsdf[1] << " " << v.bsdf_bsdf[2] << std::endl;
            ss << "\t nee_bsdf: " << v.nee_bsdf[0] << " " << v.nee_bsdf[1] << " " << v.nee_bsdf[2] << std::endl;
            ss << "\t bsdf_id: " << v.l_bsdf_id << std::endl;
            ss << "\t nee_id: " << v.l_nee_id << std::endl;
            ss << "\t med_id: " << v.medium_id << std::endl;
        }
        ss << "vs" << std::endl;

        for (auto &v : vs)
        {
            const Intersection &vv = vertices[v];
            ss << "\t id: " << v << std::endl;
            ss << "\t\t nee_id: " << vv.l_nee_id << std::endl;
            ss << "\t\t bsdf_id: " << vv.l_bsdf_id << std::endl;
        }
        return ss.str();
    }
};

struct LightPathAD
{
    LightPath val;
    LightPath der;
    LightPathAD() {}
    LightPathAD(const LightPath &val) : val(val), der(val)
    {
        der.setZero();
    }
    void reserve(int n)
    {
        val.reserve(n);
        der.resize(n);
    }
    void clear()
    {
        val.clear();
        der.clear();
    }
    void zeroGrad()
    {
        der.setZero(val.vertices.size());
    }
};

namespace algorithm1_ptracer
{
    struct Tuple
    {
        int type;
        RndSampler *sampler;
        std::array<Intersection *, 3> intersections;
    };

    struct TupleAD
    {
        Tuple val;
        Tuple der;
    };

    struct LightPath
    {
        void clear()
        {
            vertices.clear();
            path.clear();
        }
        void setZero()
        {
            for (auto &v : vertices)
            {
                v.setZero();
            }
        }
        void append(const Intersection &v)
        {
            vertices.push_back(v);
            path.push_back(vertices.size() - 1);
        }
        void resize(int n)
        {
            vertices.resize(n);
            path.resize(n);
        }
        std::vector<Intersection> vertices;
        std::vector<int> path;
    };

#include <core/template.h>
    using LightPathAD = TypeAD<LightPath>;

    // valid, anti_path, pdf1A, pdf1B, pdf2A, pdf2B
    std::tuple<bool, Intersection, Float, Float, Float, Float> antithetic(const Scene &scene, const LightPath &path);
    std::tuple<bool, Intersection, Float, Float, Float, Float> antithetic_surf(const Scene &scene, const LightPath &path, RndSampler *sampler);
    std::tuple<bool, Intersection, Float, Float, Float, Float> antithetic_vol(const Scene &scene, const LightPath &path, RndSampler *sampler, bool is_equal_trans = false);
    Spectrum eval(const Scene &scene, LightPath &path, RndSampler *sampler);
    void d_eval(SceneAD &sceneAD,
                LightPathAD &pathAD,
                Spectrum d_value, RndSampler *sampler);
    void baseline(const Scene &scene, Scene &d_scene,
                  LightPathAD &pathAD,
                  Spectrum d_value, RndSampler *sampler);
}

namespace algorithm1
{
    void d_velocity(SceneAD &sceneAD, const Intersection &its, Float d_u);

    Spectrum eval(const Scene &scene, LightPath &path, RndSampler *sampler);
    void d_eval(const Scene &scene, Scene &d_scene,
                LightPathAD &pathAD,
                Spectrum d_value, RndSampler *sampler);
    void baseline(const Scene &scene, Scene &d_scene,
                  LightPathAD &pathAD,
                  Spectrum d_value, RndSampler *sampler);
}

namespace algorithm1_vol
{
    Spectrum eval(const Scene &scene, LightPath &path, RndSampler *sampler);
    void d_evalPath(const Scene &scene, Scene &d_scene,
                    const LightPath &path, LightPath &d_path,
                    Spectrum d_value);
    void d_evalVertex(const Scene &scene, Scene &d_scene,
                      const LightPath &path, LightPath &d_path,
                      RndSampler *sampler);
    void d_getPath(const Scene &scene, Scene &d_scene, LightPath &path, LightPath &d_path);
    void d_getPoint(const Scene &scene, Scene &d_scene, const Intersection &v, Intersection &d_v);

    void d_eval(const Scene &scene, Scene &d_scene,
                LightPathAD &pathAD,
                Spectrum d_value, RndSampler *sampler);
    void baseline(const Scene &scene, Scene &d_scene,
                  LightPathAD &pathAD,
                  Spectrum d_value, RndSampler *sampler);

    void getPointFwd(const Scene &scene, Scene &d_scene, const Intersection &v, Intersection &d_v);
    void getPathFwd(const Scene &scene, Scene &d_scene, LightPath &path, LightPath &d_path);
    void evalVertexFwd(const Scene &scene, Scene &d_scene,
                       const LightPath &path, LightPath &d_path,
                       RndSampler *sampler);

    std::pair<Spectrum, Spectrum> evalPathFwd(const Scene &scene, Scene &d_scene,
                                              const LightPath &path, LightPath &d_path);
    std::pair<Spectrum, Spectrum> evalFwd(const Scene &scene, Scene &d_scene,
                                          LightPathAD &pathAD, RndSampler *sampler);
    std::pair<Spectrum, Spectrum> baselineFwd(const Scene &scene, Scene &d_scene,
                                              LightPathAD &pathAD, RndSampler *sampler);
}

namespace algorithm1_bdpt
{
    struct LightPath
    {
        LightPath() {}

        void clear()
        {
            vertices.clear();
            weight.clear();
        }
        void setZero()
        {
            for (auto &v : vertices)
                v.setZero();
            for (int i = 0; i < 3; i++)
                antithetic_vtx[i].setZero();
        }
        void append(const Intersection &v)
        {
            vertices.push_back(v);
            weight.push_back(0.f);
        }
        void reserve(int n)
        {
            vertices.reserve(n);
            weight.reserve(n);
        }
        int size() const { return vertices.size(); }
        const Intersection &operator[](std::size_t idx) const { return vertices[idx]; }
        Intersection &operator[](std::size_t idx) { return vertices[idx]; }

        bool isCameraPath;
        Array2i pixelIdx;
        Intersection antithetic_vtx[3];
        std::vector<Intersection> vertices;
        std::vector<Float> weight;
    };

#include <core/template.h>
    using LightPathAD = TypeAD<LightPath>;
    Spectrum d_evalPath(SceneAD &sceneAD, LightPathAD &cameraPathAD, int s, LightPathAD &lightPathAD, int t,
                        Float w, bool antithetic_success, Spectrum d_value);
    void d_evalVertex(SceneAD &sceneAD, LightPathAD &pathAD, int index);
    void d_getPoint(SceneAD &sceneAD, LightPathAD &pathAD, int index, bool antithetic);

}