#include "ptracerBase.h"
#include <core/sampler.h>
#include <render/scene.h>
#include <render/medium.h>

void ParticleTracerBase::importance(const Scene &scene, RndSampler *sampler,
                                    const Medium *medium, const Ray &_ray, bool on_surface,
                                    int depth, int max_bounces,
                                    const Spectrum &_throughput)
{
    Spectrum throughput(_throughput);
    Ray ray(_ray);
    Intersection its;
    int k = 0;
    while (depth < max_bounces)
    {
        if (k > 100)
        {
            PSDR_WARN("ParticleTracerBase::importance: k > 100");
            break;
        }
        k++;

        if (throughput.isZero())
            break;
        scene.rayIntersect(ray, on_surface, its);
        MediumSamplingRecord mRec;
        bool inside_med = medium != nullptr &&
                          medium->sampleDistance(ray, its.t, sampler, mRec);
        ray.org = mRec.p;
        if (inside_med)
        {
            // sampled a medium interaction
            throughput *= mRec.sigmaS * mRec.transmittance / mRec.pdfSuccess;
            const auto *phase = scene.phase_list[medium->phase_id];
            Vector wo;
            Float phase_val = phase->sample(-ray.dir, sampler->next2D(), wo);
            handleMedium(scene, sampler, medium, mRec.p, -ray.dir, wo, depth + 1, throughput);
            if (phase_val < Epsilon)
                break;
            Float phasePdf = phase->pdf(-ray.dir, wo);
            throughput *= phase_val;
            ray.dir = wo;
            on_surface = false;
            depth++;
        }
        else
        {
            // hit the surface
            // if the ray going through a medium
            if (medium)
            {
                throughput *= mRec.transmittance / mRec.pdfFailure;
            }
            if (throughput.isZero())
                break;
            if (!its.isValid())
                break;
            if (!its.ptr_bsdf->isNull())
            {
                if(dynamic_cast<const EnvironmentMap *>(its.ptr_emitter))
                    break;
                // handleSurfaceInteraction(scene, its, p_rec);
            }
            Float bsdf_pdf, bsdf_eta;
            Vector wo_local, wo;
            EBSDFMode mode = EBSDFMode::EImportanceWithCorrection;
            auto bsdf_weight = its.sampleBSDF(sampler->next3D(), wo_local, bsdf_pdf, bsdf_eta, mode);
            if (bsdf_weight.isZero())
                break;
            wo = its.toWorld(wo_local);
            Vector wi = -ray.dir;
            Float wiDotGeoN = wi.dot(its.geoFrame.n),
                  woDotGeoN = wo.dot(its.geoFrame.n);
            if (wiDotGeoN * its.wi.z() <= 0 ||
                woDotGeoN * wo_local.z() <= 0)
            {
                break;
            }
            throughput *= bsdf_weight;
            if (its.isMediumTransition())
            {
                medium = its.getTargetMedium(woDotGeoN);
            }
            ray = Ray(its.p, wo);
            on_surface = true;
            if (!its.ptr_bsdf->isNull())
            {
                depth++;
            }
        }
    }
}

void ParticleTracerBase::traceParticle(const Scene &scene, RndSampler *sampler, int max_bounces)
{
    Intersection its;
    Spectrum throughput = scene.sampleEmitterPosition(sampler->next2D(), its);
    handleEmission(scene, sampler, its, throughput);
    Vector wo;
    Float pdf;
    throughput *= its.ptr_emitter->sampleDirection(sampler->next2D(), wo, &pdf);
    // not point source
    if (its.ptr_shape != nullptr)
        wo = its.geoFrame.toWorld(wo);

    Ray ray(its.p, wo);

    bool on_surface = true;
    const Medium *medium = its.getTargetMedium(ray.dir);

    importance(scene, sampler, medium, ray, true, 0, max_bounces, throughput);
}

void ParticleTracerBase::handleEmission(const Scene &scene, RndSampler *sampler,
                                        const Intersection &its, const Spectrum &throughput)
{
    CameraDirectSamplingRecord cRec;
    Spectrum value(scene.sampleAttenuatedSensorDirect(its, sampler, cRec));
    if (!value.isZero(Epsilon) && cRec.baseVal > Epsilon)
    {
        value *= its.ptr_emitter->evalDirection(its.geoFrame.n, cRec.dir) * throughput;
        if (!value.isZero(Epsilon))
            handleSensor(scene, sampler, cRec, 0, value);
    }
}

std::tuple<Spectrum, Intersection> ParticleTracerBase::sampleParticle(
    const Scene &scene, RndSampler *sampler)
{
    Intersection its;
    Spectrum value = scene.sampleEmitterPosition(sampler->next2D(), its); // intensity / pdf
    return {value, its};
}

// handle phase function
void ParticleTracerBase::handleMedium(
    const Scene &scene, RndSampler *sampler,
    const Medium *medium, const Vector &p,
    const Vector &wi, const Vector &wo,
    int depth, const Spectrum &throughput) {
    CameraDirectSamplingRecord cRec;
    Spectrum value{scene.sampleAttenuatedSensorDirect(p, medium, sampler, cRec)};
    if (!value.isZero(Epsilon) && cRec.baseVal > Epsilon)
    {
        const PhaseFunction *ptr_phase = scene.phase_list[medium->phase_id];
        value *= ptr_phase->eval(wi, cRec.dir) * throughput;
        if (!value.isZero(Epsilon))
            handleSensor(scene, sampler, cRec, depth, value);
    }
}

// does not handle phase function
void ParticleTracerBase::handleMedium(
    const Scene &scene, RndSampler *sampler,
    const Medium *medium, const Vector &p,
    int depth, const Spectrum &throughput)
{
    CameraDirectSamplingRecord cRec;
    Spectrum value{scene.sampleAttenuatedSensorDirect(p, medium, sampler, cRec)};
    if (!value.isZero(Epsilon) && cRec.baseVal > Epsilon)
    {
        value *= throughput;
        if (!value.isZero(Epsilon))
            handleSensor(scene, sampler, cRec, depth, value);
    }
}

void ParticleTracerBase::handleSurface(
    const Scene &scene, RndSampler *sampler,
    const Intersection &its,
    int depth, const Spectrum &throughput)
{
    CameraDirectSamplingRecord cRec;
    Spectrum value(scene.sampleAttenuatedSensorDirect(its, sampler, cRec));
    if (!value.isZero(Epsilon) && cRec.baseVal > Epsilon)
    {
        Vector wi = its.toWorld(its.wi);
        Vector wo = cRec.dir, wo_local = its.toLocal(wo);
        /* Prevent light leaks due to the use of shading normals -- [Veach, p. 158] */
        Float wiDotGeoN = wi.dot(its.geoFrame.n), woDotGeoN = wo.dot(its.geoFrame.n);
        if (wiDotGeoN * its.wi.z() > Epsilon && woDotGeoN * wo_local.z() > Epsilon)
        {
            value *= its.ptr_bsdf->eval(its, wo_local, EBSDFMode::EImportanceWithCorrection) * throughput;
            if (!value.isZero(Epsilon))
                handleSensor(scene, sampler, cRec, depth, value);
        }
    }
}

// ============================================================================
//                           ParticleTracer3
// ============================================================================
void ParticleTracer3::handleSensor(const Scene &scene, RndSampler *sampler,
                                   const CameraDirectSamplingRecord &cRec,
                                   int depth, const Spectrum &throughput)
{
    // scene.camera.accumulateDirect(cRec, throughput, sampler, &image[0]);
    auto [pixel_id, value] = scene.camera.sampleDirectPixel(cRec, sampler->next1D());
    image[pixel_id] += throughput * value;
    // image.put(pixel_idx, throughput * value);
}