#include "sphereAD.h"
#include <iostream>

// Circle

void CircleAD::init(const VectorAD &C1, const VectorAD &C2, const VectorAD &C3) {
    // Compute an orthonormal basis of the triangle plane
    VectorAD T = (C2 - C1).normalized();
    VectorAD N = T.cross(C3 - C1).normalized();
    VectorAD U = N.cross(T);
    circle_frame.s = T;
    circle_frame.t = U;
    circle_frame.n = N;

    VectorAD circle_c[3];
    circle_c[0] = worldToCircleAD(C1);
    circle_c[1] = worldToCircleAD(C2);
    circle_c[2] = worldToCircleAD(C3);
    FloatAD denom = (circle_c[1].y() - circle_c[2].y()) *
        (circle_c[0].x() - circle_c[2].x()) + (circle_c[2].x() - circle_c[1].x()) *
        (circle_c[0].y() - circle_c[2].y());

    auto isInTri = [&](const VectorAD &p) -> bool {
        FloatAD tmp1 = p.x() - circle_c[2].x();
        FloatAD tmp2 = p.y() - circle_c[2].y();
        FloatAD u = ((circle_c[1].y() - circle_c[2].y()) * tmp1 +
            (circle_c[2].x() - circle_c[1].x()) * tmp2) / denom;
        FloatAD v = ((circle_c[2].y() - circle_c[0].y()) * tmp1 +
            (circle_c[0].x() - circle_c[2].x()) * tmp2) / denom;
        FloatAD w = 1.0 - u - v;
        return epsInclusiveGreater(u.val, 0) && epsInclusiveLesser(u.val, 1.0) &&
            epsInclusiveGreater(v.val, 0) && epsInclusiveLesser(v.val, 1.0) &&
            epsInclusiveGreater(w.val, 0) && epsInclusiveLesser(w.val, 1.0);
    };

    std::vector<FloatAD> phis;

    // Find intersections with each side of the triangle
    for (size_t i = 0; i < 3; i++) {
        size_t j = i == 2 ? 0 : i + 1;

        FloatAD x1 = circle_c[i].x();
        FloatAD y1 = circle_c[i].y();
        FloatAD x2 = circle_c[j].x();
        FloatAD y2 = circle_c[j].y();

        FloatAD a = (x1 - x2).pow(2) + (y1 - y2).pow(2);
        FloatAD b = 2.0 * (x1 * x2 - x2 * x2 + y1 * y2 - y2 * y2);
        FloatAD c = x2 * x2 + y2 * y2 - 1.0;
        FloatAD disc = b * b - 4.0 * a * c;

        // No intersection
        if (epsExclusiveLesser(disc, 0.0)) continue;

        // One intersection
        if (almostEqual(disc, 0.0)) {
            FloatAD alpha = -0.5 * b / a;
            if (!almostEqual(alpha, 0.0) && epsExclusiveGreater(alpha, 0.0)
                && epsInclusiveLesser(alpha, 1.0)) {
                VectorAD itst = alpha * circle_c[i] + (1.0 - alpha) * circle_c[j];
                phis.push_back(pointToAngleAD(itst));
            }
        }
        // Two intersection
        else {
            disc = disc.sqrt();
            FloatAD alpha1 = (-b + disc) * 0.5 / a;
            FloatAD alpha2 = (-b - disc) * 0.5 / a;
            if (!almostEqual(alpha1, 0.0) && epsExclusiveGreater(alpha1, 0.0)
                && epsInclusiveLesser(alpha1, 1.0)) {
                VectorAD itst = alpha1 * circle_c[i] + (1.0 - alpha1) * circle_c[j];
                phis.push_back(pointToAngleAD(itst));
            }
            if (!almostEqual(alpha2, 0.0) && epsExclusiveGreater(alpha2, 0.0)
                && epsInclusiveLesser(alpha2, 1.0)) {
                VectorAD itst = alpha2 * circle_c[i] + (1.0 - alpha2) * circle_c[j];
                phis.push_back(pointToAngleAD(itst));
            }
        }
    }

    if (phis.empty()) {
        // The circle is completely inside the triangle
        if (isInTri(VectorAD(0, 0, 0)) && isInTri(VectorAD(1, 0, 0))
            && isInTri(VectorAD(0, 1, 0))) {
            phis.push_back(0.0);
        }
        else {
            status = 1;
            return;
        }
    }

    // Map phi into [0, 2pi]
    for (FloatAD &phi : phis) {
        // TODO: convert this into its AD version
        phi.val = fmod(phi.val, 2.0 * M_PI);
        if (epsExclusiveLesser(phi, 0.0)) phi.val += 2.0 * M_PI;
    }

    // Sort all the angles
    std::sort(phis.begin(), phis.end());

    for (size_t i = 0; i < phis.size(); i++) {
        size_t j;
        FloatAD start = phis[i];
        FloatAD end;
        if (i == phis.size() - 1) {
            j = 0;
            end = phis[j] + 2.0 * M_PI;
        }
        else {
            j = i + 1;
            end = phis[j];
        }

        VectorAD mid = angleToPointAD((start + end) * 0.5); // midpoint of the curve 

        // The curve is inside the triangle
        if (isInTri(mid)) {
            curves.push_back(Vector2AD(start, end));
        }
    }

    if (curves.empty()) {
        status = 2;
    }
    else {
        status = 3;
    }
}

VectorAD CircleAD::sampleCircularCurve(const Float &rn, Float &phi, Float &pdf, FloatAD &jacobian) {
    std::vector<Float> cumsum;
    cumsum.push_back(0.0);

    // Compute the total curve length.
    for (const Vector2AD& curve : curves) {
        cumsum.push_back(cumsum[cumsum.size() - 1] + curve.y().val - curve.x().val);
    }

    // Compute the pdf for uniform sampling.
    pdf = 1.0 / cumsum[cumsum.size() - 1];

    // Find the sampled angle.
    Float scaled_rn = rn * cumsum[cumsum.size() - 1];
    phi = 0;

    for (size_t i = 0; i < curves.size(); i++) {
        double len_interval = cumsum[i + 1] - cumsum[i];
        if (scaled_rn <= len_interval) {
            phi = curves[i].x().val + scaled_rn;
            break;
        }
        else {
            scaled_rn -= len_interval;
        }
    }
    phi = fmod(phi, 2.0 * M_PI);

    // Compute the jacobian.
    jacobian = tau;

    // Map the angle to a point the ellipsoid frame.
    VectorAD p_circle = angleToPointAD(phi);
    VectorAD p_world = circleToWorldAD(p_circle);
    return p_world;
}

FloatAD CircleAD::pointToAngleAD(const VectorAD &p) const {
    FloatAD angle(p.y().atan2(p.x()));
    return angle;
}

VectorAD CircleAD::angleToPointAD(const FloatAD &angle) const {
    VectorAD p_circle(angle.cos(), angle.sin(), 0.0);
    return p_circle;
}

VectorAD CircleAD::worldToCircleAD(const VectorAD &p_world) const {
    VectorAD p_circle = circle_frame.toLocal(p_world - orig);
    return p_circle / radius;
}

VectorAD CircleAD::circleToWorldAD(const VectorAD &p_circle) const {
    VectorAD p_world = circle_frame.toWorld(p_circle * radius) + orig;
    return p_world;
}

// Sphere

void SphereAD::init() {
    Vector delta(radius.val, radius.val, radius.val);
    bounds = AABB(orig.val - delta, orig.val + delta);
}

CircleAD SphereAD::intersectTri(const VectorAD &C1, const VectorAD &C2, const VectorAD &C3) {
    VectorAD sphere_n = (C2 - C1).cross(C3 - C1).normalized();

    // Compute distance from the center of the sphere to the triangle plane
    FloatAD sphere_dist = sphere_n.dot(C1 - orig);

    if (epsExclusiveGreater(sphere_dist.val, radius.val)) return CircleAD();

    VectorAD circle_O = orig + sphere_dist * sphere_n;
    FloatAD circle_r = (radius * radius - sphere_dist * sphere_dist).sqrt();

    CircleAD circle(C1, C2, C3, circle_O, circle_r, radius);
    return circle;
}
