#include "ellipsoid_test.h"
#include "utils.h"
#include "../bsdf/diffuse.h"
#include "../emitter/area.h"
#include "epc_sampler.h"
#include "ellipsoid_jacobian_AD.h"
#include <iostream>
#include <assert.h>

static bool onEllipsoid(const Vector &p, const Ellipsoid &e) {
    Vector p_local_sphere = e.worldToSphere(p);

    bool flag0 = (abs(pow(p_local_sphere.x(), 2) + pow(p_local_sphere.y(), 2) + pow(p_local_sphere.z(), 2) - 1.0) < Epsilon);
    bool flag1 = (abs((p - e.f1).norm() + (p - e.f2).norm() - e.tau) < Epsilon);

    // inconsistent result
    assert(flag0 == flag1);

    return flag0;
}

// Assume that they are all in the ellipse frame for simpler check
static bool inTri(const Vector &p, const Vector &v1, const Vector &v2, const Vector &v3) {
    Float denom = (v2.y() - v3.y()) * 
        (v1.x() - v3.x()) + (v3.x() - v2.x()) * 
        (v1.y() - v3.y());

    Float tmp1 = p.x() - v3.x();
    Float tmp2 = p.y() - v3.y();
    Float u = ((v2.y() - v3.y()) * tmp1 + 
        (v3.x() - v2.x()) * tmp2) / denom;
    Float v = ((v3.y() - v1.y()) * tmp1 + 
        (v1.x() - v3.x()) * tmp2) / denom;
    Float w = 1.0 - u - v;
    return 0.0 <= u && u <= 1.0 && 
        0.0 <= v && v <= 1.0 && 
        0.0 <= w && w <= 1.0;
}

static void test1() {
    std::cout << "===== Testing whether a point is on an ellipsoid... =====" << std::endl;

    Vector f1(0, 0, 0);
    Vector n1(1, 0, 0);
    Vector f2(2, 0, 0);
    Vector n2(-1, 0, 0);
    Float path_length = 3.0;
    Ellipsoid e(f1, n1, f2, n2, path_length);

    Vector p(1, 0, sqrt(5.0 / 4.0));
    assert(onEllipsoid(p, e));

    std::cout << "Passed!" << std::endl;
}

static void test2() {
    std::cout << "===== Testing ellipsoid-triangle intersection... =====" << std::endl;

    Vector f1(0, 0, 0);
    Vector n1(1, 0, 0);
    Vector f2(2, 0, 0);
    Vector n2(-1, 0, 0);
    Float path_length = 3.0;
    Ellipsoid e(f1, n1, f2, n2, path_length);

    std::cout << "Set up an ellipsoid." << std::endl;

    Vector C1(0, 1, 0);
    Vector C2(0, 0, 1);
    Vector C3(0, 0, 0);
    Ellipse elp = e.intersectTri(C1, C2, C3);

    std::cout << "Finish computing intersections." << std::endl;

    if (elp.status == 0) {
        std::cout << "No intersection between the ellipsoid and the triangle." << std::endl;
        std::cout << "Can't proceed." << std::endl;
        return;
    }

    // std::cout << "Ellipse:" << std::endl;
    // std::cout << "\ttheta: " << elp.theta << std::endl;
    // std::cout << "\tm1: " << elp.m1 << std::endl;
    // std::cout << "\tm2: " << elp.m2 << std::endl;

    bool passed = true;
    // test if P is on the ellipsoid
    for (int i = 0; i < 100; i++) {
        Float phi = 2 * M_PI / 100.0 * i;
        Vector3 p(cos(phi), sin(phi), 0.0);
        p = elp.circleToEllipsoid(p);
        p = e.ellipsoidToWorld(p);

        if (!onEllipsoid(p, e)) {
            printf("failed: p(%f, %f, %f) is not on the ellipsoid\n", p.x(), p.y(), p.z());
            passed = false;
        }
    }
    if (passed) {
        printf("Passed!\n");
    }
    else {
        printf("Failed!\n");
    }
}

static void test3() {
    std::cout << "===== Testing elliptic curve sampling... =====" << std::endl;

    std::srand((unsigned int)time(0));

    Vector f1(0, 0, 0);
    Vector n1(1, 0, 0);
    Vector f2(2, 0, 0);
    Vector n2(-1, 0, 0);
    f1 += 0.1 * Eigen::Vector3d::Random();
    f2 += 0.1 * Eigen::Vector3d::Random();

    Float path_length = 3.0;
    Ellipsoid e(f1, n1, f2, n2, path_length);

    Vector C1(0, 1, 0);
    Vector C2(0, 0, 1);
    Vector C3(-0.1, 0.2, 0);

    C1 += 0.1 * Eigen::Vector3d::Random();
    C2 += 0.1 * Eigen::Vector3d::Random();
    C3 += 0.1 * Eigen::Vector3d::Random();

    /*
    Vector C1(-10, -10, 1);
    Vector C2(10, -10, 1.1);
    Vector C3(10, 10, 0.9);
    */

    Ellipse elp = e.intersectTri(C1, C2, C3);

    if (!elp.canSample()) {
        std::cout << "No elliptic curve to sample from!" << std::endl;
        std::cout << "Can't proceed." << std::endl;
        return;
    }

    std::cout << "Sampling from " << elp.curves.size() << " curves..." << std::endl;

    Vector C1_ellipse = elp.ellipsoidToEllipse(e.worldToEllipsoid(C1));
    Vector C2_ellipse = elp.ellipsoidToEllipse(e.worldToEllipsoid(C2));
    Vector C3_ellipse = elp.ellipsoidToEllipse(e.worldToEllipsoid(C3));

    bool passed = true;
    for (int i = 0; i < 100; i++) {
        Float rn = (i + 0.5) / 100.0;
        Float phi;
        Float pdf;
        Float jacobian;
        Vector3 p_ellipsoid = elp.sampleEllipticCurve(rn, phi, pdf, jacobian);
        Vector3 p = e.ellipsoidToWorld(p_ellipsoid);
        Vector3 p_ellipse = elp.ellipsoidToEllipse(p_ellipsoid);

        if (!onEllipsoid(p, e)) {
            printf("failed: p(%f, %f, %f) is not on the ellipsoid\n", p.x(), p.y(), p.z());
            passed = false;
        }

        if (!inTri(p_ellipse, C1_ellipse, C2_ellipse, C3_ellipse)) {
            printf("failed: p(%f, %f, %f) is inside the triangle\n", p.x(), p.y(), p.z());
            passed = false;
        }

        // check Jacobian
        FloatAD2 tauAD, phiAD;
        tauAD.val = path_length;
        tauAD.grad(0) = 1;
        tauAD.grad(1) = 0;
        phiAD.val = phi;
        phiAD.grad(0) = 0;
        phiAD.grad(1) = 1;

        EllipsoidJacobianAD elpJac(f1, f2, tauAD, phiAD);
        EllipseJacobianAD eJac = elpJac.intersectTri(C1, C2, C3);
        Float jacobianAD = eJac.computeJacobian();
        if (std::abs(jacobian - jacobianAD) > Epsilon) {
            printf("failed: Jacobian from formula: %.6f != AD: %.6f\n", jacobian, jacobianAD);
            passed = false;
        }
    }
    if (passed) {
        printf("Passed!\n");
    }
    else {
        printf("Failed!\n");
    }
}

static void test4() {
    std::cout << "===== Testing BVH building... =====" << std::endl;
    AreaLight area(0, Spectrum3f(10.0f, 10.0f, 10.0f));
    DiffuseBSDF diffuseBSDF(Spectrum3f(0.5f, 0.5f, 0.5f));
    Camera camera;
    RndSampler sampler(123, 0);

    float cubeVtxPositions[] = {
        -0.5f, -0.5f, -0.5f,
         0.5f, -0.5f, -0.5f,
         0.5f,  0.5f, -0.5f,
        -0.5f,  0.5f, -0.5f,
        -0.5f, -0.5f,  0.5f,
         0.5f, -0.5f,  0.5f,
         0.5f,  0.5f,  0.5f,
        -0.5f,  0.5f,  0.5f
    };

    int cubeVtxIndices[] = {
        0, 2, 1,
        0, 3, 2,
        0, 1, 5,
        0, 5, 4,
        1, 2, 6,
        1, 6, 5,
        4, 5, 6,
        4, 6, 7,
        3, 6, 2,
        3, 7, 6,
        0, 7, 3,
        0, 4, 7
    };

    float planeVtxPositions[] = {
        -100.0f, -100.0f, -2.0f,
         100.0f, -100.0f, -2.0f,
         100.0f,  100.0f, -2.0f,
        -100.0f,  100.0f, -2.0f, 
    };

    int planeVtxIndices[] = {
        1, 2, 0,
        2, 3, 0,
    };

    float lightVtxPositions[] = {
        -1.0f, -1.0f, 5.0f,
         1.0f, -1.0f, 5.0f,
         1.0f,  1.0f, 5.0f,
        -1.0f,  1.0f, 5.0f, 
    };

    int lightVtxIndices[] = {
        0, 2, 1,
        0, 3, 2,
    }; 

    Shape cube(cubeVtxPositions, cubeVtxIndices, nullptr, nullptr, 8, 12, -1, 0, -1, -1);
    Shape plane(planeVtxPositions, planeVtxIndices, nullptr, nullptr, 4, 2, -1, 0, -1, -1);
    Shape light(lightVtxPositions, lightVtxIndices, nullptr, nullptr, 4, 2, 0, 0, -1, -1);

    area.shape_id = 2;

    Scene scene(camera,
                std::vector<const Shape*>{&cube, &plane, &light},
                std::vector<const BSDF*>{&diffuseBSDF},
                std::vector<const Emitter*>{&area},
                std::vector<const PhaseFunction*>(),
                std::vector<const Medium*>());
}

static void test5() {
    std::cout << "===== Testing camera with multiple bins... =====" << std::endl;
    Camera camera;
    camera.setPIF(2, 1200.0, 50.0, 8, 50.0);
    const auto &pif = camera.pif;
    bool flag = true;

    Float path_length = pif->tau.val - 3 * pif->deltaTau.val;
    while (path_length < pif->tau.val + 3 * pif->deltaTau.val) {
        Vector2i bin_range = pif->getBinIndexRange(path_length);

        for (int i = 0; i < bin_range[0]; i++) {
            Float pif_kernel = pif->eval(path_length, i);
            if (pif_kernel > 0.005)
                flag = false;
        }

        for (int i = bin_range[1] + 1; i < pif->num_bins; i++) {
            Float pif_kernel = pif->eval(path_length, i);
            if (pif_kernel > 0.005)
                flag = false;
        }

        if (!flag) {
            std::cout << "Failed test5: camera with multiple bins!" << std::endl;
            return;
        }

        path_length += pif->stepTau * pif->num_bins / 5000.0;
    }

    std::cout << "Passed test5!" << std::endl;
}

void ellipsoid_test() {
    test1();
    test2();
    test3();
    test4();
    test5();
}