#include <tetra.hpp>
#include <core/fwd.h>
#include <render/scene.h>
#include "medium/in_element.h"
// namespace
// {
//     void preprocess()
// }

// Inactive version.
__attribute__((optnone)) double tetra_vol(const Vector3d &a, const Vector3d &b,
                                          const Vector3d &c, const Vector3d &d)
{
    return std::abs((a - d).dot((b - d).cross(c - d))) / 6.;
}

__attribute__((optnone)) int __in_element(
    const Eigen::MatrixXd &TV, const Eigen::MatrixXi &TT,
    const Vector &p, const igl::AABB<Eigen::MatrixXd, 3> &tree)
{
    // FIXME: the
    return _in_element(TV, TT, p.transpose(), tree);
}
INACTIVE_FN(__in_element, __in_element);
INACTIVE_FN(__tetra_vol, tetra_vol);

int TetrahedronMesh::in_element(const Vector3d &p) const
{
    return __in_element(m_TV, m_TT, p, m_tree);
}
Vector4 TetrahedronMesh::getBarycentric(int tet_id, const Vector3d &p) const
{
    auto tetra = m_TT.row(tet_id);
    Vector4d bary;
    bary[0] = tetra_vol(m_TV.row(tetra[1]), m_TV.row(tetra[2]),
                        m_TV.row(tetra[3]), p);
    bary[1] = tetra_vol(m_TV.row(tetra[0]), m_TV.row(tetra[2]),
                        m_TV.row(tetra[3]), p);
    bary[2] = tetra_vol(m_TV.row(tetra[0]), m_TV.row(tetra[1]),
                        m_TV.row(tetra[3]), p);
    bary[3] = tetra_vol(m_TV.row(tetra[0]), m_TV.row(tetra[1]),
                        m_TV.row(tetra[2]), p);
    bary /= getVol(tet_id);
    assert(std::abs(bary.sum() - 1.) < 1e-5);
    return bary;
}

Float TetrahedronMesh::getVol(const Scene &scene, int i) const
{
#ifdef PREPROCESS
    return m_vol[i];
#else
    auto tetra = m_TT.row(i);
    Vector a = getVertex(scene, tetra[0]);
    Vector b = getVertex(scene, tetra[1]);
    Vector c = getVertex(scene, tetra[2]);
    Vector d = getVertex(scene, tetra[3]);
    return tetra_volAD(a, b, c, d);
#endif
}

bool TetrahedronMesh::query(const Vector3d &p, Vector &q, Float &J) const
{
    Eigen::VectorXi I;
    int idx = __in_element(m_TV, m_TT, p.transpose(), m_tree);
    if (idx < 0)
        return false;
    // int idx = 1;
    auto tetra = m_TT.row(idx);
    Vector4d bary;
    bary[0] = tetra_vol(m_TV.row(tetra[1]), m_TV.row(tetra[2]),
                        m_TV.row(tetra[3]), p);
    bary[1] = tetra_vol(m_TV.row(tetra[0]), m_TV.row(tetra[2]),
                        m_TV.row(tetra[3]), p);
    bary[2] = tetra_vol(m_TV.row(tetra[0]), m_TV.row(tetra[1]),
                        m_TV.row(tetra[3]), p);
    bary[3] = tetra_vol(m_TV.row(tetra[0]), m_TV.row(tetra[1]),
                        m_TV.row(tetra[2]), p);
    // NOTE: make sure the bary is detached
    bary /= detach(getVol(idx));
    assert(std::abs(bary.sum() - 1.) < 1e-5);

    q.setZero();
    for (int i = 0; i < 4; ++i)
        q += bary[i] * m_vertices[tetra[i]];

    Float V = getVol(idx);
    J = V / detach(V);
    return true;
}

bool TetrahedronMesh::queryAD(const Scene &scene, const Vector3d &p, Vector &q, Float &J) const
{
    Eigen::VectorXi I;
    int idx = __in_element(m_TV, m_TT, p.transpose(), m_tree);
    if (idx < 0)
        return false;

    auto tetra = m_TT.row(idx);
    Vector4d bary;
    bary[0] = tetra_vol(m_TV.row(tetra[1]), m_TV.row(tetra[2]),
                        m_TV.row(tetra[3]), p);
    bary[1] = tetra_vol(m_TV.row(tetra[0]), m_TV.row(tetra[2]),
                        m_TV.row(tetra[3]), p);
    bary[2] = tetra_vol(m_TV.row(tetra[0]), m_TV.row(tetra[1]),
                        m_TV.row(tetra[3]), p);
    bary[3] = tetra_vol(m_TV.row(tetra[0]), m_TV.row(tetra[1]),
                        m_TV.row(tetra[2]), p);
    // NOTE: make sure the bary is detached
    Float vol = getVol(scene, idx);
    bary /= detach(vol);
    assert(std::abs(bary.sum() - 1.) < 1e-5);

    // get the actual vertices
    // in order to propagate the gradient to the vertices

    q.setZero();
    for (int i = 0; i < 4; ++i)
    {
        int vId = tetra[i];
        Vector v = getVertex(scene, vId);
        q += bary[i] * v;
    }

    J = vol / detach(vol);
    return true;
}

Vector TetrahedronMesh::getVertex(const Scene &scene, int i) const
{
    if (i >= m_ids.size())
        return m_TV.row(i);
    int shape_id = m_ids[i].first;
    int v_id = m_ids[i].second;
    Shape *shape = scene.getShape(shape_id);
    if (shape == nullptr || v_id >= shape->vertices.size())
        return m_TV.row(i);
    return shape->getVertex(v_id);
}