#include "epc_sampler.h"

#define USE_BVH_ACCEL_FOR_EPC 1

#define DIST_THRESHOLD 1.0

EPCSampler::EPCSampler(const Vector& x1_, const Vector& n1_, const Vector& x2_, const Vector& n2_, Float tau, const Scene& scene_) 
    : ellipsoid(x1_, n1_, x2_, n2_, tau), scene(scene_) {
    init();
}

void EPCSampler::init() {
    // Find out all intersected ellipses.
#if USE_BVH_ACCEL_FOR_EPC
    std::vector<ECBVHManager::TriWithID> its_tri;
    scene.ptr_ECBVHManager->intersectEllipsoid(ellipsoid, its_tri);

    for (const auto& tri : its_tri) {
        auto shape = scene.shape_list[tri.shape_id];
        Vector3i ind = shape->getIndices(tri.tri_id);
        Vector v[3] = { shape->getVertex(ind(0)), shape->getVertex(ind(1)), shape->getVertex(ind(2)) };
        Vector faceN = (v[1] - v[0]).cross(v[2] - v[0]).normalized();

        // Early reject some triangles.
        if (ECBVHManager::triIsInNegativeHalfspace(v, ellipsoid.f1, ellipsoid.n1) ||
            ECBVHManager::triIsInNegativeHalfspace(v, ellipsoid.f2, ellipsoid.n2) ||
            ECBVHManager::focusIsInNegativeHalfspace(ellipsoid, v[0], faceN) ||
            ECBVHManager::triIsInEllipsoid(v, ellipsoid))
            continue;

        tri_list.push_back(tri);
    }
    //printf("%d / %d\n", tri_list.size(), its_tri.size());
#else
    for (size_t shape_id = 0; shape_id < scene.shape_list.size(); shape_id++) {
        int num_tri_shape = scene.shape_list[shape_id]->num_triangles;
        tri_list.reserve(tri_list.size() + num_tri_shape);
        for (int tri_id = 0; tri_id < num_tri_shape; tri_id++) {
            auto shape = scene.shape_list[shape_id];
            Vector3i ind = shape->getIndices(tri_id);
            Vector v[3] = { shape->getVertex(ind(0)), shape->getVertex(ind(1)), shape->getVertex(ind(2)) };

            // Early reject some triangles.
            Vector faceN = (v[1] - v[0]).cross(v[2] - v[0]).normalized();
            if (ECBVHManager::triIsInNegativeHalfspace(v, ellipsoid.f1, ellipsoid.n1) ||
                ECBVHManager::triIsInNegativeHalfspace(v, ellipsoid.f2, ellipsoid.n2) ||
                ECBVHManager::focusIsInNegativeHalfspace(ellipsoid, v[0], faceN) ||
                ECBVHManager::triIsInEllipsoid(v, ellipsoid))
                continue;

            tri_list.push_back(ECBVHManager::TriWithID{ (int)shape_id, tri_id });
        }
    }
#endif
}

bool EPCSampler::sample(const Vector2& rnd2, Intersection& x_its, Float& jacobian, Float& pdf) {
    if (tri_list.empty()) return false;

    const Vector& x1 = ellipsoid.f1;
    const Vector& x2 = ellipsoid.f2;
    const Vector& n1 = ellipsoid.n1;
    const Vector& n2 = ellipsoid.n2;

    size_t index = static_cast<size_t>(floor(rnd2[0] * tri_list.size()));
    assert(0 <= index && index < tri_list.size());

    ECBVHManager::TriWithID tri = tri_list[index];
    auto shape = scene.shape_list[tri.shape_id];
    Vector3i ind = shape->getIndices(tri.tri_id);
    Vector v[3] = { shape->getVertex(ind(0)), shape->getVertex(ind(1)), shape->getVertex(ind(2)) };
    Ellipse ellipse = ellipsoid.intersectTri(v[0], v[1], v[2]);
    if (!ellipse.canSample())
        return false;

    // Sample a point from the elliptic curves.
    Float phi, x_pdf;
    Vector x_ellipsoid = ellipse.sampleEllipticCurve(rnd2[1], phi, x_pdf, jacobian);
    Vector x = ellipsoid.ellipsoidToWorld(x_ellipsoid);
    pdf = x_pdf / tri_list.size();

    Vector x1_to_x = (x - x1).normalized();
    Ray ray_x1_to_x(x1, x1_to_x);
    scene.rayIntersect(ray_x1_to_x, true, x_its);
    if (!x_its.isValid()) 
        return false;

    // x_its and x should be the same.
    if (std::abs((x_its.p - x).norm()) > 1e-4) {
        return false;
    }

    Vector x_to_x2 = (x2 - x).normalized();
    // Check if (x and x1) or (x and x2) are in the same plane.
    if (std::abs(n1.dot(x1_to_x)) < Epsilon || std::abs(n2.dot(-x_to_x2)) < Epsilon)
        return false;

    // Check if there are any occluders between x and x2.
    if (!scene.isVisible(x, true, x2, true))
        return false;

    // Reject the case that x is too close to x1 or x2, therefore avoid too large G.
    if ((x - x1).norm() < DIST_THRESHOLD || (x - x2).norm() < DIST_THRESHOLD)
        return false;

    return true;
}
