#pragma once

#include <core/fwd.h>
#include <core/utils.h>
#include <core/nanoflann.hpp>
#include <omp.h>
#include <igl/copyleft/tetgen/tetrahedralize.h>
#include <igl/in_element.h>
#include <vector>
#include <render/common.h>
struct Scene;
#define USE_KD_TREE
#define NUM_NEAREST_VERTICES 100

struct TetrahedronMesh
{

    inline static double tetra_volAD(const Vector3d &a, const Vector3d &b,
                                     const Vector3d &c, const Vector3d &d)
    {
        return std::abs((a - d).dot((b - d).cross(c - d))) / 6.;
    }

    TetrahedronMesh() {}

    template <int degree>
    TetrahedronMesh(const std::vector<Vector> &vertices,
                    const std::vector<Eigen::Matrix<int, degree, 1>> &faces,
                    const std::vector<std::pair<int, int>> &ids) : m_vertices(vertices), m_ids(ids)
    {
        static_assert(degree == 3 || degree == 4);
        Eigen::MatrixXd V(vertices.size(), 3);
        for (int i = 0; i < static_cast<int>(vertices.size()); ++i)
            V.row(i) = vertices[i];

        Eigen::MatrixXi F(faces.size(), degree);
        for (int i = 0; i < static_cast<int>(faces.size()); ++i)
            F.row(i) = faces[i];
        Eigen::MatrixXi TF;
        int status =
            igl::copyleft::tetgen::tetrahedralize(V, F, "pYQ", m_TV, m_TT, TF); // Q: quiet
        assert(status == 0);
        // NOTE the tetrahedron mesh will have new vertices
        // assert(m_TV.rows() == V.rows());
        // assert((m_TV - V).norm() < 1e-5);
        // std::vector<Eigen::Vector4i> tetra_faces;
        // FIXME reject all degenerated tetrahedra
        // for (int i = 0; i < m_TT.rows(); ++i)
        // {
        //     // check all face normals. filter the tetrahedrons
        //     bool success = true;
        //     for (int j = 0; j < 4; ++j)
        //     {
        //         Vector3d a = m_TV.row(m_TT(i, j));
        //         Vector3d b = m_TV.row(m_TT(i, (j + 1) % 4));
        //         Vector3d c = m_TV.row(m_TT(i, (j + 2) % 4));
        //         Vector3d v1 = (b - a).normalized();
        //         Vector3d v2 = (c - a).normalized();
        //         Vector3d n1 = v1.cross(v2);
        //         if (n1.norm() < 1e-5)
        //         {
        //             success = false;
        //             break;
        //         }
        //         n1 = n1.normalized();
        //         for (int k = 0; k < 4; ++k)
        //         {
        //             if (j == k)
        //                 continue;
        //             Vector3d a = m_TV.row(m_TT(i, k));
        //             Vector3d b = m_TV.row(m_TT(i, (k + 1) % 4));
        //             Vector3d c = m_TV.row(m_TT(i, (k + 2) % 4));
        //             Vector3d v1 = (b - a).normalized();
        //             Vector3d v2 = (c - a).normalized();
        //             Vector3d n2 = v1.cross(v2).normalized();
        //             if (n1.cross(n2).norm() < 1e-5)
        //             {
        //                 success = false;
        //                 break;
        //             }
        //         }
        //     }
        //     if (success)
        //         tetra_faces.push_back(m_TT.row(i));
        // }

        // m_TT = Eigen::MatrixXi(tetra_faces.size(), 4);
        // for (int i = 0; i < tetra_faces.size(); ++i)
        //     m_TT.row(i) = tetra_faces[i];

        // FIXME
        m_tree.init(m_TV, m_TT);
        m_vol.resize(m_TT.rows());
        for (int i = 0; i < m_TT.rows(); ++i)
        {
            const Vector3 &a = m_TV.row(m_TT(i, 0)), &b = m_TV.row(m_TT(i, 1)),
                          &c = m_TV.row(m_TT(i, 2)), &d = m_TV.row(m_TT(i, 3));
            m_vol[i] = abs(((a - d).dot((b - d).cross(c - d)))) / 6.;
        }
        state = ESConfigured;
    }

    Vector4i getTet(int i) const
    {
        return m_TT.row(i);
    }

    Vector getVertex(const Scene &scene, int i) const;

    template <int degree>
    void setTetmesh(const std::vector<Vector> &vertices,
                    const std::vector<Eigen::Matrix<int, degree, 1>> &faces,
                    const std::vector<std::pair<int, int>> &ids)
    {
        m_vertices = vertices;
        m_ids = ids;
        static_assert(degree == 3 || degree == 4);
        Eigen::MatrixXd V(vertices.size(), 3);
        for (int i = 0; i < static_cast<int>(vertices.size()); ++i)
            V.row(i) = vertices[i];

        Eigen::MatrixXi F(faces.size(), degree);
        for (int i = 0; i < static_cast<int>(faces.size()); ++i)
            F.row(i) = faces[i];
        Eigen::MatrixXi TF;
        int status =
            igl::copyleft::tetgen::tetrahedralize(V, F, "-Q", m_TV, m_TT, TF);
        assert(status == 0);
        assert(m_TV.rows() == V.rows());
        assert((m_TV - V).norm() < 1e-5);
        m_tree.init(m_TV, m_TT);
        m_vol.resize(m_TT.rows());
        for (int i = 0; i < m_TT.rows(); ++i)
        {
            const Vector3 &a = m_TV.row(m_TT(i, 0)), &b = m_TV.row(m_TT(i, 1)),
                          &c = m_TV.row(m_TT(i, 2)), &d = m_TV.row(m_TT(i, 3));
            m_vol[i] = abs(((a - d).dot((b - d).cross(c - d)))) / 6.;
        }
        state = ESConfigured;
    }

    template <typename T>
    struct PointCloud
    {
        struct Point
        {
            T x, y, z;
        };

        std::vector<Point> pts;

        inline size_t kdtree_get_point_count() const { return pts.size(); }

        inline T kdtree_get_pt(const size_t idx, const size_t dim) const
        {
            if (dim == 0)
                return pts[idx].x;
            else if (dim == 1)
                return pts[idx].y;
            else
                return pts[idx].z;
        }

        template <class BBOX>
        bool kdtree_get_bbox(BBOX & /* bb */) const
        {
            return false;
        }
    };

    template <typename T>
    using KDtree = nanoflann::KDTreeSingleIndexAdaptor<
        nanoflann::L2_Simple_Adaptor<T, PointCloud<T>>, PointCloud<T>, 3>;

    // NOTE: make sure the gradient can propagate to the vertices
    bool query(const Vector3d &p, Vector &q, Float &J) const;

    // NOTE: make sure the gradient can propagate to the vertices
    bool queryAD(const Scene &scene, const Vector3d &p, Vector &q, Float &J) const;

    int in_element(const Vector3d &p) const;
    Vector4 getBarycentric(int tet_id, const Vector3d &p) const;
    Float getVol(int i) const
    {
#ifdef PREPROCESS
        return m_vol[i];
#else
        const Vector3 &a = m_vertices[m_TT(i, 0)], &b = m_vertices[m_TT(i, 1)],
                      &c = m_vertices[m_TT(i, 2)], &d = m_vertices[m_TT(i, 3)];
        return tetra_volAD(a, b, c, d);
#endif
    }
    Float getVol(const Scene &scene, int i) const;

    EState state = EState::ESUninit;

    // m_vertices is managed by the containing class, which is Shape
    std::vector<Vector> m_vertices;
    // store the corresponding shape id and vertex id of each vertex
    // in order to propagate the gradient to the vertices
    std::vector<std::pair<int, int>> m_ids; // <shape_id, vertex_id>
    std::vector<Float> m_vol;

    Eigen::MatrixXd m_TV;
    Eigen::MatrixXi m_TT;
    igl::AABB<Eigen::MatrixXd, 3> m_tree;
};
