#include "algorithm1.h"
#include <render/scene.h>
#include <core/statistics.h>
NAMESPACE_BEGIN(algorithm1_ptracer)

Float evalTransmittance(const Scene &scene, const Intersection &v1, const Intersection &v2,
                        const Medium *medium, RndSampler *sampler)
{
    // NOTE:
    bool p1OnSurface = v1.type & EVSurface;
    bool p2OnSurface = v2.type & EVSurface;
    return scene.evalTransmittance(v1.p, p1OnSurface, v2.p, p2OnSurface,
                                   medium, sampler);
}

bool getPoint(const Scene &scene, Intersection &v)
{
    if (v.type == EVSensor)
    {
        v.p = scene.camera.cpos;
        return true;
    }
    if (v.type & EVSurface)
    {
        const Shape *shape = scene.shape_list[v.shape_id];
        const Vector3i &ind = shape->indices[v.triangle_id];
        const Vector &v0 = shape->getVertex(ind[0]),
                     &v1 = shape->getVertex(ind[1]),
                     &v2 = shape->getVertex(ind[2]);
        v.p = (1. - v.barycentric.x() - v.barycentric.y()) * v0 +
              v.barycentric.x() * v1 +
              v.barycentric.y() * v2;
        Vector geo_n = shape->getFaceNormal(v.triangle_id);
        Vector sh_n = shape->getShadingNormal(v.triangle_id, v.barycentric);
        v.geoFrame = Frame(geo_n);
        v.shFrame = Frame(sh_n);
        v.J = shape->getArea(v.triangle_id);
        v.J /= detach(v.J);
        return true;
    }
    // =========================== for volume =============================
    if (v.type & EVVolume)
    {
        assert(v.medium_id >= 0);
        const Medium *medium = scene.getMedium(v.medium_id);
        Vector q;
        Float J;
        assert(medium->m_tetmesh.state == ESConfigured);
        if (!medium->m_tetmesh.queryAD(scene, v.p, q, J))
        {
            return false;
        }
        v.J = J;
        v.p = q;
        return true;
    }
    // =========================== for volume =============================

    if (v.type == EVInvalid)
        return false;
    assert(false);
    return false;
}

bool getPath(const Scene &scene, LightPath &path)
{
    for (int i = 0; i < path.vertices.size(); i++)
    {
        if (!getPoint(scene, path.vertices[i]))
            return false;
    }
    return true;
}

void d_getPoint(SceneAD &sceneAD, const Intersection &v, Intersection &d_v)
{
    auto &d_scene = sceneAD.gm.get(omp_get_thread_num());
#if defined(ENZYME) && defined(PTRACER)
    __enzyme_autodiff((void *)getPoint,
                      enzyme_dup, &sceneAD.val, &d_scene,
                      enzyme_dup, &v, &d_v);
#endif
}

void d_getPath(SceneAD &sceneAD, LightPathAD &pathAD)
{
    for (int i = 0; i < pathAD.val.vertices.size(); i++)
        if (pathAD.val.vertices[i].type & EVSurface || pathAD.val.vertices[i].type & EVVolume)
            d_getPoint(sceneAD, pathAD.val.vertices[i], pathAD.der.vertices[i]);
}

void evalVertexContrib(const Scene &scene, const Tuple &t)
{
    auto [vp, v, vn] = t.intersections;
    // camera
    if (t.type == 0)
    {
        assert(v->type & EVSensor);
        if (vp->type & EVVolume)
            v->value = scene.camera.eval(v->pixel_idx[0], v->pixel_idx[1],
                                         vp->p);
        if (vp->type & EVSurface)
            v->value = scene.camera.eval(v->pixel_idx[0], v->pixel_idx[1],
                                         vp->p);
        // assert(v->value.allFinite());
    }
    // emitter :  f * cos(theta) / pdf
    if (t.type == 1)
    {
        assert(v->type == EVEmitter);
        const Shape *ptr_emitter = scene.shape_list[v->shape_id];
        const Emitter *emitter = scene.emitter_list[ptr_emitter->light_id];
        v->value = emitter->eval(v->geoFrame.n, (vn->p - v->p).normalized()) *
                   (vn->p - v->p).normalized().dot(v->geoFrame.n) *
                   v->J / v->pdf;
        // assert(v->value.allFinite());
    }
    // bsdf : f * cos(theta) / pdf
    if (t.type == 2)
    {
        v->wi = v->toLocal((vp->p - v->p).normalized());
        const Shape *ptr_shape = scene.shape_list[v->shape_id];
        const BSDF *ptr_bsdf = scene.bsdf_list[ptr_shape->bsdf_id];
        v->value = ptr_bsdf->eval(*v, v->toLocal((vn->p - v->p).normalized()));
        v->value *= v->J / v->pdf;
        // assert(v->value.allFinite());
    }
    // bsdf correction : f * cos(theta) / pdf
    if (t.type == 3)
    {
        assert(v->type & EVSurface);
        v->wi = v->toLocal((vp->p - v->p).normalized());
        const Shape *ptr_shape = scene.shape_list[v->shape_id];
        const BSDF *ptr_bsdf = scene.bsdf_list[ptr_shape->bsdf_id];
        v->value = ptr_bsdf->eval(*v, v->toLocal((vn->p - v->p).normalized()), EImportanceWithCorrection);
        v->value *= v->J / v->pdf;
        // assert(v->value.allFinite());
    }
    // volume : f / pdf
    if (t.type == 4)
    {
        assert(v->type & EVVolume);
        const Medium *medium = scene.getMedium(v->medium_id);
        const PhaseFunction *phase = scene.getPhase(medium->phase_id);
        v->value = phase->eval((vp->p - v->p).normalized(),
                               (vn->p - v->p).normalized()) *
                   medium->sigS(v->p);
        v->value *= v->J / v->pdf;
        // assert(v->value.allFinite());
    }

    // geometric
    if (t.type == 1 || t.type == 2 || t.type == 3 || t.type == 4)
    {
        if (vn->type == EVSensor)
        {
        }
        else if (vn->type & EVVolume)
        {
            v->value *= geometric(v->p, vn->p);
        }
        else if (vn->type & EVSurface)
        {
            v->value *= geometric(v->p, vn->p, vn->geoFrame.n);
        }
        else
        {
            assert(false);
        }
    }
    // transmittance
    if (t.type == 1 || t.type == 2 || t.type == 3 || t.type == 4)
    {
        const Medium *medium = scene.getMedium(v->medium_id);
        if (v->ptr_shape && v->isMediumTransition())
            medium = v->getTargetMedium((vn->p - v->p));
        if (v->type & EVVolume)
        {
            v->value *= evalTransmittance(scene, *v, *vn, medium, t.sampler);
        }
        else if (v->type & EVSurface)
        {
            v->value *= evalTransmittance(scene, *v, *vn, medium, t.sampler);
        }
        else
        {
            assert(false);
        }
    }
    // assert(v->value.allFinite());
}

void d_evalVertexContrib(SceneAD &sceneAD, const TupleAD &tAD)
{
    auto &d_scene = sceneAD.gm.get(omp_get_thread_num());
#if defined(ENZYME) && defined(PTRACER)
    __enzyme_autodiff((void *)evalVertexContrib,
                      enzyme_dup, &sceneAD.val, &d_scene,
                      enzyme_dup, &tAD.val, &tAD.der);
#endif
}

void evalVertex(const Scene &scene, LightPath &path, RndSampler *sampler)
{
    int l = path.path.size();
    // emitter
    evalVertexContrib(scene, {1, sampler, {nullptr, &path.vertices[path.path[0]], &path.vertices[path.path[1]]}});
    // bsdf / volume
    for (int i = 1; i < path.path.size() - 1; i++)
    {
        assert(path.vertices[path.path[i]].type & EVSurface || path.vertices[path.path[i]].type & EVVolume);
        if (path.vertices[path.path[i]].type & EVVolume)
        {
            evalVertexContrib(scene, {4, sampler, {&path.vertices[path.path[i - 1]], &path.vertices[path.path[i]], &path.vertices[path.path[i + 1]]}});
        }
        else
        {
            evalVertexContrib(scene, {3, sampler, {&path.vertices[path.path[i - 1]], &path.vertices[path.path[i]], &path.vertices[path.path[i + 1]]}});
        }
    }
    // camera
    evalVertexContrib(scene, {0, sampler, {&path.vertices[path.path[l - 2]], &path.vertices[path.path[l - 1]], nullptr}});
}

void d_evalVertex(SceneAD &sceneAD, LightPathAD &pathAD, RndSampler *sampler)
{
    auto &path = pathAD.val;
    auto &d_path = pathAD.der;
    int l = pathAD.val.path.size();
    // emitter
    d_evalVertexContrib(sceneAD, {{1, sampler, {nullptr, &pathAD.val.vertices[path.path[0]], &pathAD.val.vertices[path.path[1]]}},
                                  {1, sampler, {nullptr, &pathAD.der.vertices[path.path[0]], &pathAD.der.vertices[path.path[1]]}}});
    // bsdf / phase
    for (int i = 1; i < l - 1; i++)
    {
        assert(path.vertices[path.path[i]].type & EVSurface || path.vertices[path.path[i]].type & EVVolume);
        if (path.vertices[path.path[i]].type & EVVolume)
        {
            d_evalVertexContrib(sceneAD, {{4, sampler, {&path.vertices[path.path[i - 1]], &path.vertices[path.path[i]], &path.vertices[path.path[i + 1]]}},
                                          {4, sampler, {&d_path.vertices[path.path[i - 1]], &d_path.vertices[path.path[i]], &d_path.vertices[path.path[i + 1]]}}});
        }
        else
        {
            d_evalVertexContrib(sceneAD, {{3, sampler, {&path.vertices[path.path[i - 1]], &path.vertices[path.path[i]], &path.vertices[path.path[i + 1]]}},
                                          {3, sampler, {&d_path.vertices[path.path[i - 1]], &d_path.vertices[path.path[i]], &d_path.vertices[path.path[i + 1]]}}});
        }
    }
    // camera
    d_evalVertexContrib(sceneAD, {{0, sampler, {&path.vertices[path.path[l - 2]], &path.vertices[path.path[l - 1]], nullptr}},
                                  {0, sampler, {&d_path.vertices[path.path[l - 2]], &d_path.vertices[path.path[l - 1]], nullptr}}});
}

Spectrum evalPath(const Scene &scene, LightPath &path)
{
    Spectrum value = Spectrum::Ones();
    for (int i = 0; i < path.path.size(); i++)
        value *= path.vertices[path.path[i]].value;
    return value;
}

void __evalPath(const Scene &scene, LightPath &path, Spectrum &value)
{
    value = evalPath(scene, path);
}

void d_evalPath(SceneAD &sceneAD, LightPathAD &pathAD, Spectrum d_value) {
    auto    &d_scene = sceneAD.gm.get(omp_get_thread_num());
    Spectrum value   = Spectrum::Ones();
#ifdef ENZYME
    __enzyme_autodiff((void *) __evalPath,
                      enzyme_dup, &sceneAD.val, &d_scene,
                      enzyme_dup, &pathAD.val, &pathAD.der,
                      enzyme_dup, &value, &d_value);
#endif
}

Spectrum eval(const Scene &scene, LightPath &path, RndSampler *sampler)
{
    // LightPath path(_path);
    if (!getPath(scene, path))
        return Spectrum::Zero();
    evalVertex(scene, path, sampler);
    Spectrum ret = evalPath(scene, path);
    return ret;
}

void d_eval(SceneAD &sceneAD, LightPathAD &pathAD,
            Spectrum d_value, RndSampler *sampler)
{
    auto &path = pathAD.val;
    if (!getPath(sceneAD.val, path))
        return;
    evalVertex(sceneAD.val, path, sampler);
    Spectrum value = evalPath(sceneAD.val, path);
    if (!value.allFinite())
        return;
    d_evalPath(sceneAD, pathAD, d_value); // algorithm 1
    d_evalVertex(sceneAD, pathAD, sampler);
    d_getPath(sceneAD, pathAD);
}

/**
 * @brief find the antithetic intersection point
 *
 * @param scene
 * @param path
 * @return std::tuple<bool, Intersection, Float, Float, Float, Float>
 */
std::tuple<bool, Intersection, Float, Float, Float, Float> antithetic(const Scene &scene, const LightPath &path)
{
    // camera
    auto &vc = path.vertices[path.path[path.path.size() - 1]];
    auto &v = path.vertices[path.path[path.path.size() - 2]];
    Ray ray(vc.p, v.p - vc.p);
    Vector2 pixel_uv;
    Vector dir;
    scene.camera.sampleDirect(v.p, pixel_uv, dir);
    Ray dual = scene.camera.sampleDualRay(vc.pixel_idx, pixel_uv);

    Float pdf1A = v.pdf,
          pdf1B = 0, pdf2A = 0, pdf2B = 0;
    Intersection its;

    if (!scene.rayIntersect(dual, true, its))
        return {false, its, v.pdf, 0, 0, 0};

    Float G_camA = scene.camera.geometric(v.p, v.geoFrame.n);
    Float G_camB = scene.camera.geometric(its.p, its.geoFrame.n);
    Float pdf_pixel = v.pdf / G_camA;
    its.pdf = pdf_pixel * G_camB;
    pdf2B = its.pdf;
    if (path.path.size() == 2)
    {
        if (!its.isEmitter())
            return {false, its, pdf1A, 0, 0, pdf2B};
        pdf2A = scene.pdfEmitterSample(its);
        pdf1B = pdf2A / G_camB * G_camA;
    }
    if (path.path.size() > 2)
    {
        auto &vp = path.vertices[path.path[path.path.size() - 3]];

        if (!scene.isVisible(vp.p, true, its.p, true))
            return {false, its, v.pdf, 0, 0, 0};
        if (vp.isEmitter())
            pdf2A = vp.ptr_emitter->pdf(its, vp.toLocal((its.p - vp.p).normalized())) * geometric(vp.p, its.p, its.geoFrame.n);
        else
            pdf2A = vp.pdfBSDF(vp.toLocal((its.p - vp.p).normalized())) * geometric(vp.p, its.p, its.geoFrame.n);
        pdf1B = pdf2A / G_camB * G_camA;
    }
    return {true, its, pdf1A, pdf1B, pdf2A, pdf2B};
}

__attribute__((optnone)) std::tuple<bool, Intersection, Float, Float, Float, Float> antithetic_surf(const Scene &scene, const LightPath &path, RndSampler *sampler)
{
    // camera
    auto &vc = path.vertices[path.path[path.path.size() - 1]];
    auto &v = path.vertices[path.path[path.path.size() - 2]];
    Ray ray(vc.p, v.p - vc.p);
    Vector2 pixel_uv;
    Vector dir;
    scene.camera.sampleDirect(v.p, pixel_uv, dir);
    Ray dual = scene.camera.sampleDualRay(vc.pixel_idx, pixel_uv);

    Float pdf1A = v.pdf,
          pdf1B = 0, pdf2A = 0, pdf2B = 0;
    Intersection its;

    if (!scene.traceForSurface(dual, true, its))
    {
        return {false, its, v.pdf, 0, 0, 0};
    }

    Float G_camA = scene.camera.geometric(v.p, v.geoFrame.n);
    Float G_camB = scene.camera.geometric(its.p, its.geoFrame.n);
    Float pdf_pixel = v.pdf / G_camA;
    its.pdf = pdf_pixel * G_camB;
    pdf2B = its.pdf;
    if (path.path.size() == 2)
    {
        if (!its.isEmitter())
            return {false, its, pdf1A, 0, 0, pdf2B};
        pdf2A = scene.pdfEmitterSample(its);
        pdf1B = pdf2A / G_camB * G_camA;
    }
    if (path.path.size() > 2)
    {
        auto &vp = path.vertices[path.path[path.path.size() - 3]];
        bool p1OnSurface = vp.type & EVSurface;
        auto [trans, trans_pdf] = scene.evalTransmittanceAndPdf(vp.p, p1OnSurface, its.p, true, scene.getMedium(vp.medium_id), sampler);
        if (!scene.isVisible(vp.p, true, its.p, true))
            return {false, its, v.pdf, 0, 0, 0};
        if (vp.isEmitter() && path.path.size() == 3)
            pdf2A = vp.ptr_emitter->pdf(vp, vp.toLocal((its.p - vp.p).normalized())) * trans_pdf * geometric(vp.p, its.p, its.geoFrame.n);
        else if (vp.type & EVSurface)
        {
            pdf2A = vp.pdfBSDF(vp.toLocal((its.p - vp.p).normalized())) * trans_pdf * geometric(vp.p, its.p, its.geoFrame.n);
        }
        else if (vp.type & EVVolume)
        {
            auto &vpp = path.vertices[path.path[path.path.size() - 4]];
            auto medium = scene.getMedium(vp.medium_id);
            auto phase = scene.getPhase(medium->phase_id);
            pdf2A = phase->pdf((vpp.p - vp.p).normalized(), (its.p - vp.p).normalized()) * trans_pdf * geometric(vp.p, its.p, its.geoFrame.n);
        }
        else
            assert(false);

        pdf1B = pdf2A / G_camB * G_camA;
    }
    return {true, its, pdf1A, pdf1B, pdf2A, pdf2B};
}

/**
 * @brief find the antithetic volume interaction if there exists one
 *  step 1: trace the pixel level antithetic ray for an equal distance
 *  step 2: if hit a surface, return. otherwise, record the medium interaction
 *  step 3: compute pdf1A : sampled by primal process, pdf1B : sampled by antithetic process, pdf2A, pdf2B
 *      step 3.1: pdf1A is already computed
 *      step 3.2: pdf2A is computed by evaluating bsdf_pdf/emitter_pdf/phase_pdf * distance_pdf
 *      step 3.3: pdf2B is computed by converting the pdf1A to the pixel plane then convert it into the volume
 *      step 3.4: pdf1B is computed by converting the pdf2A to the pixel plane then convert it into the volume
 * @param scene
 * @param path
 * @param sampler
 * @return std::tuple<bool, Intersection, Float, Float, Float, Float>
 */
__attribute__((optnone)) std::tuple<bool, Intersection, Float, Float, Float, Float> antithetic_vol(const Scene &scene, const LightPath &path, RndSampler *sampler, bool is_equal_trans)
{
    // camera
    auto &vc = path.vertices[path.path[path.path.size() - 1]];
    // volume interaction
    auto &v = path.vertices[path.path[path.path.size() - 2]];
    assert(v.type == EVVolume);
    Ray ray(vc.p, (v.p - vc.p).normalized());
    Vector2 pixel_uv;
    Vector dir;
    scene.camera.sampleDirect(v.p, pixel_uv, dir);
    Ray dual = scene.camera.sampleDualRay(vc.pixel_idx, pixel_uv);
    dual.dir = dual.dir.normalized();
    Float pdf1A = v.pdf, pdf1B = 0,
          pdf2A = 0, pdf2B = 0;

    // step 1: trace the antithetic ray for an equal distance
    const Medium *medium = scene.getMedium(v.medium_id);
    MediumSamplingRecord mRec;
    bool success = false;
    // equal transmittance sampling
    if (is_equal_trans)
    {
        Float tarTrans = scene.evalTransmittance(vc.p, false, v.p, false, scene.getMedium(vc.medium_id), sampler);
        if (tarTrans < Epsilon)
        {
            // Statistics::getInstance().getCounter("ptracer2", "targetTransmittance < Epsilon") += 1;
            return {false, Intersection(), v.pdf, 0, 0, 0};
        }
        success = scene.traceForMedium(dual, false, scene.getMedium(vc.medium_id), tarTrans, sampler, mRec);
    }
    // equal distance sampling
    else
    {
        Float tarDist = (v.p - vc.p).norm();
        if (tarDist < Epsilon)
        {
            // Statistics::getInstance().getCounter("ptracer2", "targetTransmittance < Epsilon") += 1;
            return {false, Intersection(), v.pdf, 0, 0, 0};
        }
        success = scene.traceForMedium2(dual, false, scene.getMedium(vc.medium_id), tarDist, sampler, mRec);
    }
    // step 2: if hit a surface, return. otherwise, record the medium interaction
    if (!success)
    {
        // Statistics::getInstance().getCounter("ptracer2", "unsuccess") += 1;
        return {false, Intersection(), v.pdf, 0, 0, 0};
    }

    Float pdf_pixel = v.pdf / scene.camera.geometric(v.p);
    pdf2B = pdf_pixel * scene.camera.geometric(mRec.p);
    // step 2: if hit a surface, return. otherwise, record the medium interaction
    // if the original point and the antithetic point are not in the same medium, return
    if (mRec.medium != medium)
    {
        // Statistics::getInstance().getCounter("ptracer2", "medium inconsistence") += 1;
        return {false, Intersection(), v.pdf, 0, 0, 0};
    }
    Intersection its(mRec, v.medium_id, pdf2B);

    // step 3: compute pdf2A, adn pdf1B
    auto &vp = path.vertices[path.path[path.path.size() - 3]];
    bool p1OnSurface = vp.type & EVSurface;
    auto [trans, pdf_trans] = scene.evalTransmittanceAndPdf(vp.p, p1OnSurface, mRec.p, false, scene.getMedium(vp.medium_id), sampler);
    if (trans < Epsilon)
    {
        // Statistics::getInstance().getCounter("ptracer2", "trans < Epsilon") += 1;
        return {false, its, v.pdf, 0, 0, 0};
    }

    if (vp.isEmitter() && path.path.size() == 3)
        pdf2A = vp.ptr_emitter->pdf(vp, vp.toLocal((its.p - vp.p).normalized())) * pdf_trans * geometric(vp.p, its.p);
    else if (vp.type & EVSurface)
    {
        pdf2A = vp.pdfBSDF(vp.toLocal((its.p - vp.p).normalized())) * pdf_trans * geometric(vp.p, its.p);
    }
    else if (vp.type & EVVolume)
    {
        auto &vpp = path.vertices[path.path[path.path.size() - 4]];
        auto medium = scene.getMedium(vp.medium_id);
        auto phase = scene.getPhase(medium->phase_id);
        pdf2A = phase->pdf((vpp.p - vp.p).normalized(), (its.p - vp.p).normalized()) * pdf_trans * geometric(vp.p, its.p);
    }
    else
        assert(false);

    pdf1B = pdf2A / scene.camera.geometric(its.p) * scene.camera.geometric(v.p);
    return {true, its, pdf1A, pdf1B, pdf2A, pdf2B};
}

NAMESPACE_END(algorithm1_ptracer)
