#pragma once
#include <core/fwd.h>
#include <core/ptr.h>
#include <core/macro.h>

struct Ray;
struct Intersection;
struct Medium;

enum EState
{
    ESInit = 0,
    ESLoaded = 1,
    ESConfigured = 2,
    ESUninit = 3
};

enum EMeasure
{
    /// Invalid measure
    EMInvalid = 0,
    /// Solid angle measure
    EMSolidAngle = 1,
    /// Length measure
    EMLength = 2,
    /// Area measure
    EMArea = 3,
    /// Discrete measure
    EMDiscrete = 4
};

enum IntersectionMode
{
    EMaterial = 0,
    ESpatial = 1,
    EReference = 2
};

struct PositionSamplingRecord
{
    Vector p, n;
    Vector2 uv;
    Float J;
    Float pdf;
    EMeasure measure;
    // NOTE: for algo1
    int shape_id;
    int tri_id;
    Vector2 barycentric;
};

struct DirectSamplingRecord : public PositionSamplingRecord
{
    Vector ref;
    /*  a reference point within a medium or on a transmissive surface will set dRec.refN = 0  */
    Vector refN;
    Vector dir; // direction pointing from ref to p
    Float dist;
    Float G;
    Spectrum emittance;

    int interactions = 0;

    DirectSamplingRecord(const Vector &ref) : ref(ref), refN(Vector::Zero()) {}
    DirectSamplingRecord(const Vector &ref, const Vector &refN) : ref(ref), refN(refN) {}
    DirectSamplingRecord(const Intersection &refIts);
};

struct EdgeSamplingRecord
{
    Float t;    // [0,1]
    Vector ref; // point on boundary
    Float pdf;
    int edge_id;
    int shape_id; // shape to which the sampled edge belongs
    int med_id;
};

struct EdgeRaySamplingRecord : EdgeSamplingRecord
{
    Vector dir;
};

struct BoundarySamplingRecord : public EdgeRaySamplingRecord
{
    //
    bool onSurface_S = true;
    // for surface vertex
    int shape_id_S = -1;
    int tri_id_S = -1;
    Vector3 barycentric3_S;
    // for volume vertex
    int med_id_S = -1;
    int tet_id_S = -1;
    Vector4 barycentric4_S;
    Float pdf_S = 1.; // pdf of sampling a volume vertex, or the probability of sampling a surface vertex

    bool onSurface_D = true;
    int shape_id_D = -1;
    int tri_id_D = -1;
    Vector3 barycentric3_D;
    int med_id_D = -1;
    int tet_id_D = -1;
    Vector4 barycentric4_D;
    Float pdf_D = 1.;
};

struct PixelBoundarySegmentInfo
{
    Vector xD,
        xB,
        xS_0, xS_1, xS_2;

    Eigen::Matrix<Float, 9, 1> getVelocities() const
    {
        Eigen::Matrix<Float, 9, 1> velocities;
        velocities << xS_0.x(), xS_0.y(), xS_0.z(),
            xS_1.x(), xS_1.y(), xS_1.z(),
            xS_2.x(), xS_2.y(), xS_2.z();
        return velocities;
    }

    Float maxCoeff() const
    {
        Float res = xS_0.cwiseAbs().maxCoeff();
        res = std::max(res, xS_1.cwiseAbs().maxCoeff());
        res = std::max(res, xS_2.cwiseAbs().maxCoeff());
        res = std::max(res, xB.cwiseAbs().maxCoeff());
        res = std::max(res, xD.cwiseAbs().maxCoeff());
        return res;
    }

    Float maxCoeff(bool S, bool B, bool D) const
    {
        Float res = 0.0;
        if (S)
        {
            res = std::max(res, xS_0.cwiseAbs().maxCoeff());
            res = std::max(res, xS_0.cwiseAbs().maxCoeff());
            res = std::max(res, xS_0.cwiseAbs().maxCoeff());
        }

        if (B)
        {
            res = std::max(res, xB.cwiseAbs().maxCoeff());
        }

        if (D)
        {
            res = std::max(res, xD.cwiseAbs().maxCoeff());
        }
        return res;
    }

    void setZero()
    {
        xS_0.setZero();
        xS_1.setZero();
        xS_2.setZero();

        xB.setZero();

        xD.setZero();
    }
};

struct BoundarySegmentInfo
{
    Vector xS_0, xS_1, xS_2,
        xB_0, xB_1,
        xD_0, xD_1, xD_2;

    Float maxCoeff() const
    {
        Float res = xS_0.cwiseAbs().maxCoeff();
        res = std::max(res, xS_1.cwiseAbs().maxCoeff());
        res = std::max(res, xS_2.cwiseAbs().maxCoeff());
        res = std::max(res, xB_0.cwiseAbs().maxCoeff());
        res = std::max(res, xB_1.cwiseAbs().maxCoeff());
        res = std::max(res, xD_0.cwiseAbs().maxCoeff());
        res = std::max(res, xD_1.cwiseAbs().maxCoeff());
        res = std::max(res, xD_2.cwiseAbs().maxCoeff());
        return res;
    }

    Float maxCoeff(bool S, bool B, bool D) const
    {
        Float res = 0.0;
        if (S)
        {
            res = std::max(res, xS_0.cwiseAbs().maxCoeff());
            res = std::max(res, xS_0.cwiseAbs().maxCoeff());
            res = std::max(res, xS_0.cwiseAbs().maxCoeff());
        }

        if (B)
        {
            res = std::max(res, xB_0.cwiseAbs().maxCoeff());
            res = std::max(res, xB_1.cwiseAbs().maxCoeff());
        }

        if (D)
        {
            res = std::max(res, xD_0.cwiseAbs().maxCoeff());
            res = std::max(res, xD_1.cwiseAbs().maxCoeff());
            res = std::max(res, xD_2.cwiseAbs().maxCoeff());
        }
        return res;
    }

    void setZero()
    {
        xS_0.setZero();
        xS_1.setZero();
        xS_2.setZero();

        xB_0.setZero();
        xB_1.setZero();

        xD_0.setZero();
        xD_1.setZero();
        xD_2.setZero();
    }
};

// struct Base
// {
//     virtual ~Base() {}
//     virtual Base *clone() const = 0;
// };

// template <typename Derived>
// struct BaseCRTP : Base
// {
//     virtual Base *clone() const
//     {
//         return new Derived(static_cast<Derived const &>(*this));
//     }
// };

struct RenderOptions
{
    RenderOptions() = default;
    RenderOptions(uint64_t seed, int num_samples, int max_bounces,
                  int num_samples_primary_edge, int num_samples_secondary_edge,
                  bool quiet, int mode = -1, float ddistCoeff = 0.0f)
        : seed(seed), num_samples(num_samples), max_bounces(max_bounces),
          num_samples_primary_edge(num_samples_primary_edge),
          num_samples_secondary_edge(num_samples_secondary_edge), quiet(quiet),
          mode(mode), ddistCoeff(ddistCoeff)
    {
        num_samples_secondary_edge_direct = num_samples_secondary_edge;
        num_samples_secondary_edge_indirect = num_samples_secondary_edge;
        grad_threshold = 1e8f;
    }

    uint64_t seed;
    int num_samples;
    int max_bounces;
    int num_samples_primary_edge;   // Camera ray
    int num_samples_secondary_edge; // Secondary (i.e., reflected/scattered) rays
    int sppe0;
    bool quiet;
    int mode;
    Float ddistCoeff;

    // For path-space differentiable rendering
    int num_samples_secondary_edge_direct;
    int num_samples_secondary_edge_indirect;
    int num_samples_secondary_edge_direct_point_light;
    float grad_threshold;

    int block_size = 4;
};

template <typename T>
std::vector<T *> deepcopy(const std::vector<T *> &v)
{
    std::vector<T *> _v(v.size());
    for (int i = 0; i < v.size(); i++)
        _v[i] = v[i]->clone();
    return _v;
}

template <typename T>
std::vector<T *> deepcopy1(const std::vector<T *> &v)
{
    std::vector<T *> _v(v.size());
    for (int i = 0; i < v.size(); i++)
        _v[i] = new T(*v[i]);
    return _v;
}
namespace common
{
    template <typename T>
    void setZero(std::vector<T> &v)
    {
        for (auto &e : v)
            e.setZero();
    }

    template <typename T>
    void setZero(std::vector<T *> &v)
    {
        for (auto &e : v)
            e->setZero();
    }
}

template <typename T>
void merge_vector(std::vector<T> &v1, const std::vector<T> &v2)
{
    assert(v1.size() == v2.size());
    for (int i = 0; i < v1.size(); i++)
        v1[i] += v2[i];
}

template <typename T>
void merge_pointer_vector(std::vector<T> &v1, const std::vector<T> &v2)
{
    assert(v1.size() == v2.size());
    for (int i = 0; i < v1.size(); i++)
        v1[i]->merge(v2[i]);
}

template <typename T>
void merge_object_vector(std::vector<T> &v1, const std::vector<T> &v2)
{
    assert(v1.size() == v2.size());
    for (int i = 0; i < v1.size(); i++)
        v1[i].merge(v2[i]);
}

template <typename T>
void delete_vector(const std::vector<T *> &v)
{
    for (int i = 0; i < v.size(); i++)
        delete v[i];
}

/**
 * @brief compute the direction defined by the intersection of two planes
 *
 * @param n1        normal of the first plane
 * @param n2        normal of the second plane
 * @return Vector   direction of the the intersection
 */
inline Vector planeIntersection(const Vector &n1, const Vector &n2)
{
    return n1.cross(n2).normalized();
}

/* get the projected direction of v along dir on plane define by n*/
inline Vector project(const Vector &v,
                      const Vector &dir,
                      const Vector &n)
{
    // normal of the plane defined by v and dir
    Vector n1 = v.cross(dir);
    return planeIntersection(n, n1);
}

// paper p.9 fig.8
__attribute__((noinline)) Float sinB(const Vector &xS,
                                     const Vector &xB, const Vector &v,
                                     const Vector &xD);
__attribute__((noinline)) Float sinD(const Vector &xS,
                                     const Vector &xB, const Vector &v,
                                     const Vector &xD, const Vector &n);
__attribute__((optnone)) Float dlS_dlB(const Vector &xS,
                                       const Vector &xB, const Vector &v,
                                       const Vector &xD, const Vector &n);

__attribute__((optnone)) Float dlD_dlB(const Vector &xD,
                                       const Vector &xB, const Vector &v,
                                       const Vector &xS, const Vector &n);

__attribute__((noinline)) Float dA_dw(const Vector &ref,
                                      const Vector &p, const Vector &n);

__attribute__((noinline)) Float dASdAD_dlBdwBdrD(const Vector &xS, const Vector &nS,
                                                 const Vector &xB, const Vector &v,
                                                 const Vector &xD);

Float geometric(const Vector &ref,
                const Vector &p);

Float geometric(const Vector &ref,
                const Vector &p, const Vector &n);

Float geometric(const Vector &ref, const Vector &refN,
                const Vector &p, const Vector &n);

Float normal_velocity(const Vector &xD,
                      const Vector &xB, const Vector &v,
                      const Vector &xS, const Vector &n);

Float normal_velocity(const Vector &xD,
                      const Vector &xB, const Vector &v,
                      const Vector &xS_0, const Vector &xS_1, const Vector &xS_2,
                      const Vector &n);

Float normal_velocity(const Vector &xS_0, const Vector &xS_1, const Vector &xS_2,
                      const Vector &xB_0, const Vector &xB_1, const Vector &xB_2, const Float t, const Vector &dir,
                      const Vector &xD_0, const Vector &xD_1, const Vector &xD_2);
// compatible with point lights
Float normal_velocity(const Vector &xS,
                      const Vector &xB_0, const Vector &xB_1, const Vector &xB_2, const Float t, const Vector &dir,
                      const Vector &xD_0, const Vector &xD_1, const Vector &xD_2);
// xS -> surface vertex xD
Float normal_velocity(const Vector &xS,
                      const Vector &xB_0, const Vector &xB_1, const Vector &xB_2, const Float t,
                      const Vector &xD_0, const Vector &xD_1, const Vector &xD_2);
// xS -> volume vertex xD
Float normal_velocity(const Vector &xS,
                      const Vector &xB_0, const Vector &xB_1, const Vector &xB_2, const Float t,
                      const Vector &xD_0, const Vector &xD_1, const Vector &xD_2, const Vector &xD_3, const Float dist);

Vector rayIntersectTriangle2(const Vector &v0, const Vector &v1, const Vector &v2, const Ray &ray, IntersectionMode mode);
inline Float miWeight(Float pdfA, Float pdfB)
{
    pdfA *= pdfA;
    pdfB *= pdfB;
    return pdfA / (pdfA + pdfB);
}

void d_normal_velocity(const BoundarySegmentInfo &seg, BoundarySegmentInfo &d_seg,
                       const Vector &xB_2, const Float t, const Vector &dir,
                       const Float d_res);

void d_normal_velocity_pixel(const PixelBoundarySegmentInfo &seg, PixelBoundarySegmentInfo &d_seg,
                             const Vector &nB,
                             const Float d_res);

namespace
{
    enum EStatus
    {
        Einit = 0,
        Emerged = 1,
    };
    template <typename T>
    struct GradientManager
    {
        GradientManager(T &d_scene, int nworker)
            : d_scene(d_scene), nworker(nworker)
        {
            d_scenes.reserve(nworker);
            for (int i = 0; i < nworker; i++)
                d_scenes.push_back(d_scene); // copy
            setZero();
        }
        ~GradientManager()
        {
            if (status == Emerged)
                return;
            for (auto &_d_scene : d_scenes)
                d_scene.merge(_d_scene);
        }

        void merge()
        {
            for (auto &_d_scene : d_scenes)
                d_scene.merge(_d_scene);
            status = Emerged;
        }

        T &get(int i)
        {
            assert(i < nworker);
            return d_scenes[i];
        }

        void setZero()
        {
            for (auto &_d_scene : d_scenes)
                _d_scene.setZero();
            status = Einit;
        }

        T &d_scene;
        std::vector<T> d_scenes;
        int nworker;
        EStatus status = Einit;
    };

    template <typename T>
    std::vector<T> &operator+=(std::vector<T> &a, const std::vector<T> &b)
    {
        for (int i = 0; i < a.size(); i++)
            a[i] += b[i];
        return a;
    }

    template <typename T>
    struct ThreadManager
    {
        ThreadManager(T &data, int nworker)
            : data(data), thread_data(nworker, data) {}

        T &merge()
        {
            for (auto &_data : thread_data)
                data += _data;
            return data;
        }

        T &get(int i)
        {
            return thread_data[i];
        }

        T &data;
        std::vector<T> thread_data;
    };
}

extern bool verbose;
void set_verbose(bool verbose);
bool get_verbose();

extern bool forward;
void set_forward(bool forward);
bool get_forward();