#include "sphere_sampler_AD.h"

#define DIST_THRESHOLD 1e-2

SphereSamplerAD::SphereSamplerAD(const VectorAD& x_, const VectorAD& n_, const FloatAD& tau, const Scene& scene_)
    : sphere(x_, n_, tau), scene(scene_) {
    init();
}

void SphereSamplerAD::init() {
    std::vector<ECBVHManager::TriWithID> its_tri;
    Sphere sph(sphere.orig.val, sphere.n.val, sphere.radius.val);
    scene.ptr_sphere_BVHManager->intersectSphere(sph, its_tri);

    for (const auto& tri : its_tri) {
        auto shape = scene.shape_list[tri.shape_id];
        Vector3i ind = shape->getIndices(tri.tri_id);
        Vector v[3] = { shape->getVertex(ind(0)), shape->getVertex(ind(1)), shape->getVertex(ind(2)) };
        Vector faceN = (v[1] - v[0]).cross(v[2] - v[0]).normalized();

        // Early reject some triangles.
        if (ECBVHManager::triIsInNegativeHalfspace(v, sph.orig, sph.n) ||
            (sph.orig - v[0]).dot(faceN) < Epsilon ||    // Sphere origin is in the negative halfspace.
            ECBVHManager::triIsInSphere(v, sph))
            continue;

        tri_list.push_back(tri);
    }
}

bool SphereSamplerAD::sample(const Vector2& rnd2, IntersectionAD& x_its, FloatAD& jacobian, Float& pdf) {
    if (tri_list.empty()) return false;

    const VectorAD& x1 = sphere.orig;
    const VectorAD& n1 = sphere.n;

    size_t index = static_cast<size_t>(floor(rnd2[0] * tri_list.size()));
    assert(0 <= index && index < tri_list.size());

    ECBVHManager::TriWithID tri = tri_list[index];
    auto shape = scene.shape_list[tri.shape_id];
    Vector3i ind = shape->getIndices(tri.tri_id);
    VectorAD v[3] = { shape->getVertexAD(ind(0)), shape->getVertexAD(ind(1)), shape->getVertexAD(ind(2)) };
    CircleAD circle = sphere.intersectTri(v[0], v[1], v[2]);
    if (!circle.canSample())
        return false;

    // Sample a point from the circular curves.
    Float phi, x_pdf;
    VectorAD x = circle.sampleCircularCurve(rnd2[1], phi, x_pdf, jacobian);
    pdf = x_pdf / tri_list.size();

    Vector x1_to_x = (x.val - x1.val).normalized();
    RayAD ray_x1_to_x(x1.val, x1_to_x);
    scene.rayIntersectAD(ray_x1_to_x, true, x_its);
    if (!x_its.isValid())
        return false;

    // x_its and x should be the same.
    if (std::abs((x_its.p.val - x.val).norm()) > 1e-4) {
        return false;
    }

    x_its.p = x;

    // Check if (x and x1) are in the same plane.
    if (std::abs(n1.val.dot(x1_to_x)) < Epsilon)
        return false;

    // Reject the case that x is too close to x1, therefore avoid too large G.
    if ((x.val - x1.val).norm() < DIST_THRESHOLD)
        return false;

    return true;
}