#include "scene.h"
#include "shape.h"
#include "frame.h"
#include "frameAD.h"
#include "intersection.h"
#include "intersectionAD.h"
#include "ray.h"
#include "rayAD.h"
#include "../emitter/area.h"
#include "../bsdf/diffuse.h"
#include "camera.h"
#include "sampler.h"
#include "stats.h"
#include <iostream>
#include <chrono>
#include <assert.h>

static int nworker;


static void simpleCaseFD(const Scene &scene0, const Scene &scene1, int nBatches, Float delta)
{
    constexpr long long batchSize = 100000000LL;

    std::vector<RndSampler> samplers;
    for (int i = 0; i < nworker; i++) {
        samplers.push_back(RndSampler(0, i));
    }

    std::cout << "\n===== FD =====" << std::endl;

    Statistics resStat;
    for ( int batch = 1; batch <= nBatches; ++batch ) {
        std::vector<Statistics> stats(nworker);
        std::cout << "  [" << batch << '/' << nBatches << "]: " << std::flush;
        std::chrono::time_point start = std::chrono::high_resolution_clock::now();

#pragma omp parallel for num_threads(nworker)
        for (long long omp_i = 0; omp_i < batchSize; ++omp_i) {
            const int tid = omp_get_thread_num();
            const Array2 rnd0 = samplers[tid].next2D(), //shared random number for shape0 and shape1 of x1
                         rnd1 = samplers[tid].next2D(); //shared random number for shape0 and shape1 of x2

            PositionSamplingRecord pRec0, pRec1;
            Vector2i ind0, ind1;
            Ray ray;
            Intersection its;
            Float dist;

            Float before = 0.0f;
            ind0 = scene0.samplePosition(rnd0, pRec0);
            ind1 = scene0.samplePosition(rnd1, pRec1);
            if ( ind0 != ind1 ) {
                ray.org = pRec0.p;
                ray.dir = pRec1.p - pRec0.p;
                dist = ray.dir.norm();
                ray.dir /= dist;

                if ( scene0.rayIntersect(ray, true, its) ) { // mutual visibility function
                    if ( std::abs(its.t - dist) < ShadowEpsilon ) {
                        Float cos1 = std::abs(pRec0.n.dot(ray.dir)), cos2 = std::abs(pRec1.n.dot(-ray.dir));
                        before = cos1*cos2/(dist*dist);
                        before *= std::pow(scene0.getArea(), 2.0f); // prob density
                    }
                }
            }

            Float after = 0.0f;
            ind0 = scene1.samplePosition(rnd0, pRec0);
            ind1 = scene1.samplePosition(rnd1, pRec1);
            if ( ind0 != ind1 ) {
                ray.org = pRec0.p;
                ray.dir = pRec1.p - pRec0.p;
                dist = ray.dir.norm();
                ray.dir /= dist;

                if ( scene1.rayIntersect(ray, true, its) ) { // mutual visibility function
                    if ( std::abs(its.t - dist) < ShadowEpsilon ) {
                        Float cos1 = std::abs(pRec0.n.dot(ray.dir)), cos2 = std::abs(pRec1.n.dot(-ray.dir));
                        after = cos1*cos2/(dist*dist);
                        after *= std::pow(scene1.getArea(), 2.0f); // prob density
                    }
                }
            }

            // stats[tid].push(before);
            stats[tid].push((after - before)/delta);
        }

        std::chrono::duration<double> time_span = std::chrono::duration_cast<std::chrono::duration<double> >(
            std::chrono::high_resolution_clock::now() - start
        );

        for (int i = 0; i < nworker; ++i) resStat.push(stats[i]);
        std::cout << resStat.getMean() << " +- " << resStat.getCI() << " (" << time_span.count() << " secs)" << std::endl;
    }
}


static Float evalMIS(const Scene &scene, const Array2 &rnd0, const Array2 &rnd1)
{
    Float value = 0.0f;

    PositionSamplingRecord pRec;
    Ray ray;
    Intersection its;
    Float val, cos1cos2, pdf1, pdf2;

    const Vector2i ind0 = scene.samplePosition(rnd0, pRec);

    // Position sampling
    {
        PositionSamplingRecord pRec1;
        if ( scene.samplePosition(rnd1, pRec1) != ind0 ) {
            ray.org = pRec.p;
            ray.dir = pRec1.p - pRec.p;
            Float dist = ray.dir.norm();
            ray.dir /= dist;

            if ( scene.rayIntersect(ray, true, its) && std::abs(its.t - dist) < ShadowEpsilon ) {
                cos1cos2 = std::abs(pRec.n.dot(ray.dir)*pRec1.n.dot(-ray.dir));
                val = scene.getArea()*cos1cos2/(dist*dist);
                pdf1 = 1.0f/scene.getArea();
                pdf2 = cos1cos2/(2.0f*M_PI*dist*dist);
                assert(std::isfinite(pdf2));

                value += scene.getArea()*val*pdf1/(pdf1 + pdf2);
            }
        }
    }

    // Angle sampling
    {
        ray.org = pRec.p;
        Vector d = squareToCosineSphere(rnd1);
        ray.dir = Frame(pRec.n).toWorld(d);
        if ( scene.rayIntersect(ray, true, its) && its.t*its.t > Epsilon ) {
            cos1cos2 = std::abs(d.z()*its.geoFrame.n.dot(-ray.dir));
            val = 2.0f*M_PI;
            pdf1 = 1.0f/scene.getArea();
            pdf2 = cos1cos2/(2.0*M_PI*its.t*its.t);
            assert(std::isfinite(pdf2));

            value += scene.getArea()*val*pdf2/(pdf1 + pdf2);
        }
    }

    return value;
}


static void simpleCaseFD2(const Scene &scene0, const Scene &scene1, int nBatches, Float delta)
{
    constexpr long long batchSize = 100000000LL;

    std::vector<RndSampler> samplers;
    for (int i = 0; i < nworker; i++) {
        samplers.push_back(RndSampler(0, i));
    }

    std::cout << "\n===== FD2 =====" << std::endl;

    Statistics resStat;
    for ( int batch = 1; batch <= nBatches; ++batch ) {
        std::vector<Statistics> stats(nworker);
        std::cout << "  [" << batch << '/' << nBatches << "]: " << std::flush;
        std::chrono::time_point start = std::chrono::high_resolution_clock::now();

#pragma omp parallel for num_threads(nworker)
        for (long long omp_i = 0; omp_i < batchSize; ++omp_i) {
            const int tid = omp_get_thread_num();
            const Array2 rnd0 = samplers[tid].next2D(), //shared random number for shape0 and shape1 of x1
                         rnd1 = samplers[tid].next2D(); //shared random number for shape0 and shape1 of omega

            const Float before = evalMIS(scene0, rnd0, rnd1),
                        after = evalMIS(scene1, rnd0, rnd1);

            // stats[tid].push(before);
            stats[tid].push((after - before)/delta);
        }

        std::chrono::duration<double> time_span = std::chrono::duration_cast<std::chrono::duration<double> >(
            std::chrono::high_resolution_clock::now() - start
        );

        for (int i = 0; i < nworker; ++i) resStat.push(stats[i]);
        std::cout << resStat.getMean() << " +- " << resStat.getCI() << " (" << time_span.count() << " secs)" << std::endl;
    }
}


static FloatAD evalMIS_AD(const Scene &scene, const Array2 &rnd0, const Array2 &rnd1)
{
    FloatAD value(0.0f);

    PositionSamplingRecordAD pRec;
    Intersection its;
    FloatAD val, cos1cos2;
    Float pdf1, pdf2;

    const Vector2i ind0 = scene.samplePositionAD(rnd0, pRec);

    // Position sampling
    {
        PositionSamplingRecordAD pRec1;
        if ( scene.samplePositionAD(rnd1, pRec1) != ind0 ) {
            VectorAD dir = pRec1.p - pRec.p;
            FloatAD dist = dir.norm();
            dir /= dist;

            if ( scene.rayIntersect(Ray(pRec.p.val, dir.val), true, its) && std::abs(its.t - dist.val) < ShadowEpsilon ) {
                cos1cos2 = (pRec.n.dot(dir)*pRec1.n.dot(-dir)).abs();
                val = scene.getArea()*cos1cos2*pRec.J*pRec1.J/dist.square();
                pdf1 = 1.0f/scene.getArea();
                pdf2 = cos1cos2.val/(2.0f*M_PI*dist.val*dist.val);
                assert(std::isfinite(pdf2));

                value += scene.getArea()*(val*pdf1/(pdf1 + pdf2));
            }
        }
    }

    // Angle sampling
    {
        Ray ray(pRec.p.val, Frame(pRec.n.val).toWorld(squareToCosineSphere(rnd1)));
        if ( scene.rayIntersect(ray, true, its) && its.t*its.t > Epsilon ) {
            const Shape &shape = *scene.shape_list[its.indices[0]];
            const Vector3i ind = shape.getIndices(its.indices[1]);

            VectorAD x2 = (1.0f - its.barycentric[0] - its.barycentric[1])*shape.getVertexAD(ind[0]) +
                          its.barycentric[0]*shape.getVertexAD(ind[1]) +
                          its.barycentric[1]*shape.getVertexAD(ind[2]),
                     n2 = shape.getGeoNormalAD(its.indices[1]);
            FloatAD J2 = shape.getAreaAD(its.indices[1]);
            J2 /= J2.val;

            VectorAD dir = x2 - pRec.p;
            FloatAD dist = dir.norm();
            dir /= dist;

            cos1cos2 = (pRec.n.dot(dir)*n2.dot(-dir)).abs();
            if ( cos1cos2.val > Epsilon ) {
                val = cos1cos2*pRec.J*J2/dist.square();
                val *= 2.0f*M_PI/val.val;
                pdf1 = 1.0f/scene.getArea();
                pdf2 = cos1cos2.val/(2.0f*M_PI*its.t*its.t);
                assert(std::isfinite(pdf2));

                value += scene.getArea()*(val*pdf2/(pdf1 + pdf2));
            }
        }
    }

    return value;
}


static Float evalEdge(const Scene &scene, const Vector &rnd, Float &pdf) {
    Float ret = 0.0f;
    int shape_id;
    RayAD edgeRay;
    const Edge &rEdge = scene.sampleEdgeRay(rnd, shape_id, edgeRay, pdf);
    assert(rEdge.f0 >= 0);
    const Shape &shape = *scene.shape_list[shape_id];

    Intersection its1, its2;
    Ray _edgeRay = edgeRay.toRay();
    if ( scene.rayIntersect(_edgeRay, true, its2) && scene.rayIntersect(_edgeRay.flipped(), true, its1) ) {
        const Vector2i ind0(shape_id, rEdge.f0), ind1(shape_id, rEdge.f1);
        if ( its1.indices != ind0 && its1.indices != ind1 && its2.indices != ind0 && its2.indices != ind1 ) {
            const Shape &shape1 = *scene.shape_list[its1.indices[0]], &shape2 = *scene.shape_list[its2.indices[0]];
            const Vector3i &f1 = shape1.getIndices(its1.indices[1]), &f2 = shape2.getIndices(its2.indices[1]);

            VectorAD u2;
            {
                RayAD ray;
                ray.org = (1.0f - its1.barycentric[0] - its1.barycentric[1])*shape1.getVertexAD(f1[0]) +
                          its1.barycentric[0]*shape1.getVertexAD(f1[1]) +
                          its1.barycentric[1]*shape1.getVertexAD(f1[2]);
                assert((ray.org.val - its1.p).norm() < ShadowEpsilon);
                ray.dir = edgeRay.org - ray.org;

                const VectorAD &v0 = shape2.getVertexAD(f2[0]), &v1 = shape2.getVertexAD(f2[1]), &v2 = shape2.getVertexAD(f2[2]);
                rayIntersectTriangleAD(v0, v1, v2, ray, u2);
            }

            Vector d = its2.p - its1.p;
            Float dist = d.norm();
            d /= dist;

            Float cos2 = std::abs(its2.geoFrame.n.dot(d));
            Vector e = (shape.getVertex(rEdge.v0) - shape.getVertex(rEdge.v1)).normalized();
            Float sinphi = e.cross(d).norm();
            Vector proj = e.cross(d).cross(its2.geoFrame.n).normalized();
            Float sinphi2 = d.cross(proj).norm();
            Vector n = its2.geoFrame.n.cross(proj).normalized();
            assert(std::abs(its2.geoFrame.n.norm() - 1.0f) < ShadowEpsilon);
            assert(std::abs(e.norm() - 1.0f) < ShadowEpsilon);
            assert(std::abs(d.norm() - 1.0f) < ShadowEpsilon);
            assert(std::abs(proj.norm() - 1.0f) < ShadowEpsilon);

            Intersection vi;
            Float deltaVis = -1.0f;
            scene.rayIntersect(Ray(its1.p, (its2.p + n*ShadowEpsilon) - its1.p), true, vi);
            if (vi.t < 1.0f - ShadowEpsilon) deltaVis = 1.0f;
            ret = n.dot(u2.grad()) * deltaVis * (edgeRay.org.val - its1.p).norm()/dist * sinphi/sinphi2 * cos2;
        }
        // else
        //     std::cerr << "Bad sample" << std::endl;
    }
    return ret;
}


static void simpleCaseEQ(const Scene &scene0, int nBatches) {
    constexpr long long batchSize = 100000000LL;

    std::vector<RndSampler> samplers;
    for (int i = 0; i < nworker; i++) {
        samplers.push_back(RndSampler(0, i));
    }

    std::cout << "\n===== AD =====" << std::endl;

    Statistics resStat;
    for ( int batch = 1; batch <= nBatches; ++batch ) {
        std::vector<Statistics> stats(nworker);
        std::cout << "  [" << batch << '/' << nBatches << "]: " << std::flush;
        std::chrono::time_point start = std::chrono::high_resolution_clock::now();

#pragma omp parallel for num_threads(nworker)
        for (long long omp_i = 0; omp_i < batchSize; ++omp_i) {
            const int tid = omp_get_thread_num();
            Float value;

            // autodiff term
            {
                const Array2 rnd0 = samplers[tid].next2D(), rnd1 = samplers[tid].next2D();
#if 0
                PositionSamplingRecordAD pRec0, pRec1;
                Vector2i ind0 = scene0.samplePositionAD(rnd0, pRec0),
                         ind1 = scene0.samplePositionAD(rnd1, pRec1);
                if ( ind0[0] != ind1[0] || ind0[1] != ind1[1] ) {
                    Ray Ray0(pRec0.p.val, pRec1.p.val - pRec0.p.val);
                    Intersection its0;
                    if ( scene0.rayIntersect(Ray0, true, its0) ) {
                        if ( std::abs(its0.t - 1.0f) <= ShadowEpsilon) {
                            VectorAD v = pRec1.p-pRec0.p;
                            FloatAD len = v.norm();
                            v /= len;
                            FloatAD cos1 = pRec0.n.dot(v).abs(), cos2 = pRec1.n.dot(v).abs();
                            value = (cos1*cos2/(len*len) * pRec0.J * pRec1.J).grad();
                            value *= pow(scene0.getArea(), 2);
                        }
                    }
                }
#else
                value = evalMIS_AD(scene0, rnd0, rnd1).grad();
#endif
            }

            // edge term
            {
                Float edgeValue, pdf;
                if ( std::abs(edgeValue = evalEdge(scene0, samplers[tid].next3D(), pdf)) > Epsilon )
                    value += edgeValue/pdf;
            }
            stats[tid].push(value);
        }

        std::chrono::duration<double> time_span = std::chrono::duration_cast<std::chrono::duration<double> >(
            std::chrono::high_resolution_clock::now() - start
        );

        for (int i = 0; i < nworker; ++i) resStat.push(stats[i]);
        std::cout << resStat.getMean() << " +- " << resStat.getCI() << " (" << time_span.count() << " secs)" << std::endl;
    }
}

void compute_grid(const std::vector<const Shape *> &shapes, const std::vector<int> &para, ptr<float> grid_data)
{
    DiffuseBSDF diffuse(Spectrum3f(0.4f, 0.5f, 0.6f));
    Camera camera0;
    Scene scene0(camera0, shapes,
                 std::vector<const BSDF*>{&diffuse},
                 std::vector<const Emitter*>(),
                 std::vector<const PhaseFunction*>(),
                 std::vector<const Medium*>());

    nworker = omp_get_num_procs();
    std::vector<RndSampler> samplers;
    for (int i = 0; i < nworker; i++) samplers.push_back(RndSampler(0, i));

    Eigen::Array<Float, -1, 1> sums = Eigen::Array<Float, -1, 1>::Zero(nworker);
#pragma omp parallel for
    for ( long long omp_i = 0; omp_i < static_cast<long long>(para[0])*para[1]*para[2]; ++omp_i ) {
        int i, j, k;
        {
            long long tmp = omp_i;
            i = static_cast<int>(tmp/(para[1]*para[2]));
            tmp %= (para[1]*para[2]);
            j = static_cast<int>(tmp/para[2]);
            k = static_cast<int>(tmp % para[2]);
        }

        const int tid = omp_get_thread_num();
        Float ret = 0.0f, pdf, val;
        for ( int t = 0; t < para[3]; ++t ) {
            Vector rnd = (Vector(i, j, k) + samplers[tid].next3D().matrix()).cwiseQuotient(Vector(para[0], para[1], para[2]));
            val = evalEdge(scene0, rnd, pdf);
            ret += std::abs(val/pdf);
        }
        sums[tid] += (grid_data[omp_i] = static_cast<float>(ret/para[3]));
    }

    Float w = static_cast<Float>(para[0])*para[1]*para[2]/static_cast<Float>(sums.sum());
    for ( long long omp_i = 0; omp_i < static_cast<long long>(para[0])*para[1]*para[2]; ++omp_i )
        grid_data[omp_i] *= w;
}

static void gridCaseEQ(const Scene &scene0, int nBatches, const std::vector<int> &dim, ptr<float> grid_data) {
    constexpr long long batchSize = 100000000LL;
    std::vector<RndSampler> samplers;
    for (int i = 0; i < nworker; i++) {
        samplers.push_back(RndSampler(0, i));
    }

    DiscreteDistribution cubeDistr;
    for ( long long i = 0; i < static_cast<long long>(dim[0])*dim[1]*dim[2]; ++i )
        cubeDistr.append(grid_data[i]);
    cubeDistr.normalize();

    std::cout << "\n===== Grid AD =====" << std::endl;
    Statistics resStat;
    for ( int batch = 1; batch <= nBatches; ++batch ) {
        std::vector<Statistics> stats(nworker);
        std::cout << "  [" << batch << '/' << nBatches << "]: " << std::flush;
        std::chrono::time_point start = std::chrono::high_resolution_clock::now();

#pragma omp parallel for num_threads(nworker)
        for (long long omp_i = 0; omp_i < batchSize; ++omp_i) {
            const int tid = omp_get_thread_num();
            Float value;

            // autodiff term
            {
                const Array2 rnd0 = samplers[tid].next2D(), rnd1 = samplers[tid].next2D();
                value = evalMIS_AD(scene0, rnd0, rnd1).grad();
            }

            // edge term with grid
            {
                Vector rnd = samplers[tid].next3D();
                const long long ind = static_cast<long long>(cubeDistr.sampleReuse(rnd[0]));
                long long tmp = ind;
                rnd[0] = (rnd[0] + static_cast<Float>(tmp/(dim[1]*dim[2])))/static_cast<Float>(dim[0]);
                tmp %= (dim[1]*dim[2]);
                rnd[1] = (rnd[1] + static_cast<Float>(tmp/dim[2]))/static_cast<Float>(dim[1]);
                rnd[2] = (rnd[2] + static_cast<Float>(tmp % dim[2]))/static_cast<Float>(dim[2]);
                assert( rnd[0] > -Epsilon && rnd[0] < 1.0f + Epsilon &&
                        rnd[1] > -Epsilon && rnd[1] < 1.0f + Epsilon &&
                        rnd[2] > -Epsilon && rnd[2] < 1.0f + Epsilon );

                Float edgeValue, pdf;
                if ( std::abs(edgeValue = evalEdge(scene0, rnd, pdf)) > Epsilon )
                    value += edgeValue/(pdf*grid_data[ind]);
            }
            stats[tid].push(value);
        }

        std::chrono::duration<double> time_span = std::chrono::duration_cast<std::chrono::duration<double> >(
            std::chrono::high_resolution_clock::now() - start
        );

        for (int i = 0; i < nworker; ++i) resStat.push(stats[i]);
        std::cout << resStat.getMean() << " +- " << resStat.getCI() << " (" << time_span.count() << " secs)" << std::endl;
    }
}

void path_test(const std::vector<const Shape *> &shapes, const std::vector<int> &nBatches, float _delta, const std::vector<int> &dim, ptr<float> grid_data)
{
    const Float delta = static_cast<Float>(_delta);
    const int nshapes = static_cast<int>(shapes.size());

    std::vector<Shape> shapes1_data(nshapes);
    std::vector<const Shape *> shapes1(nshapes);
    for ( int i = 0; i < nshapes; ++i ) {
        shapes1_data[i] = *shapes[i];
        shapes1_data[i].advance(delta);
        shapes1[i] = &shapes1_data[i];
    }

    DiffuseBSDF diffuse(Spectrum3f(0.4f, 0.5f, 0.6f));
    Camera camera0;

    Scene scene0(camera0, shapes,
                 std::vector<const BSDF*>{&diffuse},
                 std::vector<const Emitter*>(),
                 std::vector<const PhaseFunction*>(),
                 std::vector<const Medium*>());

    Scene scene1(camera0, shapes1,
                 std::vector<const BSDF*>{&diffuse},
                 std::vector<const Emitter*>(),
                 std::vector<const PhaseFunction*>(),
                 std::vector<const Medium*>());

    nworker = omp_get_num_procs();
    std::cout << "Working with " << nworker << " workers." << std::endl;

    std::cout.precision(2);
    std::cout.setf(std::ios::scientific);

    std::cerr.precision(2);
    std::cerr.setf(std::ios::scientific);

    if ( nBatches[0] > 0 ) simpleCaseFD(scene0, scene1, nBatches[0], delta);
    if ( nBatches[1] > 0 ) simpleCaseFD2(scene0, scene1, nBatches[1], delta);
    if ( nBatches[2] > 0 ) simpleCaseEQ(scene0, nBatches[2]);
    if ( nBatches[3] > 0 && !dim.empty() ) gridCaseEQ(scene0, nBatches[3], dim, grid_data);
}
