#include "ellipsoidAD.h"
#include <iostream>

// Ellipse
void EllipseAD::init(const VectorAD &C1, const VectorAD &C2, const VectorAD &C3,
    const VectorAD &abc, const FloatAD &tau) {

    // Compute an orthonormal basis of the triangle plane
    VectorAD T = (C2 - C1).normalized();
    N = T.cross(C3 - C1).normalized();
    VectorAD U = N.cross(T);

    // TODO: Consider extending simpleAD.h to convert Array to Vector
    Array3AD D_ = 1.0 / abc.toArray1D().pow(2);
    VectorAD D(D_[0], D_[1], D_[2]);

    FloatAD TTD = quadDotAD(T, T, D);
    FloatAD TUD = quadDotAD(T, U, D);
    FloatAD UUD = quadDotAD(U, U, D);
    FloatAD OOD = quadDotAD(O, O, D);

    // Compute theta according to SM 1.1. 
    theta = 0.5 * (2.0 * TUD).atan2(UUD - TTD);

    // Compute some useful intermediate values. 
    FloatAD Delta = almostEqual(TTD.val, UUD.val) ?
        2.0 * TUD.abs() :
        (4.0 * TUD.pow(2) + (TTD - UUD).pow(2)).sqrt();
    FloatAD DR1 = TTD + UUD - Delta;
    FloatAD DR2 = TTD + UUD + Delta;

    // Compute m1 and m2 according to SM 1.2.
    m1 = (2.0 * (1.0 - OOD) / DR1).sqrt();
    m2 = (2.0 * (1.0 - OOD) / DR2).sqrt();

    FloatAD cosTheta = theta.cos();
    FloatAD sinTheta = theta.sin();
    TN = T * cosTheta - U * sinTheta;
    UN = T * sinTheta + U * cosTheta;

    ellipse_frame.s = TN;
    ellipse_frame.t = UN;
    ellipse_frame.n = N;

    VectorAD circle_c[3];
    circle_c[0] = ellipsoidToCircleAD(C1);
    circle_c[1] = ellipsoidToCircleAD(C2);
    circle_c[2] = ellipsoidToCircleAD(C3);
    FloatAD denom = (circle_c[1].y() - circle_c[2].y()) *
        (circle_c[0].x() - circle_c[2].x()) + (circle_c[2].x() - circle_c[1].x()) *
        (circle_c[0].y() - circle_c[2].y());

    auto isInTri = [&](const VectorAD &p) -> bool {
        FloatAD tmp1 = p.x() - circle_c[2].x();
        FloatAD tmp2 = p.y() - circle_c[2].y();
        FloatAD u = ((circle_c[1].y() - circle_c[2].y()) * tmp1 +
            (circle_c[2].x() - circle_c[1].x()) * tmp2) / denom;
        FloatAD v = ((circle_c[2].y() - circle_c[0].y()) * tmp1 +
            (circle_c[0].x() - circle_c[2].x()) * tmp2) / denom;
        FloatAD w = 1.0 - u - v;
        return epsInclusiveGreater(u.val, 0) && epsInclusiveLesser(u.val, 1.0) &&
            epsInclusiveGreater(v.val, 0) && epsInclusiveLesser(v.val, 1.0) &&
            epsInclusiveGreater(w.val, 0) && epsInclusiveLesser(w.val, 1.0);
    };

    std::vector<FloatAD> phis;

    // Find intersections with each side of the triangle
    for (size_t i = 0; i < 3; i++) {
        size_t j = i == 2 ? 0 : i + 1;

        FloatAD x1 = circle_c[i].x();
        FloatAD y1 = circle_c[i].y();
        FloatAD x2 = circle_c[j].x();
        FloatAD y2 = circle_c[j].y();

        FloatAD a = (x1 - x2).pow(2) + (y1 - y2).pow(2);
        FloatAD b = 2.0 * (x1 * x2 - x2 * x2 + y1 * y2 - y2 * y2);
        FloatAD c = x2 * x2 + y2 * y2 - 1.0;
        FloatAD disc = b * b - 4.0 * a * c;

        // No intersection
        if (epsExclusiveLesser(disc, 0.0)) continue;

        // One intersection
        if (almostEqual(disc, 0.0)) {
            FloatAD alpha = -0.5 * b / a;
            if (!almostEqual(alpha, 0.0) && epsExclusiveGreater(alpha, 0.0)
                && epsInclusiveLesser(alpha, 1.0)) {
                VectorAD itst = alpha * circle_c[i] + (1.0 - alpha) * circle_c[j];
                phis.push_back(pointToAngleAD(itst));
            }
        }
        // Two intersection
        else {
            disc = disc.sqrt();
            FloatAD alpha1 = (-b + disc) * 0.5 / a;
            FloatAD alpha2 = (-b - disc) * 0.5 / a;
            if (!almostEqual(alpha1, 0.0) && epsExclusiveGreater(alpha1, 0.0)
                && epsInclusiveLesser(alpha1, 1.0)) {
                VectorAD itst = alpha1 * circle_c[i] + (1.0 - alpha1) * circle_c[j];
                phis.push_back(pointToAngleAD(itst));
            }
            if (!almostEqual(alpha2, 0.0) && epsExclusiveGreater(alpha2, 0.0)
                && epsInclusiveLesser(alpha2, 1.0)) {
                VectorAD itst = alpha2 * circle_c[i] + (1.0 - alpha2) * circle_c[j];
                phis.push_back(pointToAngleAD(itst));
            }
        }
    }

    if (phis.empty()) {
        // The circle is completely inside the triangle
        if (isInTri(VectorAD(0, 0, 0)) && isInTri(VectorAD(1, 0, 0))
            && isInTri(VectorAD(0, 1, 0))) {
            phis.push_back(0.0);
        }
        else {
            status = 1;
            return;
        }
    }

    // Map phi into [0, 2pi]
    for (FloatAD &phi : phis) {
        // TODO: convert this into its AD version
        phi.val = fmod(phi.val, 2.0 * M_PI);
        if (epsExclusiveLesser(phi, 0.0)) phi.val += 2.0 * M_PI;
    }

    // Sort all the angles
    std::sort(phis.begin(), phis.end());

    for (size_t i = 0; i < phis.size(); i++) {
        size_t j;
        FloatAD start = phis[i];
        FloatAD end;
        if (i == phis.size() - 1) {
            j = 0;
            end = phis[j] + 2.0 * M_PI;
        }
        else {
            j = i + 1;
            end = phis[j];
        }

        VectorAD mid = angleToPointAD((start + end) * 0.5); // midpoint of the curve 

        // The curve is inside the triangle
        if (isInTri(mid)) {
            curves.push_back(Vector2AD(start, end));
        }
    }

    if (curves.empty()) {
        status = 2;
    }
    else {
        status = 3;

        // Compute useful values for Jacobian computation.
        Array3AD tmp_ = 1.0 / abc.toArray1D().pow(4);
        Array3AD E_ = -abc.x() * tmp_;
        VectorAD tmp(tmp_[0], tmp_[1], tmp_[2]);
        VectorAD E(E_[0], E_[1], E_[2]);

        FloatAD TTE = quadDotAD(T, T, E);
        FloatAD TUE = quadDotAD(T, U, E);
        FloatAD UUE = quadDotAD(U, U, E);
        FloatAD OOE = quadDotAD(O, O, E);

        // Compute dTN and dUN according to SM 3.1.
        FloatAD dDelta;
        if (epsExclusiveGreater(Delta.val, 0.0)) {
            dDelta = (4.0 * TUD * TUE + (TTE - UUE) * (TTD - UUD)) / Delta;
            FloatAD Delta2 = Delta * Delta;

            FloatAD nsnthetadtheta = (Delta * (UUE - TTE) - dDelta * (UUD - TTD)) / Delta2;
            FloatAD cnthetadtheta = 2.0 * (Delta * TUE - dDelta * TUD) / Delta2;

            dTN = T * nsnthetadtheta - U * cnthetadtheta;
            dUN = T * cnthetadtheta + U * nsnthetadtheta;
        }
        else {
            dDelta.zero();
            dTN.zero();
            dUN.zero();
        }

        // Compute dO according to SM 3.3.
        Array3AD Tabc(T.toArray1D() / abc.toArray1D().pow(2));
        Array3AD Uabc(U.toArray1D() / abc.toArray1D().pow(2));

        Matrix3x3 dOM_val;
        dOM_val.row(0) << Tabc(0).val, Tabc(1).val, Tabc(2).val;
        dOM_val.row(1) << Uabc(0).val, Uabc(1).val, Uabc(2).val;
        dOM_val.row(2) << N(0).val, N(1).val, N(2).val;
        std::array<Matrix3x3, nder> dOM_der;
        for (size_t i = 0; i < nder; i++) {
            dOM_der[i].row(0) << Tabc.grad(i)(0), Tabc.grad(i)(1), Tabc.grad(i)(2);
            dOM_der[i].row(1) << Uabc.grad(i)(0), Uabc.grad(i)(1), Uabc.grad(i)(2);
            dOM_der[i].row(2) << N.grad(i)(0), N.grad(i)(1), N.grad(i)(2);
        }
        Matrix3x3AD dOMinv = Matrix3x3AD(dOM_val, dOM_der).inverse();

        VectorAD dOV(quadDotAD(T, O, tmp), quadDotAD(U, O, tmp), 0.0);
        dO = 0.5 * tau * dOMinv * dOV;

        // Compute dm1 and dm2 according to SM 3.2.
        FloatAD NR = 2.0 * (1.0 - OOD);
        FloatAD OdOD = quadDotAD(O, dO, D);
        FloatAD dNR = -2.0 * OOE - 4.0 * OdOD;
        FloatAD dDR1 = TTE + UUE - dDelta;
        FloatAD dDR2 = TTE + UUE + dDelta;
        dm1 = (DR1 * dNR - dDR1 * NR) / (2.0 * DR1.pow(2) * m1);
        dm2 = (DR2 * dNR - dDR2 * NR) / (2.0 * DR2.pow(2) * m2);
    }
}

VectorAD EllipseAD::sampleEllipticCurve(const Float &rn, Float &phi, Float &pdf, FloatAD &jacobian) {
    std::vector<Float> cumsum;
    cumsum.push_back(0.0);

    // Compute the total curve length.
    for (const Vector2AD& curve : curves) {
        cumsum.push_back(cumsum[cumsum.size() - 1] + curve.y().val - curve.x().val);
    }

    // Compute the pdf for uniform sampling.
    pdf = 1.0 / cumsum[cumsum.size() - 1];

    // Find the sampled angle.
    Float scaled_rn = rn * cumsum[cumsum.size() - 1];
    phi = 0;

    for (size_t i = 0; i < curves.size(); i++) {
        double len_interval = cumsum[i + 1] - cumsum[i];
        if (scaled_rn <= len_interval) {
            phi = curves[i].x().val + scaled_rn;
            break;
        }
        else {
            scaled_rn -= len_interval;
        }
    }
    phi = fmod(phi, 2.0 * M_PI);

    // Compute the jacobian.
    Float cp = std::cos(phi);
    Float sp = std::sin(phi);
    Float cp2 = cp * cp;
    Float sp2 = sp * sp;
    jacobian =
        m2 * dm1 * cp2 +
        m1 * dm2 * sp2 +
        dO.dot(m2 * cp * TN + m1 * sp * UN) +
        O.dot(m2 * cp * dTN + m1 * sp * dUN);
    jacobian = jacobian.abs();

    // Map the angle to a point the ellipsoid frame.
    VectorAD p_circle = angleToPointAD(phi);
    VectorAD p_ellipsoid = circleToEllipsoidAD(p_circle);

    return p_ellipsoid;
}

// utils
FloatAD EllipseAD::quadDotAD(const VectorAD &u, const VectorAD &v, const VectorAD &diag) const {
    FloatAD res = u.x() * v.x() * diag.x();
    res += u.y() * v.y() * diag.y();
    res += u.z() * v.z() * diag.z();
    return res;
}

FloatAD EllipseAD::pointToAngleAD(const VectorAD &p) const {
    FloatAD angle(p.y().atan2(p.x()));
    return angle;
}

VectorAD EllipseAD::angleToPointAD(const FloatAD &angle) const {
    VectorAD p_circle(angle.cos(), angle.sin(), 0.0);
    return p_circle;
}

VectorAD EllipseAD::ellipsoidToEllipseAD(const VectorAD &p_ellipsoid) const {
    VectorAD p_ellipse = p_ellipsoid - O;
    p_ellipse = ellipse_frame.toLocal(p_ellipse);
    return p_ellipse;
}

VectorAD EllipseAD::ellipseToEllipsoidAD(const VectorAD &p_ellipse) const {
    VectorAD p_ellipsoid = ellipse_frame.toWorld(p_ellipse);
    p_ellipsoid = p_ellipsoid + O;
    return p_ellipsoid;
}

VectorAD EllipseAD::ellipsoidToCircleAD(const VectorAD &p_ellipsoid) const {
    VectorAD p_ellipse = ellipsoidToEllipseAD(p_ellipsoid);
    VectorAD p_circle(p_ellipse.x() / m1, p_ellipse.y() / m2, p_ellipse.z());
    return p_circle;
}

VectorAD EllipseAD::circleToEllipsoidAD(const VectorAD &p_circle) const {
    VectorAD p_ellipse(p_circle.x() * m1, p_circle.y() * m2, p_circle.z());
    VectorAD p_ellipsoid = ellipseToEllipsoidAD(p_ellipse);
    return p_ellipsoid;
}


// Ellipsoid
void EllipsoidAD::init() {
    center = (f1 + f2) * 0.5;

    a = tau * 0.5;
    FloatAD focusDistSqr = (f1 - center).squaredNorm();
    b = (a * a - focusDistSqr).sqrt();
    c = b;

    VectorAD u, v, w;
    u = (f2 - f1).normalized();
    coordinateSystemAD(u, v, w);
    ellipsoid_frame.s = u;
    ellipsoid_frame.t = v;
    ellipsoid_frame.n = w;

    Matrix3x3 S = Vector(a.val, b.val, c.val).asDiagonal();
    Matrix3x3 R;
    R.col(0) = u.val;
    R.col(1) = v.val;
    R.col(2) = w.val;
    Matrix3x3 M = R * S;
    Vector delta(M.row(0).norm(), M.row(1).norm(), M.row(2).norm());
    bounds = AABB(center.val - delta, center.val + delta);
}

EllipseAD EllipsoidAD::intersectTri(const VectorAD &C1, const VectorAD &C2, const VectorAD &C3) {
    VectorAD sphere_C1 = worldToSphereAD(C1);
    VectorAD sphere_C2 = worldToSphereAD(C2);
    VectorAD sphere_C3 = worldToSphereAD(C3);
    VectorAD sphere_n = (sphere_C2 - sphere_C1).cross(sphere_C3 - sphere_C1).normalized();

    // Compute distance from the center of the sphere to the triangle plane
    FloatAD sphere_dist = sphere_n.dot(sphere_C1);

    // Since the sphere is a unit sphere, if the distance <= 1 then they are intersected
    if (epsExclusiveGreater(sphere_dist, 1.0)) return EllipseAD();

    // Convert points to the ellipsoid coordinate system.
    VectorAD sphere_O = sphere_dist * sphere_n;

    VectorAD ellipsoid_O(sphere_O.x() * a, sphere_O.y() * b, sphere_O.z() * c);

    VectorAD ellipsoid_C1 = worldToEllipsoidAD(C1);
    VectorAD ellipsoid_C2 = worldToEllipsoidAD(C2);
    VectorAD ellipsoid_C3 = worldToEllipsoidAD(C3);

    VectorAD abc(a, b, c);
    EllipseAD elp(ellipsoid_C1, ellipsoid_C2, ellipsoid_C3, ellipsoid_O, abc, tau);
    return elp;
}

// utils
VectorAD EllipsoidAD::worldToEllipsoidAD(const VectorAD &p_world) const {
    VectorAD p_local = p_world - center;
    p_local = ellipsoid_frame.toLocal(p_local);
    return p_local;
}

VectorAD EllipsoidAD::ellipsoidToWorldAD(const VectorAD &p_ellipsoid) const {
    VectorAD p_world = ellipsoid_frame.toWorld(p_ellipsoid);
    return p_world + center;
}

VectorAD EllipsoidAD::worldToSphereAD(const VectorAD &p_world) const {
    VectorAD p_ellipsoid = worldToEllipsoidAD(p_world);
    VectorAD p_sphere(p_ellipsoid.x() / a, p_ellipsoid.y() / b, p_ellipsoid.z() / c);
    return p_sphere;
}

VectorAD EllipsoidAD::sphereToWorldAD(const VectorAD &p_sphere) const {
    VectorAD p_ellipsoid(p_sphere.x() * a, p_sphere.y() * b, p_sphere.z() * c);
    return ellipsoidToWorldAD(p_ellipsoid);
}
