#include "pathspace.h"
#include "scene.h"
#include "math_func.h"


static void build_distrib(const Vector3i &dims, const Eigen::Array<Float, -1, 1> &data, DiscreteDistribution &sample_distrb)
{
    long long num_voxels = static_cast<long long>(dims[0])*dims[1]*dims[2];
    Float ratio = static_cast<Float>(num_voxels)/data.sum();
    sample_distrb.clear();
    if ( num_voxels > 1 ) {
        sample_distrb.reserve(num_voxels);
        for ( long long i = 0; i < num_voxels; ++i ) sample_distrb.append(data[i]*ratio);
        sample_distrb.normalize();
    }
}


PathSpaceEdgeManager::PathSpaceEdgeManager(const Scene& scene, const Eigen::Array<Float, -1, 1> &shapeWeights,
                                           const Vector3i &direct_dims, const Eigen::Array<Float, -1, 1> &direct_data,
                                           const Vector3i &indirect_dims, const Eigen::Array<Float, -1, 1> &indirect_data,
                                           bool verbose)
    : m_scene(scene)
{
    const std::vector<const Shape*>& shape_list = scene.shape_list;
    assert(shape_list.size() > 0);

    // Check if the area light is planar
    int planarLight = -1;
    std::vector<Vector> lightVertices;
    Vector lightNormal = Vector::Zero();
    for ( const Shape* shape : shape_list ) {
        if ( shape->light_id >= 0 ) {
            for ( int i = 0; i < shape->num_vertices; ++i ) lightVertices.push_back(shape->getVertex(i));

            for ( int i = 0; i < shape->num_triangles; ++i ) {
                const Vector &n = shape->getGeoNormal(i), p = shape->getVertex(shape->getIndices(i)[0]);
                if ( planarLight >= 0 ) {
                    if ( (lightNormal - n).norm() > ShadowEpsilon || std::abs(lightNormal.dot(p - lightVertices.front())) > Epsilon ) {
                        planarLight = 0;
                        break;
                    }
                } else {
                    planarLight = 1;
                    lightNormal = n;
                }
            }
        }
        if ( planarLight == 0 ) break;
    }
    assert(planarLight >= 0);

    bool cull = (lightVertices.size() < 100);
    // bool cull = false;
    if ( verbose ) {
        if ( cull ) {
            std::cout << "[INFO] " << lightVertices.size() << " light vertices in total. Back-edge culling enabled." << std::endl;
        }
        else {
            std::cerr << "[WARN] Too many light vertices. Back-edge culling disabled." << std::endl;
        }
    }

    edge_distrbs[0].clear(); edge_indices[0].clear();
    edge_distrbs[1].clear(); edge_indices[1].clear();

    int nskipped = 0;
    for ( int i = 0; i < static_cast<int>(shape_list.size()); ++i )
        if ( shapeWeights[i] > Epsilon ) {
            const Shape &shape = *scene.shape_list[i];
            const BSDF &bsdf = *scene.bsdf_list[shape.bsdf_id];
            if ( shape.light_id >= 0 || !bsdf.isNull() ) {
                for ( int j = 0; j < static_cast<int>(shape.edges.size()); ++j ) {
                    const Edge &edge = shape.edges[j];
                    if ( edge.mode == 0 || bsdf.isTransmissive() ) {
                        edge_indices[0].push_back(Vector2i(i, j));
                        edge_distrbs[0].append(edge.length*4.0f*M_PI);
                        edge_indices[1].push_back(Vector2i(i, j));
                        edge_distrbs[1].append(edge.length*4.0f*M_PI);
                    } else {
                        const Vector &n0 = shape.getGeoNormal(edge.f0), &n1 = shape.getGeoNormal(edge.f1);
                        Float tmp = n0.dot(n1);
                        assert(tmp > -1.0f + EdgeEpsilon && tmp < 1.0f - EdgeEpsilon);

                        const Vector &v0 = shape.getVertex(edge.v0), &v1 = shape.getVertex(edge.v1);
                        Vector c = 0.5f*(v0 + v1);
                        if ( edge.mode > 0 ) {
                            if ( !planarLight || lightNormal.dot(c - lightVertices.front()) > Epsilon ) {
                                // Back-edge culling
                                bool dead = false;
                                if ( cull ) {
                                    dead = true;
                                    for ( const Vector &p : lightVertices ) {
                                        Vector e = p - v0;
                                        if ( n0.dot(e) > Epsilon || n1.dot(e) > Epsilon ) {
                                            dead = false; break;
                                        }
                                    }
                                }

                                if ( !dead ) {
                                    edge_indices[0].push_back(Vector2i(i, j));
                                    edge_distrbs[0].append(edge.length*4.0f*std::acos(tmp));
                                }
                            }
                            edge_indices[1].push_back(Vector2i(i, j));
                            edge_distrbs[1].append(edge.length*4.0f*std::acos(tmp));
                        }
                    }
                }
            }
        } else {
            ++nskipped;
        }

    edge_distrbs[0].normalize();
    edge_distrbs[1].normalize();
    assert(edge_distrbs[0].size() == edge_indices[0].size() && edge_distrbs[1].size() == edge_indices[1].size());
    if ( verbose ) {
        std::cout << "[INFO] Initialized " << edge_indices[0].size() << " direct edges and " << edge_indices[1].size() << " indirect edges";
        if ( nskipped > 0 ) std::cout << " (with " << nskipped << " shapes skipped)";
        std::cout << std::endl;
    }

    // Direct
    grid_dims[0] = direct_dims;
    build_distrib(direct_dims, direct_data, grid_distrbs[0]);

    // Indirect
    grid_dims[1] = indirect_dims;
    build_distrib(indirect_dims, indirect_data, grid_distrbs[1]);
}


const Edge& PathSpaceEdgeManager::sampleEdgeRay(const Vector &_rnd, int &shape_id, RayAD &ray, Float &pdf, bool direct) const {
    Vector rnd(_rnd);

    const int idx = direct ? 0 : 1;
    const DiscreteDistribution &grid_distrb = grid_distrbs[idx];
    const Vector3i &grid_dim = grid_dims[idx];

    Float pdf0 = 1.0f;
    if ( grid_distrb.getSum() > Epsilon ) {
        long long idx = grid_distrb.sampleReuse(rnd[2], pdf0);
        int unit = grid_dim[1]*grid_dim[2];
        int i = static_cast<int>(idx/unit);
        idx %= unit;
        int j = static_cast<int>(idx/grid_dim[2]), k = static_cast<int>(idx % grid_dim[2]);

        pdf0 *= grid_distrb.getSum();
        rnd = (Vector(i, j, k) + rnd).cwiseQuotient(Vector(grid_dim[0], grid_dim[1], grid_dim[2]));
        assert(rnd[0] > -Epsilon && rnd[0] < 1.0f + Epsilon &&
               rnd[1] > -Epsilon && rnd[1] < 1.0f + Epsilon &&
               rnd[2] > -Epsilon && rnd[2] < 1.0f + Epsilon);
    }

    const std::vector<Vector2i> &edge_ind = edge_indices[idx];
    const DiscreteDistribution &edge_distrb = edge_distrbs[idx];
    assert(edge_distrb.getSum() > Epsilon);

    const Vector2i ind = edge_ind[edge_distrb.sampleReuse(rnd[0], pdf)];
    const Shape &shape = *m_scene.shape_list[shape_id = ind[0]];
    const Edge &edge = shape.edges[ind[1]];
    assert(edge.f0 >= 0);
    pdf /= edge.length;

    const VectorAD &v0 = shape.getVertexAD(edge.v0), &v1 = shape.getVertexAD(edge.v1);
    ray.org = v0 + (v1 - v0)*rnd[0];

    const Vector *n0 = &shape.getGeoNormal(edge.f0), *n1 = (edge.f1 >= 0 ? &shape.getGeoNormal(edge.f1) : NULL);

    if ( direct ) {
        Intersection its;
        m_scene.sampleEmitterPosition(Vector2(rnd[1], rnd[2]), its);
        if ( m_scene.isVisible(ray.org.val, true, its.p, true) ) {
            Vector e = its.p - ray.org.val;
            Float distSqr = e.squaredNorm();
            e /= std::sqrt(distSqr);
            Float pdf1 = m_scene.pdfEmitterSample(its)*distSqr/std::abs(its.geoFrame.n.dot(-e));

            ray.dir = VectorAD(e);
            pdf *= pdf1;
        } else {
            shape_id = -1;
            n1 = NULL;
        }
    } else {
        if ( n1 == NULL ) {
            // Case 1: boundary edge
            Vector tmp = n0->cross(v1.val - v0.val);
            Float tmpNorm = tmp.norm();
            assert(tmpNorm > ShadowEpsilon);
            Frame frame(tmp/tmpNorm);

            // squareToUniformSphere
            Float z = 1.0f - 2.0f*rnd[2];

            // Avoid sampling directions that are almost within the plane
            z = std::min(z, static_cast<Float>(1.0f - EdgeEpsilon));

            Float r = std::sqrt(std::max(static_cast<Float>(0.0f), 1.0f - z*z));
            Float sinPhi, cosPhi;
            math::sincos(2.0f*M_PI*rnd[1], sinPhi, cosPhi);

            ray.dir = VectorAD(frame.toWorld(Vector(r*cosPhi, r*sinPhi, z)));
            pdf /= (4.0f*M_PI);
        } else {
            // Case 2: non-boundary edge
            Float pdf1;
            ray.dir = VectorAD(squareToEdgeRayDirection(Vector2(rnd[1], rnd[2]), *n0, *n1, pdf1));
            pdf *= pdf1;
        }
    }

    if ( n1 ) {
        Float dotn0 = ray.dir.val.dot(*n0), dotn1 = ray.dir.val.dot(*n1);
        if ( math::signum(dotn0)*math::signum(dotn1) > -0.5f ) {
            if ( !direct ) std::cerr << "\n[WARN] Bad edge ray sample: [" << dotn0 << ", " << dotn1 << "]" << std::endl;
            shape_id = -1;
        }
    }

    pdf *= pdf0;
    return edge;
}
