#include "volpathBase.h"
#include <core/sampler.h>
#include <render/scene.h>
#include <algorithm1.h>
#include <render/medium.h>
#include <fmt/color.h>

Spectrum VolpathBase::neeEmitter(const Scene &scene, RndSampler *sampler,
                                 const Medium *medium, const Vector &p, const Vector &wi, bool mis)
{
    // assert(medium);
    DirectSamplingRecord dRec(p);
    Spectrum value = scene.sampleAttenuatedEmitterDirect(
        dRec, sampler->next2D(), sampler, medium);
    if (value.isZero(Epsilon))
        return Spectrum(0.);
    const PhaseFunction *phase = scene.getPhase(medium->phase_id);
    Float phaseVal = phase->eval(wi, dRec.dir);
    if (phaseVal < Epsilon)
        return Spectrum(0.);

    value *= phaseVal;
    if (mis)
    {
        Float mis = miWeight(dRec.pdf / dRec.G, phase->pdf(wi, dRec.dir));
        value *= mis;
    }
    return value;
}

Spectrum VolpathBase::neePhase(const Scene &scene, RndSampler *sampler,
                               const Medium *medium, const Vector &p, const Vector &wi, bool mis)
{
    assert(medium);
    Vector wo;
    const PhaseFunction *phase = scene.getPhase(medium->phase_id);
    Float phaseVal = phase->sample(wi, sampler->next2D(), wo); // phase=1
    if (phaseVal < Epsilon)
        return Spectrum(0.);

    Ray ray(p, wo);
    Intersection its;
    DirectSamplingRecord dRec(p);
    Spectrum value = scene.rayIntersectAndLookForEmitter(
        ray, false, sampler, medium, its, dRec);
    if (value.isZero(Epsilon))
        return Spectrum(0.);

    Spectrum ret = value * phaseVal;
    if (mis)
    {
        Float pdf_emitter = scene.pdfEmitterDirect(dRec);
        ret *= miWeight(phase->pdf(wi, wo), pdf_emitter / dRec.G);
    }

    return ret;
}

__attribute__((optnone)) Spectrum VolpathBase::nee(const Scene &scene, RndSampler *sampler,
                                                   const Medium *medium, const Vector &p, const Vector &wi, bool mis)
{
    assert(medium);
    return neeEmitter(scene, sampler, medium, p, wi, mis) +
           neePhase(scene, sampler, medium, p, wi, mis);
}

// input : medium, p, wi(optional), max_bounces
__attribute__((optnone)) Spectrum VolpathBase::Lins(const Scene &scene, RndSampler *sampler,
                                                    const Medium *medium, const Vector &p, const Vector &wi, int max_bounces)
{
    Spectrum ret(0.);

    // nee
    ret += handleNee(scene, sampler, medium, p, wi);

    // sample a direction from phase function
    const PhaseFunction *phase = scene.getPhase(medium->phase_id);
    Vector wo;
    phase->sample(wi, sampler->next2D(), wo);

    // compute radiance
    return ret + Li(scene, sampler, medium, Ray(p, wo), max_bounces - 1, false);
}

__attribute__((optnone)) Spectrum VolpathBase::Li(const Scene &scene, RndSampler *sampler,
                                                  const Medium *medium, const Ray &_ray, int max_bounces, bool incEmission)
{
    // fmt::print("Li11\n");
    Spectrum ret = Spectrum::Zero(),
             throughput = Spectrum::Ones();

    Ray ray(_ray);
    Intersection its;
    scene.rayIntersect(ray, false, its);

    MediumSamplingRecord mRec;
    int depth = 0, null_interations = 0;
    // FIXME
    while (depth <= max_bounces && null_interations < max_null_interactions)
    {
        bool inside_med = medium && medium->sampleDistance(ray, its.t, sampler, mRec);
        if (inside_med)
        {
            // sampled a medium interaction

            if (depth >= max_bounces)
                break;

            const PhaseFunction *phase = scene.phase_list[medium->phase_id];
            throughput *= mRec.sigmaS * mRec.transmittance / mRec.pdfSuccess;
            // nee, emitter sampling + phase sampling
            ret += handleNee(scene, sampler, medium, mRec.p, -ray.dir);

            // do something on the medium event, like branching out for boundary
            handleMedium(scene, sampler, medium, mRec.p, -ray.dir, max_bounces - depth, throughput);

            // sample direction
            Vector wo;
            Float phaseVal = phase->sample(-ray.dir, sampler->next2D(), wo);

            throughput *= phaseVal;
            ray = Ray(mRec.p, wo);
            scene.rayIntersect(ray, false, its);

            // update loop variables
            incEmission = false;
            depth++;
        }
        else
        {
            // sampled a surface interaction
            if (medium)
            {
                throughput *= mRec.transmittance / mRec.pdfFailure;
            }
            if (!its.isValid())
                break;

            if (its.isEmitter())
                ret += throughput * handleEmission(its, -ray.dir, incEmission);

            if (!its.getBSDF()->isNull() && depth >= max_bounces)
                break;

            Spectrum nee_value(0.);
            // ====================== emitter sampling =========================
            DirectSamplingRecord dRec(its);
            if (!its.getBSDF()->isNull())
            {
                Spectrum value = scene.sampleAttenuatedEmitterDirect(
                    dRec, its, sampler->next2D(), sampler, medium);

                if (!value.isZero(Epsilon))
                {
                    Spectrum bsdfVal = its.evalBSDF(its.toLocal(dRec.dir));
                    Float bsdfPdf = its.pdfBSDF(its.toLocal(dRec.dir));
                    Float mis_weight = miWeight(dRec.pdf / dRec.G, bsdfPdf);

                    Spectrum nee_emitter = throughput * value * bsdfVal * mis_weight;
                    nee_value += nee_emitter;
                    ret += nee_emitter;
                }
            }

            // ====================== BSDF sampling =============================
            Vector wo;
            Float bsdfPdf, bsdfEta;
            Spectrum bsdfWeight = its.sampleBSDF(sampler->next3D(), wo, bsdfPdf, bsdfEta);
            // NOTE: need to comment out these code, otherwise the handleNee will go wrong
            // if (bsdfWeight.isZero(Epsilon))
            //     break;

            wo = its.toWorld(wo);
            ray = Ray(its.p, wo);

            throughput *= bsdfWeight;
            if (its.isMediumTransition())
            {
                medium = its.getTargetMedium(wo);
            }
            if (its.getBSDF()->isNull())
            {
                scene.rayIntersect(ray, true, its);
                null_interations++;
            }
            else
            {
                Spectrum value = scene.rayIntersectAndLookForEmitter(
                    ray, true, sampler, medium, its, dRec);
                if (!value.isZero(Epsilon))
                {
                    Float mis_weight = miWeight(bsdfPdf, dRec.pdf / dRec.G);
                    Spectrum nee_bsdf = throughput * value * mis_weight;
                    nee_value += nee_bsdf;
                    ret += nee_bsdf;
                }
                handleNee(nee_value);

                incEmission = false;
                depth++;
            }
        }
    }
    if (null_interations == max_null_interactions)
    {
        if (verbose)
            fprintf(stderr, "Max null interactions (%d) reached. Dead loop?\n", max_null_interactions);
        // Statistics::getInstance().getCounter("Warning", "Null interactions") += 1;
    }
    return ret;
}

__attribute__((optnone)) Spectrum VolpathBase::handleNee(
    const Scene &scene, RndSampler *sampler,
    const Medium *medium, const Vector &p, const Vector &wi)
{
    assert(medium);
    Spectrum value = nee(scene, sampler, medium, p, wi, true);
    handleNee(value);
    return value;
}

Spectrum VolpathBase::handleEmission(const Intersection &its, const Vector &wo, bool incEmission)
{
    assert(its.isEmitter());
    Spectrum ret(0.);
    if (incEmission)
        ret = its.Le(wo);
    return ret;
}

// ============================================================================
//                          Radiance Tracer
// ============================================================================

void RadianceTracer::handleNee(const Spectrum &value)
{
    assert(depth < radiances.size());
    radiances[depth] = value;
    depth++;
}

Spectrum RadianceTracer::Li(
    const Scene &scene, RndSampler *sampler,
    const Medium *medium, const Ray &ray, int max_bounces, bool incEmission)
{
    depth = 0;
    radiances.clear();
    radiances.resize(max_bounces, Spectrum(0.));
    Spectrum v = VolpathBase::Li(scene, sampler, medium, ray, max_bounces, incEmission);
    Spectrum ret(0.);
    for (int i = 0; i < max_bounces; i++)
    {
        ret += radiances[i];
    }
    assert(v.isApprox(ret));
    return ret;
}

__attribute__((optnone)) std::tuple<Spectrum, std::vector<Spectrum>>
RadianceTracer::sampleSource(
    const Scene &scene, RndSampler *sampler,
    const Medium *medium, const Vector &p, const Vector &wi, int max_bounces)
{
    depth = 0;
    radiances.clear();
    radiances.resize(max_bounces, Spectrum(0.));
    Spectrum value = Lins(scene, sampler, medium, p, wi, max_bounces);
    for (int i = 1; i < radiances.size(); i++)
    {
        radiances[i] += radiances[i - 1];
    }
    assert(value.isApprox(radiances.back()));
    return {value, radiances};
}

// ============================================================================
//                          Radiance tracer without nee
// ============================================================================

Spectrum RadianceTracer2::handleNee(const Scene &scene, RndSampler *sampler,
                                    const Medium *medium, const Vector &p, const Vector &wi)
{
    return Spectrum(0.);
}

Spectrum RadianceTracer2::handleEmission(const Intersection &its, const Vector &wo, bool incEmission)
{
    assert(its.isEmitter());
    return its.Le(wo);
}

//=============================================================================
//              Radiance tracer 3 for Unidirectional Boundary
//=============================================================================
void RadianceTracer3::handleBoundary(const Scene &scene, RndSampler *sampler,
                                     const Medium *medium, const Ray &_ray,
                                     int max_bounces, const Spectrum &throughput)
{
    Ray ray{_ray};
    // find all null intersections
    std::vector<Intersection> its_list = scene.rayIntersectAll(ray, false);
    if (its_list.empty())
        return;
    // estimate the boundary integral of single scattering
    for (size_t i = 0; i < its_list.size(); i++)
    {
        const Intersection &its = its_list[i];
        // Vector dir_in = -ray.dir * math::signum(ray.dir.dot(its.geoFrame.n));
        Vector dir_in = -its.geoFrame.n;
        Vector p_shifted = its.p + dir_in * ShadowEpsilon;
        VolpathBase vp;
        // FIXME medium is hard coded to be the same
        const Medium *medium_in = its.ptr_med_int;
        Spectrum Lins = vp.Lins(scene, sampler, medium_in, p_shifted, dir_in, max_bounces - 1);
        // Spectrum Lins = vp.Lins(scene, p_shifted, -ray.dir, medium, _rRec);
        // Float transmittance = scene.evalTransmittance(ray.org, onSurface, its.p, true, medium1, rRec.sampler);
        Float transmittance = scene.evalTransmittance(ray.org, false, p_shifted, false, medium, sampler);
        Spectrum value = throughput * Lins * medium_in->sigS(p_shifted) * transmittance / std::abs(-ray.dir.dot(its.geoFrame.n));
        Float d_u = (dI * value).sum();
        algorithm1::d_velocity(sceneAD, its, d_u);
    }
}

void RadianceTracer3::handleMedium(const Scene &scene, RndSampler *sampler,
                                   const Medium *medium, const Vector &p, const Vector &wi,
                                   int max_bounces, const Spectrum &throughput)
{
    const PhaseFunction *phase = scene.getPhase(medium->phase_id);
    assert(phase);
    if (max_bounces <= 0)
        return;
    Vector wo;
    Float phase_value = phase->sample(wi, sampler->next2D(), wo);
    Ray ray{p, wo};
    handleBoundary(scene, sampler, medium, ray, max_bounces - 1, throughput * phase_value);
}

// ============================================================================
//                           Particle tracing
// ============================================================================

void ParticleTracerBase::importance(const Scene &scene, RndSampler *sampler,
                                    const Medium *medium, const Ray &_ray, bool on_surface,
                                    int depth, int max_bounces,
                                    const Spectrum &_throughput)
{
    Spectrum throughput(_throughput);
    Ray ray(_ray);
    Intersection its;
    while (depth < max_bounces)
    {
        if (throughput.isZero())
            break;
        scene.rayIntersect(ray, on_surface, its);
        MediumSamplingRecord mRec;
        bool inside_med = medium != nullptr &&
                          medium->sampleDistance(ray, its.t, sampler, mRec);
        ray.org = mRec.p;
        if (inside_med)
        {
            // sampled a medium interaction
            throughput *= mRec.sigmaS * mRec.transmittance / mRec.pdfSuccess;
            handleMedium(scene, sampler, medium, mRec.p, -ray.dir, depth, throughput);
            const auto *phase = scene.phase_list[medium->phase_id];
            Vector wo;
            Float phase_val = phase->sample(-ray.dir, sampler->next2D(), wo);
            if (phase_val < Epsilon)
                break;
            Float phasePdf = phase->pdf(-ray.dir, wo);
            throughput *= phase_val;
            ray.dir = wo;
            on_surface = false;
            depth++;
        }
        else
        {
            // hit the surface
            // if the ray going through a medium
            if (medium)
            {
                throughput *= mRec.transmittance / mRec.pdfFailure;
            }
            if (throughput.isZero())
                break;
            if (!its.isValid())
                break;
            if (!its.ptr_bsdf->isNull())
            {
                // handleSurfaceInteraction(scene, its, p_rec);
            }
            Float bsdf_pdf, bsdf_eta;
            Vector wo_local, wo;
            EBSDFMode mode = EBSDFMode::EImportanceWithCorrection;
            auto bsdf_weight = its.sampleBSDF(sampler->next3D(), wo_local, bsdf_pdf, bsdf_eta, mode);
            if (bsdf_weight.isZero())
                break;
            wo = its.toWorld(wo_local);
            Vector wi = -ray.dir;
            Float wiDotGeoN = wi.dot(its.geoFrame.n),
                  woDotGeoN = wo.dot(its.geoFrame.n);
            if (wiDotGeoN * its.wi.z() <= 0 ||
                woDotGeoN * wo_local.z() <= 0)
            {
                break;
            }
            throughput *= bsdf_weight;
            if (its.isMediumTransition())
            {
                medium = its.getTargetMedium(woDotGeoN);
            }
            ray = Ray(its.p, wo);
            on_surface = true;
            if (!its.ptr_bsdf->isNull())
            {
                depth++;
            }
        }
    }
}

void ParticleTracerBase::traceParticle(const Scene &scene, RndSampler *sampler, int max_bounces)
{
    Intersection its;
    Spectrum throughput = scene.sampleEmitterPosition(sampler->next2D(), its);
    handleEmission(scene, sampler, its, throughput);
    Vector wo;
    Float pdf;
    throughput *= its.ptr_emitter->sampleDirection(sampler->next2D(), wo, &pdf);
    // not point source
    if (its.ptr_shape != nullptr)
        wo = its.geoFrame.toWorld(wo);

    Ray ray(its.p, wo);

    bool on_surface = true;
    const Medium *medium = its.getTargetMedium(ray.dir);

    importance(scene, sampler, medium, ray, true, 0, max_bounces, throughput);
}

void ParticleTracerBase::handleEmission(const Scene &scene, RndSampler *sampler,
                                        const Intersection &its, const Spectrum &throughput)
{
    CameraDirectSamplingRecord cRec;
    Spectrum value(scene.sampleAttenuatedSensorDirect(its, sampler, cRec));
    if (!value.isZero(Epsilon) && cRec.baseVal > Epsilon)
    {
        value *= its.ptr_emitter->evalDirection(its.geoFrame.n, cRec.dir) * throughput;
        if (!value.isZero(Epsilon))
            handleSensor(scene, sampler, cRec, 0, value);
    }
}

std::tuple<Spectrum, Intersection> ParticleTracerBase::sampleParticle(
    const Scene &scene, RndSampler *sampler)
{
    Intersection its;
    Spectrum value = scene.sampleEmitterPosition(sampler->next2D(), its); // intensity / pdf
    return {value, its};
}

void ParticleTracerBase::handleMedium(
    const Scene &scene, RndSampler *sampler,
    const Medium *medium, const Vector &p, const Vector &wi,
    int depth, const Spectrum &throughput)
{
    CameraDirectSamplingRecord cRec;
    Spectrum value{scene.sampleAttenuatedSensorDirect(p, medium, sampler, cRec)};
    if (!value.isZero(Epsilon) && cRec.baseVal > Epsilon)
    {
        const PhaseFunction *ptr_phase = scene.phase_list[medium->phase_id];
        value *= ptr_phase->eval(wi, cRec.dir) * throughput;
        if (!value.isZero(Epsilon))
            handleSensor(scene, sampler, cRec, depth, value);
    }
}

void ParticleTracerBase::handleMedium(
    const Scene &scene, RndSampler *sampler,
    const Medium *medium, const Vector &p,
    int depth, const Spectrum &throughput)
{
    CameraDirectSamplingRecord cRec;
    Spectrum value{scene.sampleAttenuatedSensorDirect(p, medium, sampler, cRec)};
    if (!value.isZero(Epsilon) && cRec.baseVal > Epsilon)
    {
        value *= throughput;
        if (!value.isZero(Epsilon))
            handleSensor(scene, sampler, cRec, depth, value);
    }
}

void ParticleTracerBase::handleSurface(
    const Scene &scene, RndSampler *sampler,
    const Intersection &its,
    int depth, const Spectrum &throughput)
{
    CameraDirectSamplingRecord cRec;
    Spectrum value(scene.sampleAttenuatedSensorDirect(its, sampler, cRec));
    if (!value.isZero(Epsilon) && cRec.baseVal > Epsilon)
    {
        Vector wi = its.toWorld(its.wi);
        Vector wo = cRec.dir, wo_local = its.toLocal(wo);
        /* Prevent light leaks due to the use of shading normals -- [Veach, p. 158] */
        Float wiDotGeoN = wi.dot(its.geoFrame.n), woDotGeoN = wo.dot(its.geoFrame.n);
        if (wiDotGeoN * its.wi.z() > Epsilon && woDotGeoN * wo_local.z() > Epsilon)
        {
            value *= its.ptr_bsdf->eval(its, wo_local, EBSDFMode::EImportanceWithCorrection) * throughput;
            if (!value.isZero(Epsilon))
                handleSensor(scene, sampler, cRec, depth, value);
        }
    }
}

// ============================================================================
//                           ImportanceTracer
// ============================================================================

__attribute__((optnone)) void ImportanceTracer::handleSensor(const Scene &scene, RndSampler *sampler,
                                                             const CameraDirectSamplingRecord &cRec,
                                                             int depth, const Spectrum &throughput)
{
    auto [pixel_id, sensor_value] = scene.camera.sampleDirectPixel(cRec, sampler->next1D());
    Spectrum d_value = d_image.get(cRec.pixel_idx);
    // radiances contain phase value
    d_value *= sensor_value * throughput * radiances[max_bounces - depth];
#ifdef FORWARD
    int shape_idx = scene.getShapeRequiresGrad();
    sceneAD.getDer().shape_list[shape_idx]->param = 0;
#endif

    algorithm1::d_velocity(sceneAD, its_b, d_value.sum());

#ifdef FORWARD
    Float param = sceneAD.getDer().shape_list[shape_idx]->param;
    grad_image.put(cRec.pixel_idx, Spectrum(param, 0, 0));
#endif
}

// ============================================================================
//                           Importance Tracer 2
// ============================================================================
void ImportanceTracer2::handleSensor(const Scene &scene, RndSampler *sampler,
                                     const CameraDirectSamplingRecord &cRec,
                                     int depth, const Spectrum &throughput)
{
    auto [pixel_id, sensor_value] = scene.camera.sampleDirectPixel(cRec, sampler->next1D());
    Array2i pixel_idx = unravel_index(pixel_id, scene.camera.getCropSize());
    Spectrum d_value = d_image.get(pixel_idx);
    // radiances contain phase value
    d_value *= sensor_value * throughput * radiance;
#ifdef FORWARD
    int shape_idx = scene.getShapeRequiresGrad();
    sceneAD.getDer().shape_list[shape_idx]->param = 0;
#endif

    algorithm1::d_velocity(sceneAD, its_b, d_value.sum());

#ifdef FORWARD
    Float param = sceneAD.getDer().shape_list[shape_idx]->param;
    grad_image.put(pixel_idx, Spectrum(param, 0, 0));
#endif
}

// ============================================================================
//                           ParticleTracer3
// ============================================================================
void ParticleTracer3::handleSensor(const Scene &scene, RndSampler *sampler,
                                   const CameraDirectSamplingRecord &cRec,
                                   int depth, const Spectrum &throughput)
{
    // scene.camera.accumulateDirect(cRec, throughput, sampler, &image[0]);
    auto [pixel_id, value] = scene.camera.sampleDirectPixel(cRec, sampler->next1D());
    image[pixel_id] += throughput * value;
    // image.put(pixel_idx, throughput * value);
}