#include <core/fwd.h>
#include <core/logger.h>
#include <core/math_func.h>
#include <core/ray.h>
#include <core/timer.h>
#include <core/transform.h>
#include <render/intersection.h>
#include <render/shape.h>
#include <tiny_obj_loader.h>

#include <algorithm>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <map>
#include <random>
#include <unordered_set>

namespace {
[[maybe_unused]] void print(std::unordered_set<int> const &s) {
    std::copy(s.begin(),
              s.end(),
              std::ostream_iterator<int>(std::cout, " "));
}

std::vector<Edge> constructEdges(const std::vector<Vector>   &vertices,
                                 const std::vector<Vector3i> &indices,
                                 const std::vector<Vector>   &faceNormals) {
    /* pair of vertex indices, vector (face i, third vertex i) */
    std::map<std::pair<int, int>, std::vector<Vector2i>> edge_map;
    for (int itri = 0; itri < static_cast<int>(indices.size()); itri++) {
        Vector3i ind = indices[itri];
        for (int iedge = 0; iedge < 3; iedge++) {
            int k1                  = iedge,
                k2                  = (iedge + 1) % 3;
            std::pair<int, int> key = (ind[k1] < ind[k2])
                                          ? std::make_pair(ind[k1], ind[k2])
                                          : std::make_pair(ind[k2], ind[k1]);
            if (edge_map.find(key) == edge_map.end())
                edge_map[key] = std::vector<Vector2i>();
            edge_map[key].push_back(Vector2i(itri,
                                             ind[(iedge + 2) % 3]));
        }
    }

    std::vector<Edge> edges;
    for (const auto &it : edge_map) {
        Float length = (vertices[it.first.first] -
                        vertices[it.first.second])
                           .norm();

        // check if good mesh
        if (it.second.size() > 2) {
            std::cerr << "Every edge can be shared by at most 2 faces!"
                      << std::endl;
            assert(false);
        } else if (it.second.size() == 2) // an edge is shared by two faces
        {
            // face id
            const int ind0 = it.second[0][0],
                      ind1 = it.second[1][0];
            if (ind0 == ind1) {
                std::cerr << "Duplicated faces!" << std::endl;
                assert(false);
            }

            const Vector &n0 = faceNormals[ind0],
                         &n1 = faceNormals[ind1];
            Float val        = n0.dot(n1);
            if (val < -1.0f + EdgeEpsilon) {
                // std::cerr << "Inconsistent normal orientation! n0.dot(n1) = "
                //           << val << std::endl;
                // assert(false);
            } else if (val < 1.0f - EdgeEpsilon) // not coplanar
            {
                //
                [[maybe_unused]] Float tmp0 = n0.dot(vertices[it.second[1][1]] -
                                                     vertices[it.first.first]),
                                       tmp1 = n1.dot(vertices[it.second[0][1]] -
                                                     vertices[it.first.first]);
                assert(math::signum(tmp0) * math::signum(tmp1) > 0.5f);
                edges.push_back(Edge(it.first.first, it.first.second,
                                     ind0, ind1, length, it.second[0][1],
                                     tmp0 > Epsilon ? -1 : 1));
            }
        } else {
            assert(it.second.size() == 1);
            edges.push_back(Edge(it.first.first, it.first.second,
                                 it.second[0][0], -1, length, it.second[0][1], 0));
        }
    }
    return edges;
}

std::vector<Edge> constructSortEdges(const std::vector<Vector>   &vertices,
                                     const std::vector<Vector3i> &indices,
                                     const std::vector<Vector> &faceNormals, const Sort_config &sort_config) {
    std::map<int, std::unordered_set<int>> vertices_neighbor;
    [[maybe_unused]] int                   vertex_size = vertices.size();
    [[maybe_unused]] int                   indice_size = indices.size();

    /* pair of vertex indices, vector (face i, third vertex i) */
    std::map<std::pair<int, int>, std::vector<Vector2i>> edge_map;
    for (int itri = 0; itri < static_cast<int>(indices.size()); itri++) {
        Vector3i ind = indices[itri];
        for (int iedge = 0; iedge < 3; iedge++) {
            int k1                  = iedge,
                k2                  = (iedge + 1) % 3;
            std::pair<int, int> key = (ind[k1] < ind[k2])
                                          ? std::make_pair(ind[k1], ind[k2])
                                          : std::make_pair(ind[k2], ind[k1]);
            if (edge_map.find(key) == edge_map.end())
                edge_map[key] = std::vector<Vector2i>();
            edge_map[key].push_back(Vector2i(itri,
                                             ind[(iedge + 2) % 3]));
        }
    }
    std::vector<Edge> edges;
    for (const auto &it : edge_map) {
        Float length = (vertices[it.first.first] -
                        vertices[it.first.second])
                           .norm();
        // check if good mesh
        if (it.second.size() > 2) {
            std::cerr << "Every edge can be shared by at most 2 faces!"
                      << std::endl;
            assert(false);
        } else if (it.second.size() == 2) // an edge is shared by two faces
        {
            // face id
            const int ind0 = it.second[0][0],
                      ind1 = it.second[1][0];
            if (ind0 == ind1) {
                std::cerr << "Duplicated faces!" << std::endl;
                assert(false);
            }

            const Vector &n0 = faceNormals[ind0],
                         &n1 = faceNormals[ind1];
            Float val        = n0.dot(n1);
            if (val < -1.0f + EdgeEpsilon) {
                std::cerr << "Inconsistent normal orientation! n0.dot(n1) = "
                          << val << std::endl;
            } else if (val < 1.0f - EdgeEpsilon) // not coplanar
            {
                [[maybe_unused]] Float tmp0 = n0.dot(vertices[it.second[1][1]] -
                                                     vertices[it.first.first]),
                                       tmp1 = n1.dot(vertices[it.second[0][1]] -
                                                     vertices[it.first.first]);
                assert(math::signum(tmp0) * math::signum(tmp1) > 0.5f);
                edges.push_back(Edge(it.first.first, it.first.second,
                                     ind0, ind1, length, it.second[0][1],
                                     tmp0 > Epsilon ? -1 : 1));
            }
        } else {
            assert(it.second.size() == 1);
            edges.push_back(Edge(it.first.first, it.first.second,
                                 it.second[0][0], -1, length, it.second[0][1], 0));
        }
    }

    if (edges[0].length < 0.00001) {
        return edges;
    }
    auto rd  = std::random_device{};
    auto rng = std::default_random_engine{ rd() };
    rng.seed(10);
    std::shuffle(std::begin(edges), std::end(edges), rng);

    if (sort_config.max_length == -1) {
        // std::cout << "using unsorted edge" << std::endl;
        return edges;
    }

    for (int i = 0; i < static_cast<int>(edges.size()); ++i) {
        if (vertices_neighbor.find(edges[i].v0) == vertices_neighbor.end())
            vertices_neighbor[edges[i].v0] = std::unordered_set<int>();
        vertices_neighbor[edges[i].v0].insert(edges[i].v1);

        if (vertices_neighbor.find(edges[i].v1) == vertices_neighbor.end())
            vertices_neighbor[edges[i].v1] = std::unordered_set<int>();
        vertices_neighbor[edges[i].v1].insert(edges[i].v0);
    }

    std::map<std::pair<int, int>, bool> edge_walk_map;
    for (int i = 0; i < static_cast<int>(edges.size()); ++i) {
        std::pair<int, int> walk_key = std::make_pair(edges[i].v0, edges[i].v1);
        edge_walk_map[walk_key]      = false;
    }

    // std::vector<int> jump;

    std::vector<Vector2i> draw;
    for (int i = 0; i < static_cast<int>(vertices_neighbor.size()); ++i) {
        // Greedy Search Here
        for (int neig : vertices_neighbor[i]) {
            int    prev    = i;
            int    curr    = neig;
            Vector P_start = vertices[curr] - vertices[prev];

            std::pair<int, int> walk_key = (prev < curr)
                                               ? std::make_pair(prev, curr)
                                               : std::make_pair(curr, prev);
            if (edge_walk_map[walk_key] == false) {
                // Greedy Search
                for (int depth = 0; depth < sort_config.max_length; ++depth) {
                    std::pair<int, int> walk_key = (prev < curr)
                                                       ? std::make_pair(prev, curr)
                                                       : std::make_pair(curr, prev);
                    if (edge_walk_map[walk_key] == true) {
                        break;
                    }
                    edge_walk_map[walk_key] = true;
                    draw.push_back(Vector2i(prev, curr));
                    int   best_id  = -1;
                    Float best_cos = 1.1;

                    for (int curr_id : vertices_neighbor[curr]) {
                        std::pair<int, int> walk_best_key = (curr < curr_id)
                                                                ? std::make_pair(curr, curr_id)
                                                                : std::make_pair(curr_id, curr);
                        if (edge_walk_map[walk_best_key] == false) {
                            Vector P0 = vertices[curr] - vertices[prev];
                            Vector P1 = vertices[curr] - vertices[curr_id];

                            Float                  cos_val       = P0.dot(P1) / (P0.norm() * P1.norm());
                            [[maybe_unused]] Float start_cos_val = P_start.dot(P1) / (P_start.norm() * P1.norm());
                            if (cos_val < best_cos) {
                                best_id  = curr_id;
                                best_cos = cos_val;
                            }
                        }
                    }
                    Vector P_end      = vertices[curr] - vertices[best_id];
                    Float  global_cos = P_start.dot(P_end) / (P_start.norm() * P_end.norm());
                    // std::cout << prev << " -> " << curr << " :-> " << best_id << std::endl;

                    if (best_cos > sort_config.cos_val_local || global_cos > sort_config.cos_val_global || best_id == -1) {
                        // std::cout << "break" << std::endl;
                        break;
                    }

                    // std::cout << prev << " -> " << curr << " :-> " << best_id << std::endl;

                    prev = curr;
                    curr = best_id;
                }
                // jump.push_back(draw.size()-1);
            }
        }
    }
    // std::cout << jump.size() << std::endl;
    // std::cout << draw.size() << std::endl;

    // for (int j : jump) {
    //     std::cout << j << " ";
    // }
    // std::cout << std::endl;

    std::map<std::pair<int, int>, int> sort_edge_map;
    for (int i = 0; i < static_cast<int>(edges.size()); ++i) {
        const auto         &e = edges[i];
        std::pair<int, int> e_pair(e.v0, e.v1);
        sort_edge_map[e_pair] = i;
    }

    std::vector<Edge> sort_edges;

    for (Vector2i draw_edge : draw) {
        std::pair<int, int> key = (draw_edge[0] < draw_edge[1])
                                      ? std::make_pair(draw_edge[0], draw_edge[1])
                                      : std::make_pair(draw_edge[1], draw_edge[0]);

        int  id_edge = sort_edge_map[key];
        Edge data    = edges[id_edge];
        data.v0      = draw_edge[0];
        data.v1      = draw_edge[1];
        sort_edges.push_back(data);
    }

    assert(sort_edges.size() == edges.size());
    std::cout << "using sorted edge" << std::endl;
    return sort_edges;
}
} // namespace

Shape::Shape(const Properties &props) {
    vertices_raw       = props.get<std::vector<Vector>>("vertices");
    indices            = props.get<std::vector<Vector3i>>("indices");
    uvs                = props.get<std::vector<Vector2>>("uvs", std::vector<Vector2>());
    uv_indices         = props.get<std::vector<Vector3i>>("uv_indices", std::vector<Vector3i>());
    normals_raw        = props.get<std::vector<Vector>>("normals", std::vector<Vector>());
    num_vertices       = vertices.size();
    num_triangles      = indices.size();
    light_id           = props.get<int>("light_id", -1);
    bsdf_id            = props.get<int>("bsdf_id");
    med_int_id         = props.get<int>("med_int_id", -1);
    med_ext_id         = props.get<int>("med_ext_id", -1);
    m_to_world         = props.get<Matrix4x4>("to_world", Matrix4x4::Identity());
    m_use_face_normals = props.get<bool>("use_face_normals", true);
    configure();
}

Shape::Shape(const std::vector<Vector> &vertices, const std::vector<Vector3i> &indices,
             const std::vector<Vector2> &uvs, const std::vector<Vector> &normals,
             int num_vertices, int num_triangles, int light_id, int bsdf_id,
             int med_int_id, int med_ext_id)
    : vertices(vertices), indices(indices), uvs(uvs),
      num_vertices(num_vertices), num_triangles(num_triangles),
      light_id(light_id), bsdf_id(bsdf_id), med_int_id(med_int_id),
      med_ext_id(med_ext_id) {
    assert(static_cast<int>(vertices.size()) == num_vertices);
    assert(static_cast<int>(indices.size()) == num_triangles);

    if (m_use_face_normals)
        assert(normals.size() == 0);
    else {
        if (normals.size() > 0) {
            assert(light_id < 0);
            assert(static_cast<int>(normals.size()) == num_vertices);
            this->normals.resize(num_vertices);
            for (int i = 0; i < num_vertices; i++)
                this->normals[i] = normals[i].normalized();
        }
    }

    configure();
}

Shape::Shape(const std::vector<Vector> &vertices, const std::vector<Vector3i> &indices,
             const std::vector<Vector2> &uvs, const std::vector<Vector> &normals,
             int num_vertices, int num_triangles, int light_id, int bsdf_id,
             int med_int_id, int med_ext_id, const Sort_config &config)
    : vertices(vertices), indices(indices), uvs(uvs),
      num_vertices(num_vertices), num_triangles(num_triangles),
      light_id(light_id), bsdf_id(bsdf_id), med_int_id(med_int_id),
      med_ext_id(med_ext_id) {
    assert(static_cast<int>(vertices.size()) == num_vertices);
    assert(static_cast<int>(indices.size()) == num_triangles);

#ifdef SHAPE_COMPUTE_VTX_NORMAL
    assert(normals.size() == 0);
#else
    if (normals.size() > 0) {
        assert(light_id < 0);
        assert(static_cast<int>(normals.size()) == num_vertices);
        this->normals.resize(num_vertices);
        for (int i = 0; i < num_vertices; i++)
            this->normals[i] = normals[i].normalized();
    }
#endif
    sort_config = config;
    configure();
}

// load vertices, indices, normals, uvs from obj file and then configure
void Shape::load(const std::string &filename, bool auto_config) {
    vertices_raw.clear();
    vertices.clear();
    indices.clear();
    normals.clear();
    uvs.clear();

    tinyobj::attrib_t                attrib;
    std::vector<tinyobj::shape_t>    shapes;
    std::vector<tinyobj::material_t> materials;
    std::string                      warn, err;
    PSDR_ASSERT_MSG(tinyobj::LoadObj(&attrib, &shapes, &materials,
                                     &warn, &err, filename.c_str()),
                    fmt::runtime("Failed to load obj file: " + filename + "\n" + err));
    assert(attrib.vertices.size() % 3 == 0);

    // Load vertices
    vertices_raw.resize(attrib.vertices.size() / 3);
    for (int i = 0; i < static_cast<int>(vertices_raw.size()); i++)
        vertices_raw[i] = Vector(attrib.vertices[i * 3],
                                 attrib.vertices[i * 3 + 1],
                                 attrib.vertices[i * 3 + 2]);
    // Load uvs
    if (attrib.texcoords.size() > 0) {
        uvs.resize(attrib.texcoords.size() / 2);
        for (int i = 0; i < static_cast<int>(uvs.size()); i++)
            uvs[i] = Vector2(attrib.texcoords[i * 2],
                             attrib.texcoords[i * 2 + 1]);
    }

    // Load faces
    indices.clear();
    uv_indices.clear();
    for (const auto &shape : shapes) {
        auto &idx = shape.mesh.indices;
        for (int f = 0; f < static_cast<int>(shape.mesh.num_face_vertices.size()); f++) {
            int fv = shape.mesh.num_face_vertices[f];
            PSDR_ASSERT_MSG(fv == 3, "Only triangular faces are supported!");
            // Load vertex indices
            indices.push_back(Vector3i(idx[f * 3].vertex_index,
                                       idx[f * 3 + 1].vertex_index,
                                       idx[f * 3 + 2].vertex_index));
            // Load uv indices
            if (uvs.size() > 0) {
                PSDR_ASSERT_MSG(idx[f * 3].texcoord_index >= 0,
                                "Texture coordinates are not provided!");
                uv_indices.push_back(Vector3i(idx[f * 3].texcoord_index,
                                              idx[f * 3 + 1].texcoord_index,
                                              idx[f * 3 + 2].texcoord_index));
            }
        }
    }
    if (auto_config)
        configure();
}

void Shape::configureC() {
    assert(vertices.size() == vertices_raw.size());
    // =========== object to world transformation ============
    for (int i = 0; i < static_cast<int>(vertices_raw.size()); i++)
        vertices[i] = psdr::transform_pos(m_to_world, vertices_raw[i]);

    // if (false) // FIXME
    //     computeFaceNormals();

    // if (!m_use_face_normals) {
    //     computeVertexNormals();
    // }
}

void Shape::configure() {
    // PSDR_INFO("Configuring shape");
    num_triangles = indices.size();
    num_vertices  = vertices_raw.size();

    if (m_use_face_normals)
        assert(normals.size() == 0);

    vertices.resize(num_vertices);
    configureC();

    // PSDR_INFO("Configuring adjacent faces");
    // Compute adjacent faces
    computeAdjacentFaces();
    // PSDR_INFO("Configuring face distributions");
    // Compute face distribution
    computeFaceDistribution();

    faceNormals.clear();
    faceNormals.resize(num_triangles);
    if (!m_use_face_normals) {
        normals.resize(num_vertices);
    }
    // PSDR_INFO("Configuring face normals");
    computeFaceNormals();
    // PSDR_INFO("Configuring vertex normals");
    // Compute vertex normals
    if (!m_use_face_normals) {
        computeVertexNormals();
    }

    // compute velocities
    velocities.resize(num_vertices);
    for (int i = 0; i < num_vertices; i++)
        velocities[i] = Vector(0, 0, 0);

    // ================= Edge preprocessing ===================
    // Compute edges
    if (enable_draw) {
        edges = constructSortEdges(vertices, indices, faceNormals, sort_config);
    } else {
        edges = constructEdges(vertices, indices, faceNormals);
    }
    computeAdjacentEdges();
    // Compute edge distribution
    edge_distrb.clear();
    for (const Edge &edge : edges)
        edge_distrb.append(edge.length);
    edge_distrb.normalize();
}

namespace {
void __configureC(Shape &shape) {
    shape.configureC();
}

void configureD(const Shape &shape, Shape &d_shape) {
    __enzyme_autodiff((void *) __configureC, //
                      enzyme_dup, &shape, &d_shape);
}
} // namespace

void Shape::configureD(const Shape &primal) {
    ::configureD(primal, *this);
}

void Shape::computeAdjacentFaces() {
    adjacentFaces.clear();
    adjacentFaces.resize(num_vertices);
    for (int i = 0; i < static_cast<int>(indices.size()); i++) {
        assert(i < static_cast<int>(indices.size()));
        auto &face = indices.at(i);
        for (int j = 0; j < 3; j++) {
            adjacentFaces.at(face[j]).push_back(i);
        }
    }
}

void Shape::computeAdjacentEdges() {
    adjacentEdges.clear();
    adjacentEdges.resize(num_vertices);
    for (int i = 0; i < edges.size(); i++) {
        assert(i < edges.size());
        auto &edge = edges.at(i);
        adjacentEdges.at(edge.v0).push_back(i);
        adjacentEdges.at(edge.v1).push_back(i);
    }
}

void Shape::computeFaceDistribution() {
    face_distrb.clear();
    face_distrb.reserve(num_triangles);
    for (int i = 0; i < num_triangles; ++i) {
        const auto &ind  = getIndices(i);
        const auto &v0   = getVertex(ind(0));
        const auto &v1   = getVertex(ind(1));
        const auto &v2   = getVertex(ind(2));
        auto        cur  = (v1 - v0).cross(v2 - v0);
        Float       area = cur.norm();
        face_distrb.append(0.5f * area);
    }
    face_distrb.normalize();
}
INACTIVE_FN(Shape_computeFaceDistribution, &Shape::computeFaceDistribution);

void Shape::computeFaceNormals() {
    assert(static_cast<int>(faceNormals.size()) == num_triangles);
    for (int i = 0; i < num_triangles; ++i) {
        const auto            &ind  = getIndices(i);
        const auto            &v0   = getVertex(ind(0));
        const auto            &v1   = getVertex(ind(1));
        const auto            &v2   = getVertex(ind(2));
        [[maybe_unused]] auto  n    = (v1 - v0).cross(v2 - v0);
        [[maybe_unused]] Float area = n.norm();
        // ! assert(area > Epsilon);
        faceNormals[i] = n.normalized();
    }
}

Float unitAngle(const Vector &u, const Vector &v) {
    if (u.dot(v) < 0.0)
        return M_PI - asin((0.5 * (v + u).norm()));
    else
        return 2.0 * asin((0.5 * (v - u).norm()));
}

void Shape::computeVertexNormals() {
    assert(static_cast<int>(normals.size()) == num_vertices);
    for (int i = 0; i < num_vertices; ++i) {
        normals[i] = Vector::Zero();
        for (int j : adjacentFaces[i])
            normals[i] += getFaceNormal(j);
        normals[i].normalize();
    }
}

Float Shape::getArea(int index) const {
    auto &ind = getIndices(index);
    auto &v0  = getVertex(ind(0));
    auto &v1  = getVertex(ind(1));
    auto &v2  = getVertex(ind(2));
    return 0.5f * (v1 - v0).cross(v2 - v0).norm();
}

// void Shape::samplePosition(int index, const Vector2 &rnd2, Vector &pos,
// Vector &norm) const {
//     const Vector3i& ind = getIndices(index);
//     const Vector& v0 = getVertex(ind(0));
//     const Vector& v1 = getVertex(ind(1));
//     const Vector& v2 = getVertex(ind(2));
//     Float a = std::sqrt(rnd2[0]);
//     pos = v0 + (v1 - v0)*(1.0f - a) + (v2 - v0)*(a*rnd2[1]);
//     norm = faceNormals[index].val;
// }

int Shape::samplePosition(const Vector2          &_rnd2,
                          PositionSamplingRecord &pRec,
                          const DiscreteDistribution *distrib, EmitterPrimarySampleRecord *dPSRec) const {
    Vector2 rnd2(_rnd2);
    Float   pdf;
    int     tri_idx    = distrib == nullptr ? face_distrb.sampleReuse(rnd2[0], pdf)
                                            : distrib->sampleReuse(rnd2[0], pdf);
    pRec.tri_id        = tri_idx;
    Float   a          = std::sqrt(rnd2[0]);
    Vector2 uv         = Vector2{ 1.0f - a, a * rnd2[1] };
    Float   u          = detach(uv.x());
    Float   v          = detach(uv.y());
    pRec.barycentric   = Vector2{ u, v };
    const Vector3i ind = getIndices(tri_idx);
    const Vector  &v0  = getVertex(ind(0));
    const Vector  &v1  = getVertex(ind(1));
    const Vector  &v2  = getVertex(ind(2));
    pRec.p             = v0 + (v1 - v0) * u + (v2 - v0) * v;
    pRec.n             = getFaceNormal(tri_idx);
    pRec.uv            = Vector2(u, v);
    Float area         = getArea();
    pRec.pdf           = detach(pdf) * detach(1. / getArea(tri_idx));
    pRec.measure       = EMArea;
    area               = getArea(tri_idx);
    pRec.J             = area / detach(area);

    if (dPSRec) {
        dPSRec->offset = distrib == nullptr ? face_distrb.m_cdf[tri_idx]
                                            : distrib->m_cdf[tri_idx];
        dPSRec->scale = distrib == nullptr ? face_distrb.m_cdf[tri_idx + 1] - face_distrb.m_cdf[tri_idx]
                                           : distrib->m_cdf[tri_idx + 1] - distrib->m_cdf[tri_idx];
        dPSRec->v0 = v0;
        dPSRec->v1 = v1;
        dPSRec->v2 = v2;
        dPSRec->emitter_id = light_id;
        dPSRec->triangle_id = tri_idx;
        dPSRec->n = pRec.n;
        dPSRec->continuous = false;
    }
    return tri_idx;
}

// int Shape::samplePositionCont(const Vector2      &_rnd2,
//                           PositionSamplingRecord &pRec,
//                           const DiscreteDistribution *distrib, EmitterPrimarySampleRecord *dPSRec) const { // sample position on UVmap, find corresponding position on mesh
//     Vector2 rnd2(_rnd2);
//     Float   pdf = 1.0;
//     Point_2 query_point(rnd2[0], rnd2[1]);
//     Vector2 uv;
//     int     tri_idx = -1;

//     for (int i = 0; i < indices.size(); i++)
//     {
//         Ray query_ray(Vector(rnd2[0], rnd2[1], 1.0), Vector(0, 0, -1));
//         std::vector<Scalar> barycentrics;
//         Vector uv_x = Vector(uv_map[vertex_descriptor(indices[i][0])][0], uv_map[vertex_descriptor(indices[i][0])][1], 0.0);
//         Vector uv_y = Vector(uv_map[vertex_descriptor(indices[i][1])][0], uv_map[vertex_descriptor(indices[i][1])][1], 0.0);
//         Vector uv_z = Vector(uv_map[vertex_descriptor(indices[i][2])][0], uv_map[vertex_descriptor(indices[i][2])][1], 0.0);
        
//         Array uvt = rayIntersectTriangle(uv_x, uv_y, uv_z, query_ray);
//         if (uvt[0] > -1e-6 && uvt[1] > -1e-6 && uvt[0] + uvt[1] < 1.0+1e-6)
//         {
//             uv = Vector2(uvt[0], uvt[1]);
//             tri_idx = i;
//             break;
//         }
//     }

//     // auto result = PMP::locate(K::Point_3(rnd2[0], rnd2[1], 0.0), uv_mesh);
//     pRec.tri_id        = tri_idx;
//     Float   u          = detach(uv.x());
//     Float   v          = detach(uv.y());
//     pRec.barycentric   = Vector2{ u, v };
//     const Vector3i ind = getIndices(tri_idx);
//     const Vector  &v0  = getVertex(ind(0));
//     const Vector  &v1  = getVertex(ind(1));
//     const Vector  &v2  = getVertex(ind(2));
//     pRec.p             = v0 + (v1 - v0) * u + (v2 - v0) * v;
//     pRec.n             = getFaceNormal(tri_idx);
//     pRec.uv            = Vector2(u, v);
//     Float area         = getArea();
//     pRec.pdf           = detach(pdf) * detach(1. / getArea(tri_idx));
//     pRec.measure       = EMArea;
//     area               = getArea(tri_idx);
//     pRec.J             = area / detach(area);

//     if (dPSRec) {
//         dPSRec->offset = 0.0;
//         dPSRec->scale = 1.0;
//         dPSRec->v0 = v0;
//         dPSRec->v1 = v1;
//         dPSRec->v2 = v2;
//         dPSRec->emitter_id = light_id;
//         dPSRec->triangle_id = tri_idx;
//         dPSRec->n = pRec.n;
//         dPSRec->continuous = true;
//         dPSRec->uv_mesh0 = Vector(uv_map[vertex_descriptor(indices[tri_idx][0])][0], uv_map[vertex_descriptor(indices[tri_idx][0])][1], 0.0);
//         dPSRec->uv_mesh1 = Vector(uv_map[vertex_descriptor(indices[tri_idx][1])][0], uv_map[vertex_descriptor(indices[tri_idx][1])][1], 0.0);
//         dPSRec->uv_mesh2 = Vector(uv_map[vertex_descriptor(indices[tri_idx][2])][0], uv_map[vertex_descriptor(indices[tri_idx][2])][1], 0.0);
//     }
//     return tri_idx;
// }
// INACTIVE_FN(Shape_samplePositionCont, &Shape::samplePositionCont);

void Shape::rayIntersect(int tri_index, const Ray &ray,
                         Intersection    &its,
                         IntersectionMode mode) const {
    its.type = EVSurface;
    if (its.isEmitter())
        its.type = EVEmitter;
    const Vector3i &ind = getIndices(tri_index);
    const Vector   &v0  = getVertex(ind(0)),
                 &v1    = getVertex(ind(1)),
                 &v2    = getVertex(ind(2));

    const Array uvt = rayIntersectTriangle(v0, v1, v2, ray);
    const Float u = uvt(0), v = uvt(1), t = uvt(2);

    Vector2 barycentric = Vector2(u, v);
    if (mode == EMaterial)
        barycentric = Vector2(detach(u), detach(v));

    Vector geom_normal    = getFaceNormal(tri_index);
    Vector shading_normal = getShadingNormal(tri_index, barycentric);
    // Flip geometric normal to the same side of shading normal
    if (geom_normal.dot(shading_normal) < 0.f) {
        geom_normal = -geom_normal;
    }
    its.geoFrame = Frame(geom_normal);
    its.shFrame  = Frame(shading_normal);

    if (mode == ESpatial)
        its.p = ray.org + ray.dir * t;
    else if (mode == EMaterial)
        its.p = (1 - barycentric[0] - barycentric[1]) * v0 +
                barycentric[0] * v1 + barycentric[1] * v2;
    else if (mode == EReference)
        its.p = (1 - u - v) * detach(v0) +
                u * detach(v1) +
                v * detach(v2);

    if (mode == ESpatial) {
        its.t           = t;
        its.uv          = getUV(tri_index, barycentric);
        its.wi          = its.toLocal(-ray.dir);
        its.barycentric = barycentric;
    } else {
        Vector dir  = its.p - ray.org;
        Float  dist = dir.norm();
        dir /= dist;
        its.t           = dist;
        its.barycentric = barycentric;
        its.uv          = getUV(tri_index, barycentric);
        its.wi          = its.toLocal(-dir);
    }
    its.J = getArea(tri_index);
    its.J /= detach(its.J);
}

bool Shape::rayIntersect(const Ray &_ray, Intersection &its,
                         IntersectionMode mode) const {
    Ray   ray(_ray);
    Float u, v, t;
    bool  found = false;
    int   f     = -1;
    for (int i = 0; i < num_triangles; ++i) {
        const Vector3i &ind = getIndices(i);
        const Vector   &v0  = getVertex(ind(0)),
                     &v1    = getVertex(ind(1)),
                     &v2    = getVertex(ind(2));
        if (rayIntersectTriangle(v0, v1, v2, ray, u, v, t)) {
            ray.tmax = t;
            found    = true;
            f        = i;
            break;
        }
    }
    if (found) {
        // populate its
        its.ptr_shape   = nullptr; // NOTE: make sure its.isValid()==false;
        its.triangle_id = f;
        rayIntersect(f, ray, its, mode);
    }

    return found;
}

std::tuple<Vector, Vector, Float> Shape::getPoint(int tri_index, const Vector2 &_barycentric) const {
    Vector2 barycentric = detach(_barycentric);
    assert(tri_index >= 0 && tri_index < num_triangles);
    const Vector3i &ind = indices[tri_index];
    const Vector   &v0  = getVertex(ind[0]),
                 &v1    = getVertex(ind[1]),
                 &v2    = getVertex(ind[2]);
    Vector x            = (1. - barycentric.x() - barycentric.y()) * v0 +
               barycentric.x() * v1 +
               barycentric.y() * v2;
    Vector n = getFaceNormal(tri_index);
    Float  J = getArea(tri_index);
    return std::make_tuple(x, n, J / detach(J));
}

void Shape::getPoint(const Intersection &its, Intersection &itsAD) const {
    Vector2 barycentric = detach(its.barycentric);
    int     tri_index   = its.indices[1];
    assert(tri_index >= 0 && tri_index < num_triangles);
    const Vector3i &ind = indices[tri_index];
    const Vector   &v0  = getVertex(ind[0]),
                 &v1    = getVertex(ind[1]),
                 &v2    = getVertex(ind[2]);
    itsAD.p             = (1. - barycentric.x() - barycentric.y()) * v0 +
              barycentric.x() * v1 +
              barycentric.y() * v2;
    Vector geom_normal    = getFaceNormal(tri_index);
    Vector shading_normal = getShadingNormal(tri_index, barycentric);
    itsAD.geoFrame        = Frame(geom_normal);
    itsAD.shFrame         = Frame(shading_normal);
    itsAD.barycentric     = barycentric;
    itsAD.J               = getArea(tri_index);
    itsAD.J /= detach(itsAD.J);
}

void Shape::getPoint(int tri_index, const Vector2 &_barycentric, Intersection &its, Float &J) const {
    Vector2 barycentric = detach(_barycentric);
    Vector  geom_normal;
    std::tie(its.p, geom_normal, J) = getPoint(tri_index, barycentric);
    Vector shading_normal           = getShadingNormal(tri_index, barycentric);
    if (geom_normal.dot(shading_normal) < 0.)
        geom_normal = -geom_normal;

    its.geoFrame = Frame(geom_normal);
    its.shFrame  = Frame(shading_normal);

    its.uv = getUV(tri_index, barycentric);
}

bool Shape::isSihoulette(const Edge &edge, const Vector &p) const {
    if (edge.f0 == -1 || edge.f1 == -1)
        return true;

    const Vector &v0 = vertices[edge.v0],
                 &v1 = vertices[edge.v1];

    bool            frontfacing0 = false;
    const Vector3i &ind0         = indices[edge.f0];
    for (int i = 0; i < 3; i++) {
        if (ind0[i] != edge.v0 && ind0[i] != edge.v1) {
            const Vector &v  = vertices[ind0[i]];
            Vector        n0 = (v0 - v).cross(v1 - v).normalized();
            frontfacing0     = n0.dot(p - v) > 0.0f;
            break;
        }
    }

    bool            frontfacing1 = false;
    const Vector3i &ind1         = indices[edge.f1];
    for (int i = 0; i < 3; i++) {
        if (ind1[i] != edge.v0 && ind1[i] != edge.v1) {
            const Vector &v  = vertices[ind1[i]];
            Vector        n1 = (v1 - v).cross(v0 - v).normalized();
            frontfacing1     = n1.dot(p - v) > 0.0f;
            break;
        }
    }
    return (frontfacing0 && !frontfacing1) ||
           (!frontfacing0 && frontfacing1);
}

Vector Shape::transform(int index, Float param) const {
    Vector v = vertices[index];
    if (is_tranlate)
        return param * translation + detach(v); //! detach is neccessary here
    // else if (is_rotate) // segfault
    //     return Eigen::AngleAxis<Float>(param, rotation.normalized())
    //                .toRotationMatrix() *
    //            detach(v);
    else if (is_scale)
        return param * scale.cwiseProduct(detach(v));
    else if (is_velocity) {
        Vector velocity = velocities[index];
        return param * detach(velocity) + detach(v); // 1st order Taylor expansion
    } else
        return v;
}

Vector Shape::getShadingNormal(int tri_index, const Vector2 &barycentric) const {
    Vector          geom_normal = getFaceNormal(tri_index);
    const Vector3i &ind         = indices[tri_index];
    Vector          shading_normal;
    if (!m_use_face_normals) {
        /* interpolation */
        const Float u    = barycentric[0],
                    v    = barycentric[1];
        const Vector &n0 = getVertexNormal(ind[0]),
                     &n1 = getVertexNormal(ind[1]),
                     &n2 = getVertexNormal(ind[2]);
        shading_normal   = ((1.0f - u - v) * n0 + u * n1 + v * n2)
                             .normalized();
    } else
        shading_normal = geom_normal;
    return shading_normal;
}

/* ==================== Normal precomputation ==================== */
#ifdef NORMAL_PREPROCESS
namespace {
void precompute_normal(Shape &shape) {
    /* compute face normal */
    shape.computeFaceNormals();
    /* compute vertex normal */
    if (!shape.m_use_face_normals)
        shape.computeVertexNormals();
}
} // namespace

void d_precompute_normal(const Shape &shape, Shape &d_shape) {
    __enzyme_autodiff((void *) precompute_normal,
                      enzyme_dup, &shape, &d_shape);
}
#endif

bool Shape::save(const std::string &name) const {
    std::ofstream ofs(name);
    if (!ofs.is_open())
        throw std::runtime_error("failed to open file");
    for (int i = 0; i < num_vertices; i++) {
        ofs << "v " << vertices[i].x() << " " << vertices[i].y() << " " << vertices[i].z() << std::endl;
    }
    if (uvs.size() > 0) {
        for (int i = 0; i < num_vertices; i++) {
            ofs << "vt " << uvs[i].x() << " " << uvs[i].y() << std::endl;
        }
    }
    if (uv_indices.size() > 0) {
        assert(static_cast<int>(uv_indices.size()) == num_triangles);
        for (int i = 0; i < num_triangles; i++) {
            ofs << "f "
                << uv_indices[i][0] << "/" << uv_indices[i][0] << " "
                << uv_indices[i][1] << "/" << uv_indices[i][1] << " "
                << uv_indices[i][2] << "/" << uv_indices[i][2] << std::endl;
        }
    } else {
        for (int i = 0; i < num_triangles; i++) {
            ofs << "f " << indices[i][0] + 1 << " " << indices[i][1] + 1 << " " << indices[i][2] + 1 << std::endl;
        }
    }
    ofs.close();
    return true;
}
