#pragma once

#include <vector>
#include <igl/copyleft/tetgen/tetrahedralize.h>
#include <igl/in_element.h>
#include "fwd.h"
#include "omp.h"
#include "nanoflann.hpp"

#define USE_KD_TREE
#define NUM_NEAREST_VERTICES 100

struct TetrahedronMesh {
    inline static double tetra_vol(const Vector3d &a, const Vector3d &b, const Vector3d &c, const Vector3d &d) {
        return std::abs((a - d).dot((b - d).cross(c - d)))/6.;
    }

    inline static FloatAD tetra_vol_boundary(const VectorAD &a, const VectorAD &b, const VectorAD &c, const VectorAD &d) {
        return ((a - d).dot((b - d).cross(c - d))).abs()/6.;
    }


    TetrahedronMesh() = default;

    template <typename T>
    struct PointCloud {
        struct Point
        {
            T  x,y,z;
        };

        std::vector<Point>  pts;

        inline size_t kdtree_get_point_count() const { return pts.size(); }

        inline T kdtree_get_pt(const size_t idx, const size_t dim) const
        {
            if (dim == 0) return pts[idx].x;
            else if (dim == 1) return pts[idx].y;
            else return pts[idx].z;
        }

        template <class BBOX>
        bool kdtree_get_bbox(BBOX& /* bb */) const { return false; }
    };

    template <typename T>
    using KDtree = nanoflann::KDTreeSingleIndexAdaptor<nanoflann::L2_Simple_Adaptor<T, PointCloud<T> >, PointCloud<T>, 3>;

    template <int degree>
    void init(const std::vector<VectorAD> &vertices, const std::vector<Eigen::Matrix<int, degree, 1>> &faces) {
        static_assert(degree == 3 || degree == 4);

        Eigen::MatrixXd V(vertices.size(), 3);
        // printf("Passed vertices size = %d\n", vertices.size());
        for ( int i = 0; i < static_cast<int>(vertices.size()); ++i )
            V.row(i) = vertices[i].val;

        Eigen::MatrixXi F(faces.size(), degree);
        for ( int i = 0; i < static_cast<int>(faces.size()); ++i )
            F.row(i) = faces[i];

        Eigen::MatrixXi TF;
        int status = igl::copyleft::tetgen::tetrahedralize(V, F, "", m_TV, m_TT, TF);
        assert(status == 0);
        if ( m_TV.rows() == V.rows() ) {
            assert((m_TV - V).norm() < 1e-5);
            printf("[TETRAHEDRON] No new vertices are added!\n");
            m_vertices = vertices;
        } else {
            printf("[TETRAHEDRON] %ld -> %ld vertices after tetrahedralization...\n", V.rows(), m_TV.rows());
            m_vertices.resize(m_TV.rows(), VectorAD(Vector::Zero()));
            const int nworker = omp_get_num_procs();
            std::vector<int> stats(nworker, 0);

#ifdef USE_KD_TREE
            PointCloud<Float> vCloud;
            vCloud.pts.resize(V.rows());
            for (size_t i = 0; i < V.rows(); i++) {
                vCloud.pts[i].x = V.row(i).x();
                vCloud.pts[i].y = V.row(i).y();
                vCloud.pts[i].z = V.row(i).z();
            }

            KDtree<Float> vIndices(3, vCloud, nanoflann::KDTreeSingleIndexAdaptorParams(10));
            vIndices.buildIndex();
#endif
            #pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
            for (long int i = 0; i < m_TV.rows(); i++) {
                const int tid = omp_get_thread_num();
                Float tot_invDist2 = 0.0;
                m_vertices[i].val = m_TV.row(i);
                m_vertices[i].zeroGrad();
                bool interpolated = true;
#ifdef USE_KD_TREE
                Float queryPoint[3] = { m_TV.row(i).x(), m_TV.row(i).y(), m_TV.row(i).z()};
                size_t matched_indices[NUM_NEAREST_VERTICES];
                Float dist_sqr[NUM_NEAREST_VERTICES];
                if ( vIndices.knnSearch(queryPoint, NUM_NEAREST_VERTICES, matched_indices, dist_sqr) != NUM_NEAREST_VERTICES ) assert(false);
                // assert( vIndices.knnSearch(queryPoint, NUM_NEAREST_VERTICES, matched_indices, dist_sqr) == NUM_NEAREST_VERTICES );
                if ( dist_sqr[0] < Epsilon ) {
                    m_vertices[i].der = vertices[matched_indices[0]].der;
                    interpolated = false;
                } else {
                    for (int j = 0; j < NUM_NEAREST_VERTICES; j++) {
                        m_vertices[i].der += vertices[matched_indices[j]].der/dist_sqr[j];
                        tot_invDist2 += 1.0/dist_sqr[j];
                    }
                }
#else
                for (long int j = 0; j < V.rows(); j++) {
                    const VectorAD& vAD = vertices[j];
                    Float dist2 = (vAD.val - m_vertices[i].val).squaredNorm();
                    if ( dist2 < Epsilon ) {
                        m_vertices[i].der = vertices[j].der;
                        interpolated = false;
                        break;
                    } else {
                        m_vertices[i].der += vertices[j].der/dist2;
                        tot_invDist2 += 1.0/dist2;
                    }
                }
#endif
                if ( interpolated ) {
                    stats[tid]++;
                    m_vertices[i].der /= tot_invDist2;
                }
            }

            int num_interpolated = 0;
            for (int i = 0; i < nworker; i++)
                num_interpolated += stats[i];
            printf("[TETRAHEDRON] %d vertices velocity calculated...\n", num_interpolated);
        }

        m_tree.init(m_TV, m_TT);
        m_vol.resize(m_TT.rows());
        for ( int i = 0; i < m_TT.rows(); ++i ) {
            const Vector3AD &a = m_vertices[m_TT(i, 0)],
                            &b = m_vertices[m_TT(i, 1)],
                            &c = m_vertices[m_TT(i, 2)],
                            &d = m_vertices[m_TT(i, 3)];
            m_vol[i] = ((a - d).dot((b - d).cross(c - d))).abs()/6.;
        }
    }


    bool query(const Vector3d &p, VectorAD &q, FloatAD &J) const {
        Eigen::VectorXi I;
        igl::in_element(m_TV, m_TT, p.transpose(), m_tree, I);

        int idx = I[0];
        if ( idx < 0 ) return false;

        auto tetra = m_TT.row(idx);
        Vector4d bary;
        bary[0] = tetra_vol(m_TV.row(tetra[1]), m_TV.row(tetra[2]), m_TV.row(tetra[3]), p);
        bary[1] = tetra_vol(m_TV.row(tetra[0]), m_TV.row(tetra[2]), m_TV.row(tetra[3]), p);
        bary[2] = tetra_vol(m_TV.row(tetra[0]), m_TV.row(tetra[1]), m_TV.row(tetra[3]), p);
        bary[3] = tetra_vol(m_TV.row(tetra[0]), m_TV.row(tetra[1]), m_TV.row(tetra[2]), p);
        bary /= m_vol[idx].val;
        assert(std::abs(bary.sum() - 1.) < 1e-5);

        q.zero();
        for ( int i = 0; i < 4; ++i )
            q += bary[i]*m_vertices[tetra[i]];

        J = m_vol[idx]/m_vol[idx].val;
        return true;
    }

    bool query_boundary(const VectorAD &p, VectorAD& q) const {
        Eigen::VectorXi I;
        igl::in_element(m_TV, m_TT, p.val.transpose(), m_tree, I);
        int idx = I[0];
        if ( idx < 0 ) return false;

        auto tetra = m_TT.row(idx);
        FloatAD bary[4];
        bary[0] = tetra_vol_boundary(m_vertices[tetra[1]], m_vertices[tetra[2]], m_vertices[tetra[3]], p);
        bary[1] = tetra_vol_boundary(m_vertices[tetra[0]], m_vertices[tetra[2]], m_vertices[tetra[3]], p);
        bary[2] = tetra_vol_boundary(m_vertices[tetra[0]], m_vertices[tetra[1]], m_vertices[tetra[3]], p);
        bary[3] = tetra_vol_boundary(m_vertices[tetra[0]], m_vertices[tetra[1]], m_vertices[tetra[2]], p);
        q.zero();
        for ( int i = 0; i < 4; ++i )
            q += bary[i]/m_vol[idx] * VectorAD(m_vertices[tetra[i]].val);
        assert((q.val-p.val).norm() < Epsilon);
        return true;
    }

    std::vector<VectorAD> m_vertices;
    std::vector<FloatAD> m_vol;

    Eigen::MatrixXd m_TV;
    Eigen::MatrixXi m_TT;
    igl::AABB<Eigen::MatrixXd, 3> m_tree;
};
