#pragma once

#include <cmath>
#include <string>
#include <core/fwd.h>
#include <core/pmf.h>
#include <core/ptr.h>
#include <core/object.h>
#include <render/common.h>
#include "tetra.hpp"

struct Ray;
struct RayAD;
struct Intersection;
struct IntersectionAD;

struct Edge
{
    Edge() : v0(-1), v1(-1),
             f0(-1), f1(-1), length(0.0f), v2(-1), mode(0) {}
    Edge(int v0, int v1, int f0, int f1,
         Float length, int v2, int mode)
        : v0(v0), v1(v1), f0(f0), f1(f1),
          length(length), v2(v2), mode(mode) {}
    inline bool operator==(const Edge &other) const
    {
        return v0 == other.v0 && v1 == other.v1 &&
               f0 == other.f0 && f1 == other.f1;
    }
    inline bool isValid() const { return f0 >= 0; }

    int v0, v1;   // vertex IDs
    int f0, f1;   // face IDs
    Float length; // edge length

    int v2;   // a third vertex that is NOT on the edge but f0
    int mode; // 0: boundary; 1: convex; -1: concave
};

struct Sort_config
{
    Sort_config()
    {
        max_length = -1;
    }
    Sort_config(int _max_length, Float _cos_val_local, Float _cos_val_global)
    {
        assert(_max_length > 1);
        assert(_cos_val_local > -0.9999 && _cos_val_local < 0);
        assert(_cos_val_global > -0.9999 && _cos_val_global < 0);
        max_length = _max_length;
        cos_val_local = _cos_val_local;
        cos_val_global = _cos_val_global;
    }
    int max_length;
    Float cos_val_local;
    Float cos_val_global;
};

struct Shape : Object
{
    Shape() {}

    Shape(const std::vector<Vector> &vertices, const std::vector<Vector3i> &indices,
          const std::vector<Vector2> &uvs, const std::vector<Vector> &normals,
          int num_vertices, int num_triangles, int light_id, int bsdf_id, int med_int_id, int med_ext_id);

    ~Shape() {}
    Shape(const std::vector<Vector> &vertices, const std::vector<Vector3i> &indices,
          const std::vector<Vector2> &uvs, const std::vector<Vector> &normals,
          int num_vertices, int num_triangles, int light_id, int bsdf_id, int med_int_id, int med_ext_id, const Sort_config &config);

    void merge(const Shape *shape)
    {
        merge_vector(vertices, shape->vertices);
        merge_vector(normals, shape->normals);
        merge_vector(uvs, shape->uvs);
        merge_vector(faceNormals, shape->faceNormals);
    }

    void setZero()
    {
        common::setZero(vertices);
        common::setZero(normals);
        common::setZero(faceNormals);
    }

    /**
     * @brief
     *  Load vertices, uvs, vertex indices, uv indices from *.obj file.
     *  Call configure() after loading to compute face normals, vertex normals,
     *  and other dervied properties.
     */
    void load(const std::string &filename);

    /**
     * @brief
     *  Computers face normals, vertex normals, and other derived properties.
     */
    void configure();

    void advance(Float stepSize, int derId = 0);

    void setVertex(int id, const Vector &v)
    {
        assert(id >= 0 && id < vertices.size());
        vertices[id] = v;
        this->configure();
    }

    void setVertices(const std::vector<Vector> &_vertices)
    {
        assert(_vertices.size() == vertices.size());
        vertices = _vertices;
        this->configure();
    }

    const std::vector<Vector> &getVertices() const { return vertices; }

    /* ================= normal related ====================== */
    Vector getVertexNormal(int index) const
    {
#ifdef NORMAL_PREPROCESS
        return normals[index];
#else
        Vector n = Vector::Zero();
        for (int j = 0; j < adjacentFaces[index].size(); j++)
            n += getGeoNormal(adjacentFaces[index][j]);
        return n.normalized();
#endif
    }

    Vector3 getGeoNormal(int index) const
    {
#ifdef NORMAL_PREPROCESS
        return faceNormals[index];
#else
        const auto &ind = getIndices(index);
        const auto &v0 = getVertex(ind(0));
        const auto &v1 = getVertex(ind(1));
        const auto &v2 = getVertex(ind(2));
        auto n = (v1 - v0).cross(v2 - v0);
        [[maybe_unused]] Float area = n.norm();
        // assert(area > Epsilon);
        return n.normalized();
#endif
    }

    Vector getFaceNormal(int tri_index) const
    {
        return getGeoNormal(tri_index);
    }

    Vector getShadingNormal(int tri_index, const Vector2 &barycentric) const;
    /* compute the adjacent faces of the vertices */
    void computeAdjacentFaces();
    void computeFaceDistribution();
    void computeFaceNormals();
    void computeVertexNormals();

    inline bool isMediumTransition() const
    {
        return med_ext_id >= 0 || med_int_id >= 0;
    }
    inline bool hasUVs() const { return uvs.size() != 0; }
    inline bool hasNormals() const { return normals.size() != 0; }
    inline bool isEmitter() const { return light_id >= 0; }

#ifdef FORWARD
    const Vector3 getVertex(int index) const;
#else
    inline const Vector3 &getVertex(int index) const
    {
        return vertices[index];
    }
#endif

    inline const Vector3i &getIndices(int index) const
    {
        return indices[index];
    }
    inline const Vector2 &getUV(int index) const { return uvs[index]; }
    inline Float getEdgeTotLength() const { return edge_distrb.getSum(); }

    Float getArea(int index) const;

    // void samplePosition(int index, const Vector2 &rnd2, Vector &p, Vector &n) const;

    /**
     * @brief sample a position on a shape
     *
     * @param _rnd2
     * @param pRec
     *      pRec.p:     position
     *      pRec.n:     normal
     *      pRec.uv:    barycentric coordinate
     *      pRec.J:     Jacobian determinant for change of variable
     *      pRec.pdf:   pdf
     *      pRec.measure:   measure
     * @return int
     */
    int samplePosition(const Vector2 &rnd2, PositionSamplingRecord &pRec) const;

    void rayIntersect(int tri_index, const Ray &ray, Intersection &its,
                      IntersectionMode mode = ESpatial) const;

    Eigen::MatrixXi getEdges()
    {
        Eigen::MatrixXi ret(edges.size(), 2);
        for (int i = 0; i < edges.size(); i++)
        {
            ret(i, 0) = edges[i].v0;
            ret(i, 1) = edges[i].v1;
        }
        return ret;
    }

    inline Float getArea() const { return face_distrb.getSum(); }

    Vector2 getUV(int tri_index, const Vector2 &barycentric) const
    {
        const Vector3i &ind = uv_indices[tri_index];
        Vector2 uvs0, uvs1, uvs2;
        if (uvs.size() != 0)
        {
            uvs0 = getUV(ind(0));
            uvs1 = getUV(ind(1));
            uvs2 = getUV(ind(2));
        }
        else
        {
            uvs0 = Vector2{0, 0};
            uvs1 = Vector2{1, 0};
            uvs2 = Vector2{1, 1};
        }
        return (1.0f - barycentric(0) - barycentric(1)) * uvs0 +
               barycentric(0) * uvs1 +
               barycentric(1) * uvs2;
    }

    std::tuple<Vector, Vector, Float> getPoint(int tri_index, const Vector2 &_barycentric) const;
    void getPoint(const Intersection &its, Intersection &itsAD) const;
    void getPoint(int tri_index, const Vector2 &_barycentric, Intersection &its, Float &J) const;

    /**
     * @brief test if an edge is a sihoulette edge viewed from p
     */
    bool isSihoulette(const Edge &edge, const Vector &p) const;
    bool save(const std::string &name) const;
    //=======================  Forward Autodiff ==========================
#ifdef FORWARD
    bool requires_grad = false;
    Float param = 0.f;
    int vertex_idx = -1; // -1: transforming all vertices
    bool is_tranlate = false;
    bool is_rotate = false;
    bool is_scale = false;
    bool is_velocity = false;
    Vector translation = Vector(1., 0., 0.);
    Vector rotation = Vector(0., 1., 0.);
    Vector scale = Vector(1., 1., 1.);
    std::vector<Vector> velocities;
    void setTranslation(const Vector &_translation)
    {
        is_tranlate = true;
        translation = _translation;
    }
    void setRotation(const Vector &_rotation)
    {
        is_rotate = true;
        rotation = _rotation;
    }
    void setScale(const Vector &_scale)
    {
        is_scale = true;
        scale = _scale;
    }
    void setVelocities(const Eigen::MatrixX3d &_velocities)
    {
        is_velocity = true;
        assert(_velocities.rows() == vertices.size());
        velocities.resize(vertices.size());
        for (int i = 0; i < vertices.size(); i++)
        {
            velocities[i] = _velocities.row(i);
        }
    }
    Vector transform(int index, Float param) const;
#endif

    //====================================================================

    std::vector<Vector3> vertices;
    std::vector<Vector3> normals;
    std::vector<Vector3i> indices;
    std::vector<Vector2> uvs;
    std::vector<Vector3i> uv_indices;
    Matrix4x4 m_to_world;

    std::vector<std::vector<int>> adjacentFaces;

    // Per triangle properties
    std::vector<Vector3> faceNormals;
    // std::vector<Matrix3x3> faceRotations;

    int num_vertices;
    int num_triangles;

    int light_id = -1;
    int bsdf_id = -1;
    int med_int_id = -1;
    int med_ext_id = -1;

    bool m_use_face_normals = true;

    // =================Edge Sampling=================
    bool enable_edge = true;
    bool enable_draw = false;
    Sort_config sort_config;
    std::vector<Edge> edges;
    DiscreteDistribution edge_distrb;
    DiscreteDistribution face_distrb;
    // =================Edge Sampling=================

    PSDR_DECLARE_CLASS(Shape)
};

#ifdef NORMAL_PREPROCESS
void d_precompute_vertex_normal(const Shape &shape, Shape &d_shape);
void d_precompute_face_normal(const Shape &shape, Shape &d_shape);
void d_precompute_normal(const Shape &shape, Shape &d_shape);
#endif