#include <render/scene.h>
#include "algorithm1.h"
#include "algorithm1_vol.h"

NAMESPACE_BEGIN(algorithm1_vol)

bool getPoint(const Scene &scene, Intersection &v)
{
    if (v.type == EVSensor)
    {
        v.p = scene.camera.cpos;
        return true;
    }
    if (v.type & EVSurface)
    {
        const Shape *shape = scene.shape_list[v.shape_id];
        const Vector3i &ind = shape->indices[v.triangle_id];
        const Vector &v0 = shape->getVertex(ind[0]),
                     &v1 = shape->getVertex(ind[1]),
                     &v2 = shape->getVertex(ind[2]);
        v.p = (1. - v.barycentric.x() - v.barycentric.y()) * v0 +
              v.barycentric.x() * v1 +
              v.barycentric.y() * v2;
        Vector geo_n = shape->getFaceNormal(v.triangle_id);
        Vector sh_n = shape->getShadingNormal(v.triangle_id, v.barycentric);
        v.geoFrame = Frame(geo_n);
        v.shFrame = Frame(sh_n);
        v.J = shape->getArea(v.triangle_id);
        v.J /= detach(v.J);
        return true;
    }
    if (v.type & EVVolume)
    {
        assert(v.medium_id >= 0);
        const Medium *medium = scene.getMedium(v.medium_id);
        v.J = 1;
        if (scene.state.use_tetmesh)
        {
            Vector q;
            Float J;
            assert(medium->m_tetmesh.state == ESConfigured);
            if (!medium->m_tetmesh.queryAD(scene, v.p, q, J))
            {
                v.type = EVInvalid;
                return false;
            }
            v.J = J;
            v.p = q;
        }
        return true;
    }
    if (v.type == EVInvalid)
        return true;
    assert(false);
    return false;
}

bool getPath(const Scene &scene, LightPath &path)
{
    for (int i = 0; i < path.vertices.size(); i++)
    {
        if (!getPoint(scene, path.vertices[i]))
            return false;
    }
    return true;
}

Float evalTransmittance(const Scene &scene, const Intersection &v1, const Intersection &v2,
                        const Medium *medium, RndSampler *sampler)
{
    // printf("v1.p = %f %f %f\n", v1.p[0], v1.p[1], v1.p[2]);
    // printf("v2.p = %f %f %f\n", v2.p[0], v2.p[1], v2.p[2]);
    // printf("v1.type = %d, v2.type = %d\n", v1.type, v2.type);
    // printf("is medium = %d\n", medium != nullptr);
    // NOTE:
    bool p1OnSurface = v1.type & EVSurface;
    bool p2OnSurface = v2.type & EVSurface;
    return scene.evalTransmittance(
        v1.p, p1OnSurface, v2.p, p2OnSurface, medium, sampler);
}

Float evalTransmittanceRatio(const Scene &scene, const Intersection &v1, const Intersection &v2,
                            const Medium *medium, RndSampler *sampler)
{
    // NOTE:
    bool p1OnSurface = v1.type & EVSurface;
    bool p2OnSurface = v2.type & EVSurface;
    Float transmittanceRatio = scene.evalTransmittance(
        v1.p, p1OnSurface, v2.p, p2OnSurface, medium, sampler, true);
    // if(std::abs(transmittanceRatio - 1) > 1e-6) {
    //     printf("transmittanceRatio: %f\n", transmittanceRatio);
    // }
    return transmittanceRatio;
}

Spectrum evalFirstVertexContrib(
    const Scene &scene, RndSampler *sampler, const Array2i &pixel_idx,
    const Intersection &curV, const Intersection &nextV)
{
    assert(curV.type & EVSensor);
    Spectrum value = Spectrum::Zero();
    if (nextV.type & EVVolume)
        value = scene.camera.evalFilter(pixel_idx[0], pixel_idx[1],
                                        nextV.p);
    if (nextV.type & EVSurface)
    {
        value = scene.camera.evalFilter(pixel_idx[0], pixel_idx[1],
                                        nextV.p, nextV.geoFrame.n);
        // printf("curV.value: %f %f %f\n", value[0], value[1], value[2]);
    }
    const Medium *medium = scene.getMedium(curV.medium_id);
    Float transmittance = evalTransmittanceRatio(scene, curV, nextV, medium, sampler);
    // printf("curV.transmittance: %f\n", transmittance);
    value *= transmittance;
    return value;
}

Spectrum evalVertexContrib(
    const Scene &scene, RndSampler *sampler,
    const Intersection &preV, Intersection &curV, const Intersection &nextV) {
    if (curV.type & EVSurface) {
        curV.wi = curV.toLocal((preV.p - curV.p).normalized());
    }
    if (curV.type & EVVolume) {
        curV.wi = (preV.p - curV.p).normalized();
    }
    assert(curV.type != EVInvalid);

    // compute bsdf/phase * (sig_s) * transmittance * G * J / pdf
    const Medium *medium = nullptr;
    Spectrum      value  = Spectrum::Ones();

    assert(curV.type & EVVolume || curV.type & EVSurface);
    if (curV.type & EVVolume) {
        medium                     = scene.getMedium(curV.medium_id);
        const PhaseFunction *phase = scene.getPhase(medium->phase_id);
        value *= phase->eval((preV.p - curV.p).normalized(),
                             (nextV.p - curV.p).normalized()) *
                 medium->sigS(curV.p);
        Spectrum sigs = medium->sigS(curV.p);
    }
    if (curV.type & EVSurface) {
        medium                 = scene.getMedium(curV.getTargetMediumId((nextV.p - curV.p).normalized()));
        const Shape *ptr_shape = scene.shape_list[curV.shape_id];
        const BSDF  *ptr_bsdf  = scene.bsdf_list[ptr_shape->bsdf_id];
        value *= ptr_bsdf->eval(curV, curV.toLocal((nextV.p - curV.p).normalized()));
    }

    value *= evalTransmittanceRatio(scene, curV, nextV, medium, sampler);
    assert(nextV.type & EVVolume || nextV.type & EVSurface);
    if (nextV.type & EVVolume) {
        value *= geometric(curV.p, nextV.p);
    }
    if (nextV.type & EVSurface)
    {
        value *= geometric(curV.p, nextV.p, nextV.geoFrame.n);
    }
    value *= curV.J / curV.pdf;
    return value;
}

Spectrum evalLastVertexContrib(const Scene        &scene,  RndSampler *sampler,
                               const Intersection &curV, const Intersection &neeV) {
    const Shape   *ptr_emitter = scene.shape_list[neeV.shape_id];
    const Emitter *emitter     = scene.emitter_list[ptr_emitter->light_id];

    // FIXME: evalTransmittance used here only to adjust the value of the evalTransmittanceRatio evaluated in evalVertexContrib
    const Medium *medium = nullptr;
    if (curV.type & EVVolume) {
        medium = scene.getMedium(curV.medium_id);
    }
    if (curV.type & EVSurface) {
        medium = scene.getMedium(curV.getTargetMediumId((neeV.p - curV.p).normalized()));
    }
    Float transmittance = evalTransmittance(scene, curV, neeV, medium, sampler);
    Spectrum em = emitter->eval(neeV.geoFrame.n, (curV.p - neeV.p).normalized());
    Vector   dir           = (curV.p - neeV.p).normalized();
    Spectrum ret = em * detach(transmittance) *
                   neeV.J / neeV.pdf;
    return ret;
}

void evalVertex(const Scene &scene, LightPath &path, RndSampler *sampler)
{
    const Array2i &pixel_idx = path.pixel_idx;
    auto &vertices = path.vertices;
    const auto &vtxIds = path.vs;

    Intersection &camV = vertices[0];
    if (path.vs.size() <= 1 || vertices[vtxIds[1]].type == EVInvalid)
        return;
    Intersection &firstV = vertices[vtxIds[1]];
    camV.value = evalFirstVertexContrib(scene, sampler, pixel_idx, camV, firstV);
    if (firstV.type == EVEmitter)
    {
        camV.nee_bsdf = evalLastVertexContrib(scene, sampler, camV, firstV);
    }
    for (int i = 1; i < vtxIds.size(); i++)
    {
        Intersection &preV = vertices[vtxIds[i - 1]];
        Intersection &curV = vertices[vtxIds[i]];
        if (i < vtxIds.size() - 1)
        {
            Intersection &nextV = vertices[vtxIds[i + 1]];
            curV.value = evalVertexContrib(scene, sampler, preV, curV, nextV);
        }
        // handle nee: emitter sampling and bsdf sampling
        for (int nee_id : {curV.l_nee_id, curV.l_bsdf_id})
        {
            if (nee_id == -1)
                continue;
            // compute bsdf/phase * (sig_s) * transmittance * G * J / pdf
            Intersection &neeV = vertices[nee_id];
            Spectrum value = evalVertexContrib(scene, sampler, preV, curV, neeV);
            if (nee_id == curV.l_nee_id){
                curV.nee_bsdf = value;
            }
            if (nee_id == curV.l_bsdf_id){
                curV.bsdf_bsdf = value;
            }

            // compute Le
            neeV.value = evalLastVertexContrib(scene, sampler, curV, neeV);
        }
    }
}

Spectrum evalPath(const Scene &scene, LightPath &path)
{
    Spectrum value = Spectrum::Zero();
    Spectrum throughput = Spectrum::Ones();

    if (path.vs.size() <= 1 || path.vertices[path.vs[1]].type == EVInvalid)
        return Spectrum::Zero();
    const Intersection &camV = path.vertices[0];
    const Intersection &firstV = path.vertices[path.vs[1]];
    if (firstV.type == EVEmitter) {
        value += camV.value * camV.nee_bsdf;
    }
    for (int i = 0; i < path.vs.size() - 1; i++)
    {
        Intersection &preV = path.vertices[path.vs.at(i)];
        Intersection &curV = path.vertices[path.vs.at(i + 1)];
        throughput *= preV.value;
        // emitter sampling
        if (curV.l_nee_id >= 0)
        {
            Intersection &neeV = path.vertices[curV.l_nee_id];
            value += throughput * curV.nee_bsdf * neeV.value;
        }

        // bsdf sampling
        if (curV.l_bsdf_id >= 0)
        {
            Intersection &bsdfV = path.vertices[curV.l_bsdf_id];
            value += throughput * curV.bsdf_bsdf * bsdfV.value;
        }
    }
    return value;
}

Spectrum eval(const Scene &scene, LightPath &path, RndSampler *sampler)
{
    // LightPath path(_path);
    if (!getPath(scene, path))
        return Spectrum::Zero();
    evalVertex(scene, path, sampler);
    Spectrum ret = evalPath(scene, path);
    if (!ret.allFinite())
        ret = Spectrum::Zero();
    return ret;
}

NAMESPACE_END(algorithm1_vol)