#include "tof_test.h"
#include "utils.h"
#include "stats.h"
#include "../bsdf/diffuse.h"
#include "../emitter/area.h"
#include "epc_sampler.h"
#include "epc_sampler_AD.h"
#include "sphere_sampler.h"
#include "sphere_sampler_AD.h"
#include <iostream>
#include <assert.h>

static void test1() {
    printf("===== Testing v_{tau}(p_0) =====\n");
    std::srand((unsigned int)time(0));

    AreaLight area(0, Spectrum3f(10.0f, 10.0f, 10.0f));
    DiffuseBSDF diffuseBSDF(Spectrum3f(0.5f, 0.5f, 0.5f));
    Camera camera;
    RndSampler sampler(1234, 0);

    float planeVtxPositions[] = {
        -100.f, 1.f, -100.f,
        100.f, 1.f, -100.f,
        100.f, 1.f, 100.f
    };

    int planeVtxIndices[] = {
        0, 1, 2,
    };

    float lightVtxPositions[] = {
        -100.f, -100.f, -100.f,
        100.f, -100.f, -100.f,
        100.f, -100.f, 100.f
    };

    int lightVtxIndices[] = {
        0, 2, 1,
    };

    Shape plane(planeVtxPositions, planeVtxIndices, nullptr, nullptr, 3, 1, 0, 0, -1, -1);
    Shape light(lightVtxPositions, lightVtxIndices, nullptr, nullptr, 3, 1, -1, 0, -1, -1);

    Eigen::Matrix<Float, 6, 3> dx;
    for (int i = 0; i < 6; i++)
        for (int j = 0; j < 3; j++)
            dx(i, j) = 0.0;
    dx(1, 0) = 1.0;
    dx(1, 1) = 1.0;
    dx(1, 2) = 1.0;
    //dx.block(0, 0, 3, 1) = Vector(1, 0, 0); //Eigen::Vector3d::Random();
    dx.block(0, 0, 3, 1) = Eigen::Vector3d::Random();
    //dx.block(0, 1, 3, 1) = dx.block(0, 0, 3, 1);
    //dx.block(0, 2, 3, 1) = dx.block(0, 0, 3, 1);
    plane.initVelocities(dx);

    VectorAD v0 = plane.getVertexAD(0);
    VectorAD v1 = plane.getVertexAD(1);
    VectorAD v2 = plane.getVertexAD(2);
    printf("velocity at v0 = (%.3f, %.3f, %.3f)\n", v0.grad(0)[0], v0.grad(0)[1], v0.grad(0)[2]);
    printf("velocity at v1 = (%.3f, %.3f, %.3f)\n", v1.grad(0)[0], v1.grad(0)[1], v1.grad(0)[2]);
    printf("velocity at v2 = (%.3f, %.3f, %.3f)\n", v2.grad(0)[0], v2.grad(0)[1], v2.grad(0)[2]);

    area.shape_id = 0;

    Scene scene(camera,
        std::vector<const Shape*>{&plane, &light},
        std::vector<const BSDF*>{&diffuseBSDF},
        std::vector<const Emitter*>{&area},
        std::vector<const PhaseFunction*>(),
        std::vector<const Medium*>());

    Vector x0(50, 0, -50);
    Vector n0(0, 1, 0);
    Float tau = 3.0;
    SphereSamplerAD sphere_sampler(x0, n0, tau, scene);

    int cnt = 0;
    printf("Start: Test d(path_length)/d(tau)\n");
    for (int i = 0; i < 10000; i++) {
        VectorAD x0;
        x0.val = 0.1 * Eigen::Vector3d::Random();

        FloatAD path_length = 3.0 + sampler.next1D() * 1.0;
        path_length.grad(1) = 1.0;

        SphereSamplerAD epc_sampler_AD(x0, n0, path_length, scene);

        Vector2 rnd2D = sampler.next2D();

        IntersectionAD x_its;
        FloatAD x_j;
        Float x_pdf;
        if (!epc_sampler_AD.sample(rnd2D, x_its, x_j, x_pdf)) continue;

        FloatAD res = (x_its.p - x0).norm();
        if (std::abs(res.grad(1) - 1.0) > Epsilon) {
            printf("Failed! d(path_length)/d(tau) = %.6f\n", res.grad(1));
        }
        else {
            cnt++;
        }

    }
    printf("Passed: Test d(path_length)/d(tau) with %d examples\n", cnt);
    
    bool flag = true;
    for (int test_i = 0; test_i < 10; test_i++) {
        printf("=== Case %d ===\n", test_i);

        //PositionSamplingRecord pRec;
        //plane.samplePosition(sampler.next2D(), pRec);
        IntersectionAD its;
        FloatAD jacobian;
        Float pdf;
        if (!sphere_sampler.sample(sampler.next2D(), its, jacobian, pdf)) continue;
        
        //printf("%.3f, %.3f, %.3f\n", its.p[0], its.p[1], its.p[2]);
        //printf("%d\n", itsAD.isValid());

        printf("u: %.3f; %.3f, v: %.3f; %.3f\n", its.barycentric.x().val, its.barycentric.x().grad(0), 
            its.barycentric.y().val, its.barycentric.y().grad(0));

        // Approach 1
        VectorAD x1, n1;
        FloatAD J;
        scene.getPoint(its.toIntersection(), x1, n1, J);
        Float v_path_length = (x1 - x0).norm().grad(0);

        printf("x1 = (%.3f, %.3f, %.3f); (%.3f, %.3f, %.3f)\n", x1.val[0], x1.val[1], x1.val[2],
            x1.grad(0)[0], x1.grad(0)[1], x1.grad(0)[2]);

        // Approach 2
        FloatAD proj_dist = -n1.dot(x1 - x0);
        Vector circle_center = x0 - n1.val * proj_dist.val;
        Float circle_radius = (x1.val - circle_center).norm();
        //printf("circle_center = (%.3f, %.3f, %.3f)\n", circle_center[0], circle_center[1], circle_center[2]);
        Vector p1_norm = (x1.val - circle_center).normalized();

        RayAD ray(x0, (x1 - x0).normalized());
        
        VectorAD y = its.p;
        printf("y = (%.3f, %.3f, %.3f); (%.3f, %.3f, %.3f)\n", y.val[0], y.val[1], y.val[2],
            y.grad(0)[0], y.grad(0)[1], y.grad(0)[2]);

        const Vector3i &ind = plane.indices[0];
        const VectorAD &v0 = plane.vertices[ind[0]], &v1 = plane.vertices[ind[1]], &v2 = plane.vertices[ind[2]];
        FloatAD u = 0.5 * (y - v0).cross(v2 - v0).norm() / plane.getAreaAD(0);
        FloatAD v = 0.5 * (y - v0).cross(v1 - v0).norm() / plane.getAreaAD(0);
        printf("u: %.3f; %.3f, v: %.3f; %.3f\n", u.val, u.grad(0), v.val, v.grad(0));
        VectorAD p1 = (1.0f - u - v)*VectorAD(v0.val) + u * VectorAD(v1.val) + v * VectorAD(v2.val);

        printf("p1 = (%.3f, %.3f, %.3f); (%.3f, %.3f, %.3f)\n", p1.val[0], p1.val[1], p1.val[2],
            p1.grad(0)[0], p1.grad(0)[1], p1.grad(0)[2]);
        printf("p1_norm = (%.3f, %.3f, %.3f)\n", p1_norm[0], p1_norm[1], p1_norm[2]);
        printf("pdf = %.6f\n", pdf);
        
        Float v_tau = p1.grad(0).dot(p1_norm);

        printf("%.6f, %.6f\n", v_path_length * jacobian.val, v_tau * circle_radius);
        if (std::abs(v_path_length * jacobian.val + v_tau * circle_radius) > Epsilon) flag = false;
    }

    if (flag)
        printf("Test passed!\n");
    else
        printf("Test failed...\n");
}

static void test2() {
    printf("===== Testing int v_{tau}(p_0) =====\n");
    std::srand((unsigned int)time(0));

    AreaLight area(0, Spectrum3f(10.0f, 10.0f, 10.0f));
    DiffuseBSDF diffuseBSDF(Spectrum3f(0.5f, 0.5f, 0.5f));
    Camera camera;
    RndSampler sampler(1234, 0);

    float planeVtxPositions[] = {
        -100.f, 1.f, -100.f,
        100.f, 1.f, -100.f,
        100.f, 1.f, 100.f
    };

    int planeVtxIndices[] = {
        0, 1, 2,
    };

    float lightVtxPositions[] = {
        -100.f, -100.f, -100.f,
        100.f, -100.f, -100.f,
        100.f, -100.f, 100.f
    };

    int lightVtxIndices[] = {
        0, 2, 1,
    };

    Shape plane(planeVtxPositions, planeVtxIndices, nullptr, nullptr, 3, 1, 0, 0, -1, -1);
    Shape light(lightVtxPositions, lightVtxIndices, nullptr, nullptr, 3, 1, -1, 0, -1, -1);

    Eigen::Matrix<Float, 6, 3> dx;
    for (int i = 0; i < 6; i++)
        for (int j = 0; j < 3; j++)
            dx(i, j) = 0.0;
    //dx(1, 0) = 1.0;
    //dx(1, 1) = 1.0;
    //dx(1, 2) = 1.0;
    //dx.block(0, 0, 3, 1) = Vector(1, 0, 0);
    dx.block(0, 0, 3, 1) = Eigen::Vector3d::Random();
    //dx.block(0, 1, 3, 1) = dx.block(0, 0, 3, 1);
    //dx.block(0, 2, 3, 1) = dx.block(0, 0, 3, 1);
    plane.initVelocities(dx);

    VectorAD v0 = plane.getVertexAD(0);
    VectorAD v1 = plane.getVertexAD(1);
    VectorAD v2 = plane.getVertexAD(2);
    printf("velocity at v0 = (%.3f, %.3f, %.3f)\n", v0.grad(0)[0], v0.grad(0)[1], v0.grad(0)[2]);
    printf("velocity at v1 = (%.3f, %.3f, %.3f)\n", v1.grad(0)[0], v1.grad(0)[1], v1.grad(0)[2]);
    printf("velocity at v2 = (%.3f, %.3f, %.3f)\n", v2.grad(0)[0], v2.grad(0)[1], v2.grad(0)[2]);

    area.shape_id = 0;

    Scene scene(camera,
        std::vector<const Shape*>{&plane, &light},
        std::vector<const BSDF*>{&diffuseBSDF},
        std::vector<const Emitter*>{&area},
        std::vector<const PhaseFunction*>(),
        std::vector<const Medium*>());

    Vector x0(50, 0, -50);
    Vector n0(0, 1, 0);
    Float tau = 3.0;
    SphereSamplerAD sphere_sampler(x0, n0, tau, scene);

    int spp = 1000000;

    Float deltaTau = 0.001;
    Statistics res0;
    // Approximated ref.
    for (int test_i = 0; test_i < spp; test_i++) {
        //PositionSamplingRecordAD pRec;
        //plane.samplePositionAD(sampler.next2D(), pRec);
        //FloatAD path_length = (pRec.p - x0).norm();
        //Float pdf = 1.0 / plane.getArea(0);

        Vector2 uv = squareToStdNormal(sampler.next2D());
        Float tau_sampled[2];
        tau_sampled[0] = uv.x() * deltaTau + tau;
        Float pdf = normalPdf(tau_sampled[0], tau, deltaTau);
        tau_sampled[1] = 2 * tau - tau_sampled[0];

        Vector2 rnd2 = sampler.next2D();
        Float tmp_res = 0.0;
        for (int k = 0; k < 1; k++) {
            Intersection its;
            Float jacobian;
            Float x_pdf;
            SphereSampler sphere_sampler(x0, n0, tau_sampled[k], scene);
            if (!sphere_sampler.sample(rnd2, its, jacobian, x_pdf)) continue;

            VectorAD x1, n1;
            FloatAD J;
            scene.getPoint(its, x1, n1, J);
            FloatAD path_length = (x1 - x0).norm();

            //printf("k = %d, (%.3f, %.3f, %.3f)\n", k, x1.val[0], x1.val[1], x1.val[2]);

            FloatAD val = normalPdfAD(path_length, tau, deltaTau) * path_length.grad(0) * jacobian;
            tmp_res += val.val / (pdf * x_pdf);
        }
        res0.push(tmp_res);
    }

    Statistics res1, res2;
    // Approach 1
    for (int test_i = 0; test_i < spp; test_i++) {
        IntersectionAD its;
        FloatAD jacobian;
        Float pdf;
        if (!sphere_sampler.sample(sampler.next2D(), its, jacobian, pdf)) continue;

        VectorAD x1, n1;
        FloatAD J;
        scene.getPoint(its.toIntersection(), x1, n1, J);
        Float v_path_length = (x1 - x0).norm().grad(0);

        res1.push(v_path_length * jacobian.val / pdf);
    }

    // Approach 2
    for (int test_i = 0; test_i < spp; test_i++) {
        IntersectionAD its;
        FloatAD jacobian;
        Float pdf;
        if (!sphere_sampler.sample(sampler.next2D(), its, jacobian, pdf)) continue;

        VectorAD x1, n1;
        FloatAD J;
        scene.getPoint(its.toIntersection(), x1, n1, J);

        FloatAD proj_dist = -n1.dot(x1 - x0);
        Vector circle_center = x0 - n1.val * proj_dist.val;
        Float circle_radius = (x1.val - circle_center).norm();
        //printf("circle_center = (%.3f, %.3f, %.3f)\n", circle_center[0], circle_center[1], circle_center[2]);
        Vector p1_norm = (x1.val - circle_center).normalized();

        RayAD ray(x0, (x1 - x0).normalized());

        VectorAD y = its.p;
        //printf("y = (%.3f, %.3f, %.3f); (%.3f, %.3f, %.3f)\n", y.val[0], y.val[1], y.val[2],
        //    y.grad(0)[0], y.grad(0)[1], y.grad(0)[2]);

        const Vector3i &ind = plane.indices[0];
        const VectorAD &v0 = plane.vertices[ind[0]], &v1 = plane.vertices[ind[1]], &v2 = plane.vertices[ind[2]];
        FloatAD u = 0.5 * (y - v0).cross(v2 - v0).norm() / plane.getAreaAD(0);
        FloatAD v = 0.5 * (y - v0).cross(v1 - v0).norm() / plane.getAreaAD(0);
        //printf("u: %.3f; %.3f, v: %.3f; %.3f\n", u.val, u.grad(0), v.val, v.grad(0));
        VectorAD p1 = (1.0f - u - v)*VectorAD(v0.val) + u * VectorAD(v1.val) + v * VectorAD(v2.val);
        //printf("p1 = (%.3f, %.3f, %.3f); (%.3f, %.3f, %.3f)\n", p1.val[0], p1.val[1], p1.val[2],
        //    p1.grad(0)[0], p1.grad(0)[1], p1.grad(0)[2]);
        
        Float v_tau = p1.grad(0).dot(p1_norm);

        res2.push(v_tau * circle_radius / pdf);
    }

    printf("Gaussian approx: %.6f +- %.6f\n", res0.getMean(), res0.getCI());
    printf("Diff. delta: %.6f +- %.6f\n", res1.getMean(), res1.getCI());
    printf("Boundary: %.6f +- %.6f\n", res2.getMean(), res2.getCI());

    if (std::abs(res1.getMean() - res2.getMean()) < Epsilon)
        printf("Test passed!\n");
    else
        printf("Test failed...\n");
}

void tof_test() {
    test1();
    test2();
}