#include "sphere.h"
#include <iostream>

// Circle

void Circle::init(const Vector &C1, const Vector &C2, const Vector &C3) {
    // Compute an orthonormal basis of the triangle plane
    Vector T = (C2 - C1).normalized();
    Vector N = T.cross(C3 - C1).normalized();
    Vector U = N.cross(T);
    circle_frame.s = T;
    circle_frame.t = U;
    circle_frame.n = N;

    Vector circle_c[3];
    circle_c[0] = worldToCircle(C1);
    circle_c[1] = worldToCircle(C2);
    circle_c[2] = worldToCircle(C3);
    Float denom = (circle_c[1].y() - circle_c[2].y()) *
        (circle_c[0].x() - circle_c[2].x()) + (circle_c[2].x() - circle_c[1].x()) *
        (circle_c[0].y() - circle_c[2].y());

    auto isInTri = [&](const Vector &p) -> bool {
        Float tmp1 = p.x() - circle_c[2].x();
        Float tmp2 = p.y() - circle_c[2].y();
        Float u = ((circle_c[1].y() - circle_c[2].y()) * tmp1 +
            (circle_c[2].x() - circle_c[1].x()) * tmp2) / denom;
        Float v = ((circle_c[2].y() - circle_c[0].y()) * tmp1 +
            (circle_c[0].x() - circle_c[2].x()) * tmp2) / denom;
        Float w = 1.0 - u - v;
        return epsInclusiveGreater(u, 0) && epsInclusiveLesser(u, 1.0) &&
            epsInclusiveGreater(v, 0) && epsInclusiveLesser(v, 1.0) &&
            epsInclusiveGreater(w, 0) && epsInclusiveLesser(w, 1.0);
    };

    std::vector<Float> phis;

    // Find intersections with each side of the triangle
    for (size_t i = 0; i < 3; i++) {
        size_t j = i == 2 ? 0 : i + 1;

        Float x1 = circle_c[i].x();
        Float y1 = circle_c[i].y();
        Float x2 = circle_c[j].x();
        Float y2 = circle_c[j].y();

        Float a = (x1 - x2) * (x1 - x2) + (y1 - y2) * (y1 - y2);
        Float b = 2.0 * (x1 * x2 - x2 * x2 + y1 * y2 - y2 * y2);
        Float c = x2 * x2 + y2 * y2 - 1.0;
        Float disc = b * b - 4.0 * a * c;

        // No intersection
        if (epsExclusiveLesser(disc, 0.0)) continue;

        // One intersection
        if (almostEqual(disc, 0.0)) {
            Float alpha = -0.5 * b / a;
            if (!almostEqual(alpha, 0.0) && epsExclusiveGreater(alpha, 0.0)
                && epsInclusiveLesser(alpha, 1.0)) {
                Vector itst = alpha * circle_c[i] + (1.0 - alpha) * circle_c[j];
                phis.push_back(pointToAngle(itst));
            }
        }
        // Two intersection
        else {
            disc = std::sqrt(disc);
            Float alpha1 = (-b + disc) * 0.5 / a;
            Float alpha2 = (-b - disc) * 0.5 / a;
            if (!almostEqual(alpha1, 0.0) && epsExclusiveGreater(alpha1, 0.0)
                && epsInclusiveLesser(alpha1, 1.0)) {
                Vector itst = alpha1 * circle_c[i] + (1.0 - alpha1) * circle_c[j];
                phis.push_back(pointToAngle(itst));
            }
            if (!almostEqual(alpha2, 0.0) && epsExclusiveGreater(alpha2, 0.0)
                && epsInclusiveLesser(alpha2, 1.0)) {
                Vector itst = alpha2 * circle_c[i] + (1.0 - alpha2) * circle_c[j];
                phis.push_back(pointToAngle(itst));
            }
        }
    }

    if (phis.empty()) {
        // The circle is completely inside the triangle
        if (isInTri(Vector(0, 0, 0)) && isInTri(Vector(1, 0, 0))
            && isInTri(Vector(0, 1, 0))) {
            phis.push_back(0.0);
        }
        else {
            status = 1;
            return;
        }
    }

    // Map phi into [0, 2pi]
    for (Float &phi : phis) {
        // TODO: convert this into its AD version
        phi = fmod(phi, 2.0 * M_PI);
        if (epsExclusiveLesser(phi, 0.0)) phi += 2.0 * M_PI;
    }

    // Sort all the angles
    std::sort(phis.begin(), phis.end());

    for (size_t i = 0; i < phis.size(); i++) {
        size_t j;
        Float start = phis[i];
        Float end;
        if (i == phis.size() - 1) {
            j = 0;
            end = phis[j] + 2.0 * M_PI;
        }
        else {
            j = i + 1;
            end = phis[j];
        }

        Vector mid = angleToPoint((start + end) * 0.5); // midpoint of the curve 

        // The curve is inside the triangle
        if (isInTri(mid)) {
            curves.push_back(Vector2(start, end));
        }
    }

    if (curves.empty()) {
        status = 2;
    }
    else {
        status = 3;
    }
}

Vector Circle::sampleCircularCurve(const Float &rn, Float &phi, Float &pdf, Float &jacobian) {
    std::vector<Float> cumsum;
    cumsum.push_back(0.0);

    // Compute the total curve length.
    for (const Vector2& curve : curves) {
        cumsum.push_back(cumsum[cumsum.size() - 1] + curve.y() - curve.x());
    }

    // Compute the pdf for uniform sampling.
    pdf = 1.0 / cumsum[cumsum.size() - 1];

    // Find the sampled angle.
    Float scaled_rn = rn * cumsum[cumsum.size() - 1];
    phi = 0;

    for (size_t i = 0; i < curves.size(); i++) {
        double len_interval = cumsum[i + 1] - cumsum[i];
        if (scaled_rn <= len_interval) {
            phi = curves[i].x() + scaled_rn;
            break;
        }
        else {
            scaled_rn -= len_interval;
        }
    }
    phi = fmod(phi, 2.0 * M_PI);

    // Compute the jacobian.
    jacobian = tau;

    // Map the angle to a point the ellipsoid frame.
    Vector p_circle = angleToPoint(phi);
    Vector p_world = circleToWorld(p_circle);
    return p_world;
}

Float Circle::pointToAngle(const Vector &p) const {
    Float angle(std::atan2(p.y(), p.x()));
    return angle;
}

Vector Circle::angleToPoint(const Float &angle) const {
    Vector p_circle(std::cos(angle), std::sin(angle), 0.0);
    return p_circle;
}

Vector Circle::worldToCircle(const Vector &p_world) const {
    Vector p_circle = circle_frame.toLocal(p_world - orig);
    return p_circle / radius;
}

Vector Circle::circleToWorld(const Vector &p_circle) const {
    Vector p_world = circle_frame.toWorld(p_circle * radius) + orig;
    return p_world;
}

// Sphere

void Sphere::init() {
    Vector delta(radius, radius, radius);
    bounds = AABB(orig - delta, orig + delta);
}

Circle Sphere::intersectTri(const Vector &C1, const Vector &C2, const Vector &C3) {
    Vector sphere_n = (C2 - C1).cross(C3 - C1).normalized();

    // Compute distance from the center of the sphere to the triangle plane
    Float sphere_dist = sphere_n.dot(C1 - orig);

    if (epsExclusiveGreater(sphere_dist, radius)) return Circle();

    Vector circle_O = orig + sphere_dist * sphere_n;
    Float circle_r = std::sqrt(radius * radius - sphere_dist * sphere_dist);

    Circle circle(C1, C2, C3, circle_O, circle_r, radius);
    return circle;
}
