#include "ellipsoid.h"
#include <iostream>

// Ellipse
void Ellipse::init(const Vector &C1, const Vector &C2, const Vector &C3, const Vector &abc, const Float &tau) {  
    // Compute an orthonormal basis of the triangle plane
    Vector T = (C2 - C1).normalized();
    N = T.cross(C3 - C1).normalized();
    Vector U = N.cross(T);

    // TODO: Consider extending simpleAD.h to convert Array to Vector
    Vector D(1.0 / (abc[0] * abc[0]), 1.0 / (abc[1] * abc[1]), 1.0 / (abc[2] * abc[2]));

    Float TTD = quadDot(T, T, D);
    Float TUD = quadDot(T, U, D);
    Float UUD = quadDot(U, U, D);
    Float OOD = quadDot(O, O, D);

    // Compute theta according to SM 1.1. 
    theta = 0.5 * std::atan2(2.0 * TUD, UUD - TTD);

    // Compute some useful intermediate values. 
    Float Delta = almostEqual(TTD, UUD) ? 
        2.0 * std::abs(TUD) :
        std::sqrt(4.0 * std::pow(TUD, 2) + std::pow(TTD - UUD, 2));
    Float DR1 = TTD + UUD - Delta;
    Float DR2 = TTD + UUD + Delta;

    // Compute m1 and m2 according to SM 1.2.
    m1 = std::sqrt(2.0 * (1.0 - OOD) / DR1);
    m2 = std::sqrt(2.0 * (1.0 - OOD) / DR2);

    Float cosTheta = std::cos(theta);
    Float sinTheta = std::sin(theta);
    TN = T * cosTheta - U * sinTheta;
    UN = T * sinTheta + U * cosTheta;

    ellipse_frame.s = TN;
    ellipse_frame.t = UN;
    ellipse_frame.n = N;

    Vector circle_c[3];
    circle_c[0] = ellipsoidToCircle(C1);
    circle_c[1] = ellipsoidToCircle(C2);
    circle_c[2] = ellipsoidToCircle(C3);
    Float denom = (circle_c[1].y() - circle_c[2].y()) * 
        (circle_c[0].x() - circle_c[2].x()) + (circle_c[2].x() - circle_c[1].x()) * 
        (circle_c[0].y() - circle_c[2].y());

    auto isInTri = [&](const Vector &p) -> bool {
        Float tmp1 = p.x() - circle_c[2].x();
        Float tmp2 = p.y() - circle_c[2].y();
        Float u = ((circle_c[1].y() - circle_c[2].y()) * tmp1 + 
            (circle_c[2].x() - circle_c[1].x()) * tmp2) / denom;
        Float v = ((circle_c[2].y() - circle_c[0].y()) * tmp1 + 
            (circle_c[0].x() - circle_c[2].x()) * tmp2) / denom;
        Float w = 1.0 - u - v;
        return epsInclusiveGreater(u, 0) && epsInclusiveLesser(u, 1.0) && 
            epsInclusiveGreater(v, 0) && epsInclusiveLesser(v, 1.0) &&
            epsInclusiveGreater(w, 0) && epsInclusiveLesser(w, 1.0);
    };

    std::vector<Float> phis;

    // Find intersections with each side of the triangle
    for (size_t i = 0; i < 3; i++) {
        size_t j = i == 2 ? 0 : i + 1;

        Float x1 = circle_c[i].x();
        Float y1 = circle_c[i].y();
        Float x2 = circle_c[j].x();
        Float y2 = circle_c[j].y();

        Float a = std::pow(x1 - x2, 2) + std::pow(y1 - y2, 2);
        Float b = 2.0 * (x1 * x2 - x2 * x2 + y1 * y2 - y2 * y2);
        Float c = x2 * x2 + y2 * y2 - 1.0;
        Float disc = b * b - 4.0 * a * c;

        // No intersection
        if (epsExclusiveLesser(disc, 0.0)) continue;
        
        // One intersection
        if (almostEqual(disc, 0.0)) {
            Float alpha = -0.5 * b / a;
            if (!almostEqual(alpha, 0.0) && epsExclusiveGreater(alpha, 0.0) 
                && epsInclusiveLesser(alpha, 1.0)) {
                Vector itst = alpha * circle_c[i] + (1.0 - alpha) * circle_c[j];
                phis.push_back(pointToAngle(itst));
            }
        }
        // Two intersection
        else {
            disc = std::sqrt(disc);
            Float alpha1 = (-b + disc) * 0.5 / a;
            Float alpha2 = (-b - disc) * 0.5 / a;
            if (!almostEqual(alpha1, 0.0) && epsExclusiveGreater(alpha1, 0.0) 
                && epsInclusiveLesser(alpha1, 1.0)) {
                Vector itst = alpha1 * circle_c[i] + (1.0 - alpha1) * circle_c[j];
                phis.push_back(pointToAngle(itst));
            }
            if (!almostEqual(alpha2, 0.0) && epsExclusiveGreater(alpha2, 0.0) 
                && epsInclusiveLesser(alpha2, 1.0)) {
                Vector itst = alpha2 * circle_c[i] + (1.0 - alpha2) * circle_c[j];
                phis.push_back(pointToAngle(itst));
            }
        }
    }

    if (phis.empty()) {
        // The circle is completely inside the triangle
        if (isInTri(Vector(0, 0, 0)) && isInTri(Vector(1, 0, 0)) 
            && isInTri(Vector(0, 1, 0))) {
            phis.push_back(0.0);
        }
        else {
            status = 1;
            return;
        }
    }

    // Map phi into [0, 2pi]
    for (Float &phi : phis) {
        phi = fmod(phi, 2.0 * M_PI);
        if (epsExclusiveLesser(phi, 0.0)) phi += 2.0 * M_PI;
    }

    // Sort all the angles
    std::sort(phis.begin(), phis.end());

    for (size_t i = 0; i < phis.size(); i++) {
        size_t j;
        Float start = phis[i];
        Float end;
        if (i == phis.size() - 1) {
            j = 0;
            end = phis[j] + 2.0 * M_PI;
        }
        else {
            j = i + 1;
            end = phis[j];
        }

        Vector mid = angleToPoint((start + end) * 0.5); // midpoint of the curve 

        // The curve is inside the triangle
        if (isInTri(mid)) {
            curves.push_back(Vector2(start, end));
        }
    }

    if (curves.empty()) {
        status = 2;
    }
    else {
        status = 3;

        // Compute useful values for Jacobian computation.
        Vector tmp(1.0 / std::pow(abc[0], 4), 1.0 / std::pow(abc[1], 4), 1.0 / std::pow(abc[2], 4));
        Vector E = -abc.x() * tmp;
        
        Float TTE = quadDot(T, T, E);
        Float TUE = quadDot(T, U, E);
        Float UUE = quadDot(U, U, E);
        Float OOE = quadDot(O, O, E);

        // Compute dTN and dUN according to SM 3.1.
        Float dDelta;
        if (epsExclusiveGreater(Delta, 0.0)) {
            dDelta = (4.0 * TUD * TUE + (TTE - UUE) * (TTD - UUD)) / Delta;
            Float Delta2 = Delta * Delta;

            Float nsnthetadtheta = (Delta * (UUE - TTE) - dDelta * (UUD - TTD)) / Delta2;
            Float cnthetadtheta = 2.0 * (Delta * TUE - dDelta * TUD) / Delta2;

            dTN = T * nsnthetadtheta - U * cnthetadtheta;
            dUN = T * cnthetadtheta + U * nsnthetadtheta;
        }
        else {
            dDelta = 0;
            dTN = Vector::Zero();
            dUN = Vector::Zero();
        }

        // Compute dO according to SM 3.3.
        Array3 Tabc(T.array() / abc.array().pow(2));
        Array3 Uabc(U.array() / abc.array().pow(2));

        Matrix3x3 dOM;
        dOM.row(0) << Tabc(0), Tabc(1), Tabc(2);
        dOM.row(1) << Uabc(0), Uabc(1), Uabc(2);
        dOM.row(2) << N(0), N(1), N(2);
        Matrix3x3 dOMinv = Matrix3x3(dOM).inverse();

        Vector dOV(quadDot(T, O, tmp), quadDot(U, O, tmp), 0.0);
        dO = 0.5 * tau * dOMinv * dOV;

        // Compute dm1 and dm2 according to SM 3.2.
        Float NR = 2.0 * (1.0 - OOD);
        Float OdOD = quadDot(O, dO, D);
        Float dNR = -2.0 * OOE - 4.0 * OdOD;
        Float dDR1 = TTE + UUE - dDelta;
        Float dDR2 = TTE + UUE + dDelta;
        dm1 = (DR1 * dNR - dDR1 * NR) / (2.0 * std::pow(DR1, 2) * m1);
        dm2 = (DR2 * dNR - dDR2 * NR) / (2.0 * std::pow(DR2, 2) * m2);
    }
}

Vector Ellipse::sampleEllipticCurve(const Float &rn, Float &phi, Float &pdf, Float &jacobian) {
    std::vector<Float> cumsum;
    cumsum.push_back(0.0);

    // Compute the total curve length.
    for (const Vector2& curve : curves) {
        cumsum.push_back(cumsum[cumsum.size() - 1] + curve.y() - curve.x());
    }

    // Compute the pdf for uniform sampling.
    pdf = 1.0 / cumsum[cumsum.size() - 1];

    // Find the sampled angle.
    Float scaled_rn = rn * cumsum[cumsum.size() - 1];
    phi = 0;

    for (size_t i = 0; i < curves.size(); i++) {
        double len_interval = cumsum[i + 1] - cumsum[i];
        if (scaled_rn <= len_interval) {
            phi = curves[i].x() + scaled_rn;
            break;
        }
        else {
            scaled_rn -= len_interval;
        }
    }
    phi = fmod(phi, 2.0 * M_PI);

    // Compute the jacobian.
    Float cp = std::cos(phi);
    Float sp = std::sin(phi);
    Float cp2 = cp * cp;
    Float sp2 = sp * sp;
    jacobian = 
        m2 * dm1 * cp2 + 
        m1 * dm2 * sp2 + 
        dO.dot(m2 * cp * TN + m1 * sp * UN) +
        O.dot(m2 * cp * dTN + m1 * sp * dUN);
    jacobian = std::abs(jacobian);
    
    // Map the angle to a point the ellipsoid frame.
    Vector p_circle = angleToPoint(phi);
    Vector p_ellipsoid = circleToEllipsoid(p_circle);
    
    return p_ellipsoid;
}

// utils
Float Ellipse::quadDot(const Vector &u, const Vector &v, const Vector &diag) const {
    Float res = u.x() * v.x() * diag.x();
    res += u.y() * v.y() * diag.y();
    res += u.z() * v.z() * diag.z();
    return res;
}

Float Ellipse::pointToAngle(const Vector &p) const {
    Float angle(std::atan2(p.y(), p.x()));
    return angle;
}

Vector Ellipse::angleToPoint(const Float &angle) const {
    Vector p_circle(std::cos(angle), std::sin(angle), 0.0);
    return p_circle;
}

Vector Ellipse::ellipsoidToEllipse(const Vector &p_ellipsoid) const {
    Vector p_ellipse = p_ellipsoid - O;
    p_ellipse = ellipse_frame.toLocal(p_ellipse);
    return p_ellipse;
}

Vector Ellipse::ellipseToEllipsoid(const Vector &p_ellipse) const {
    Vector p_ellipsoid = ellipse_frame.toWorld(p_ellipse);
    p_ellipsoid = p_ellipsoid + O;
    return p_ellipsoid;
}

Vector Ellipse::ellipsoidToCircle(const Vector &p_ellipsoid) const {
    Vector p_ellipse = ellipsoidToEllipse(p_ellipsoid);
    Vector p_circle(p_ellipse.x() / m1, p_ellipse.y() / m2, p_ellipse.z());
    return p_circle;
}

Vector Ellipse::circleToEllipsoid(const Vector &p_circle) const {
    Vector p_ellipse(p_circle.x() * m1, p_circle.y() * m2, p_circle.z());
    Vector p_ellipsoid = ellipseToEllipsoid(p_ellipse);
    return p_ellipsoid;
}


// Ellipsoid
void Ellipsoid::init() {
    center = (f1 + f2) * 0.5;

    a = tau * 0.5;
    Float focusDistSqr = (f1 - center).squaredNorm();
    b = std::sqrt(a * a - focusDistSqr);
    c = b;

    Vector u, v, w;
    u = (f2 - f1).normalized();
    coordinateSystem(u, v, w);
    ellipsoid_frame.s = u;
    ellipsoid_frame.t = v;
    ellipsoid_frame.n = w;

    Matrix3x3 S = Vector(a, b, c).asDiagonal();
    Matrix3x3 R;
    R.col(0) = u;
    R.col(1) = v;
    R.col(2) = w;
    Matrix3x3 M = R * S;
    Vector delta(M.row(0).norm(), M.row(1).norm(), M.row(2).norm());
    bounds = AABB(center - delta, center + delta);
}

Ellipse Ellipsoid::intersectTri(const Vector &C1, const Vector &C2, const Vector &C3) {
    Vector sphere_C1 = worldToSphere(C1);
    Vector sphere_C2 = worldToSphere(C2);
    Vector sphere_C3 = worldToSphere(C3);
    Vector sphere_n = (sphere_C2 - sphere_C1).cross(sphere_C3 - sphere_C1).normalized();

    // Compute distance from the center of the sphere to the triangle plane
    Float sphere_dist = sphere_n.dot(sphere_C1);

    // Since the sphere is a unit sphere, if the distance <= 1 then they are intersected
    if (epsExclusiveGreater(sphere_dist, 1.0)) return Ellipse();

    // Convert points to the ellipsoid coordinate system.
    Vector sphere_O = sphere_dist * sphere_n;

    Vector ellipsoid_O(sphere_O.x() * a, sphere_O.y() * b, sphere_O.z() * c);

    Vector ellipsoid_C1 = worldToEllipsoid(C1);
    Vector ellipsoid_C2 = worldToEllipsoid(C2);
    Vector ellipsoid_C3 = worldToEllipsoid(C3);
    
    Vector abc(a, b, c);
    Ellipse elp(ellipsoid_C1, ellipsoid_C2, ellipsoid_C3, ellipsoid_O, abc, tau);
    return elp;
}

// utils
Vector Ellipsoid::worldToEllipsoid(const Vector &p_world) const {
    Vector p_local = p_world - center;
    p_local = ellipsoid_frame.toLocal(p_local);
    return p_local;
}

Vector Ellipsoid::ellipsoidToWorld(const Vector &p_ellipsoid) const {
    Vector p_world = ellipsoid_frame.toWorld(p_ellipsoid);
    return p_world + center;
}

Vector Ellipsoid::worldToSphere(const Vector &p_world) const {
    Vector p_ellipsoid = worldToEllipsoid(p_world);
    Vector p_sphere(p_ellipsoid.x() / a, p_ellipsoid.y() / b, p_ellipsoid.z() / c);
    return p_sphere;
}

Vector Ellipsoid::sphereToWorld(const Vector &p_sphere) const {
    Vector p_ellipsoid(p_sphere.x() * a, p_sphere.y() * b, p_sphere.z() * c);
    return ellipsoidToWorld(p_ellipsoid);
}