#include "ellipsoid_jacobian_AD.h"
#include <iostream>

// EllipseJacobianAD
void EllipseJacobianAD::init(const Vector3 &C1, const Vector3 &C2, const Vector3 &C3,
        const Vector3AD2 &abc, const FloatAD2 &tau, const FloatAD2 &phi) {
    // Compute an orthonormal basis of the triangle plane
    Vector3 T_ = (C2 - C1).normalized();
    Vector3 N_ = T_.cross(C3 - C1).normalized();
    Vector3 U_ = N_.cross(T_);
    Vector3AD2 T(T_), N(N_), U(U_);

    // TODO: Consider extending simpleAD.h to convert Array to Vector
    Array3AD2 D_ = 1.0 / abc.toArray1D().pow(2);
    Vector3AD2 D(D_[0], D_[1], D_[2]);

    FloatAD2 TTD = quadDotAD(T, T, D);
    FloatAD2 TUD = quadDotAD(T, U, D);
    FloatAD2 UUD = quadDotAD(U, U, D);
    FloatAD2 OOD = quadDotAD(O, O, D);

    // TODO: There is no atan2 in simpleAD.h. Should we implement it?
    // Compute theta according to SM 1.1. 
    theta = 0.5 * std::atan2(2.0 * TUD.val, (UUD - TTD).val);

    // Compute some useful intermediate values. 
    FloatAD2 Delta = almostEqual(TTD.val, UUD.val) ?
        2.0 * TUD.abs() :
        (4.0 * TUD.pow(2) + (TTD - UUD).pow(2)).sqrt();
    FloatAD2 DR1 = TTD + UUD - Delta;
    FloatAD2 DR2 = TTD + UUD + Delta;

    // Compute m1 and m2 according to SM 1.2.
    m1 = (2.0 * (1.0 - OOD) / DR1).sqrt();
    m2 = (2.0 * (1.0 - OOD) / DR2).sqrt();

    FloatAD2 cosTheta = theta.cos();
    FloatAD2 sinTheta = theta.sin();
    TN = T * cosTheta - U * sinTheta;
    UN = T * sinTheta + U * cosTheta;

    // Compute u, v
    Vector3AD2 P = O + TN * m1 * phi.cos() + UN * m2 * phi.sin();
    u = TN.dot(P);
    v = UN.dot(P);

    /*
    // for debugging
    printf("\n===== AD =====\n");
    printf("theta = %.6f\n", theta.val);
    printf("P = (%.6f, %.6f, %.6f)\n", P.x().val, P.y().val, P.z().val);
    printf("TN = (%.6f, %.6f, %.6f)\n", TN.x().val, TN.y().val, TN.z().val);
    printf("UN = (%.6f, %.6f, %.6f)\n", UN.x().val, UN.y().val, UN.z().val);
    printf("SM 3.1\n");
    printf("dTN = (%.6f, %.6f, %.6f)\n", TN.x().grad(0), TN.y().grad(0), TN.z().grad(0));
    printf("dUN = (%.6f, %.6f, %.6f)\n", UN.x().grad(0), UN.y().grad(0), UN.z().grad(0));
    printf("SM 3.2\n");
    printf("dm1 = %.6f, dm2 = %.6f\n", m1.grad(0), m2.grad(0));
    printf("SM 3.3\n");
    printf("dO = (%.6f, %.6f, %.6f)\n", O.x().grad(0), O.y().grad(0), O.z().grad(0));
    printf("others\n");
    printf("DuDtau = %.6f, DvDtao = %.6f\n", u.grad(0), v.grad(0));
    printf("DuDphi = %.6f, DvDphi = %.6f\n", u.grad(1), v.grad(1));
    */
}

Float EllipseJacobianAD::computeJacobian() {
    Float jacobian = u.grad(0) * v.grad(1) - u.grad(1) * v.grad(0);
    return std::abs(jacobian);
}

// utils
FloatAD2 EllipseJacobianAD::quadDotAD(const Vector3AD2 &u, const Vector3AD2 &v, const Vector3AD2 &diag) const {
    FloatAD2 res = u.x() * v.x() * diag.x();
    res += u.y() * v.y() * diag.y();
    res += u.z() * v.z() * diag.z();
    return res;
}

// EllipsoidJacobianAD
void EllipsoidJacobianAD::init() {
    center = (f1 + f2) * 0.5;

    a = tau * 0.5;
    Float focusDistSqr = (f1 - center).squaredNorm();
    b = (a * a - focusDistSqr).sqrt();
    c = b;

    Vector3 u, v, w;
    u = (f2 - f1).normalized();
    coordinateSystem(u, v, w);
    ellipsoid_frame.s = u;
    ellipsoid_frame.t = v;
    ellipsoid_frame.n = w;
}

EllipseJacobianAD EllipsoidJacobianAD::intersectTri(const Vector3 &C1, const Vector3 &C2, const Vector3 &C3) {
    Vector3AD2 sphere_C1 = worldToSphereAD(C1);
    Vector3AD2 sphere_C2 = worldToSphereAD(C2);
    Vector3AD2 sphere_C3 = worldToSphereAD(C3);
    Vector3AD2 sphere_n = (sphere_C2 - sphere_C1).cross(sphere_C3 - sphere_C1).normalized();

    FloatAD2 sphere_dist = sphere_n.dot(sphere_C1);

    Vector3AD2 sphere_O = sphere_dist * sphere_n;
    Vector3AD2 ellipsoid_O(sphere_O.x() * a, sphere_O.y() * b, sphere_O.z() * c);

    Vector3AD2 ellipsoid_C1 = worldToEllipsoidAD(C1);
    Vector3AD2 ellipsoid_C2 = worldToEllipsoidAD(C2);
    Vector3AD2 ellipsoid_C3 = worldToEllipsoidAD(C3);

    Vector3AD2 abc(a, b, c);
    EllipseJacobianAD elp(ellipsoid_C1.val, ellipsoid_C2.val, ellipsoid_C3.val, ellipsoid_O,
        abc, tau, phi);
    return elp;
}

// utils
Vector3AD2 EllipsoidJacobianAD::worldToEllipsoidAD(const Vector3AD2 &p_world) const {
    Vector3AD2 p_local = p_world - center;
    p_local = ellipsoid_frame.toLocal(p_local.val);
    return p_local;
}

Vector3AD2 EllipsoidJacobianAD::ellipsoidToWorldAD(const Vector3AD2 &p_ellipsoid) const {
    Vector3AD2 p_world = ellipsoid_frame.toWorld(p_ellipsoid.val);
    return p_world + center;
}

Vector3AD2 EllipsoidJacobianAD::worldToSphereAD(const Vector3AD2 &p_world) const {
    Vector3AD2 p_ellipsoid = worldToEllipsoidAD(p_world);
    Vector3AD2 p_sphere(p_ellipsoid.x() / a, p_ellipsoid.y() / b, p_ellipsoid.z() / c);
    return p_sphere;
}

Vector3AD2 EllipsoidJacobianAD::sphereToWorldAD(const Vector3AD2 &p_sphere) const {
    Vector3AD2 p_ellipsoid(p_sphere.x() * a, p_sphere.y() * b, p_sphere.z() * c);
    return ellipsoidToWorldAD(p_ellipsoid);
}