#pragma once
#include <render/integrator.h>
#include <render/camera.h>
#include <render/imageblock.h>
#include <core/logger.h>
struct RndSampler;
struct Ray;
struct LightPath;

struct VolpathBase
{
    // radiance
    // L(x, w) = \int T(x, x') \sigma_s(x') L(x', w) dt
    Spectrum Li(const Scene &scene, RndSampler *sampler,
                const Medium *medium, const Ray &ray, int max_bounces, bool incEmission);

    // insacttering with wi
    // Lins(x, w) = \int f(x, -w', w) L(w') dw'
    // input : medium, p, wi(optional), max_bounces
    Spectrum Lins(const Scene &scene, RndSampler *sampler,
                  const Medium *medium, const Vector &p, const Vector &wi, int max_bounces);

    // nee: sample emitter from medium event
    // input : medium, p, wi(optional), mis
    Spectrum neeEmitter(const Scene &scene, RndSampler *sampler,
                        const Medium *medium, const Vector &p, const Vector &wi, bool mis);

    // nee: sample phase function from medium event
    Spectrum neePhase(const Scene &scene, RndSampler *sampler,
                      const Medium *medium, const Vector &p, const Vector &wi, bool mis);

    // nee: from medium event
    Spectrum nee(const Scene &scene, RndSampler *sampler,
                 const Medium *medium, const Vector &p, const Vector &wi, bool mis);

    // callback
    virtual void handleMedium(const Scene &scene, RndSampler *sampler,
                              const Medium *medium, const Vector &p, const Vector &wi,
                              int max_bounces, const Spectrum &throughput){};

    virtual void handleSurface(){};

    virtual Spectrum handleNee(const Scene &scene, RndSampler *sampler,
                               const Medium *medium, const Vector &p, const Vector &wi);

    virtual void handleNee(const Spectrum &value){};

    virtual Spectrum handleEmission(const Intersection &its, const Vector &wo, bool incEmission);
};

// camera agnostic
struct ParticleTracerBase
{
    // just like Li
    void importance(const Scene &scene, RndSampler *sampler,
                    const Medium *medium, const Ray &ray, bool on_surface,
                    int depth, int max_bounces,
                    const Spectrum &throughput);

    // sample a particle on the emitters
    std::tuple<Spectrum, Intersection> sampleParticle(const Scene &scene, RndSampler *sampler);

    // sample and trace the particle
    void traceParticle(const Scene &scene, RndSampler *sampler, int max_bounces);

    void handleEmission(const Scene &scene, RndSampler *sampler,
                        const Intersection &its, const Spectrum &throughput);

    // connect to the camera from medium event
    virtual void handleMedium(const Scene &scene, RndSampler *sampler,
                              const Medium *medium, const Vector &p, const Vector &wi,
                              int depth, const Spectrum &throughput);

    // no phase function handling, used for medium boundary term
    virtual void handleMedium(const Scene &scene, RndSampler *sampler,
                              const Medium *medium, const Vector &p,
                              int depth, const Spectrum &throughput);

    virtual void handleSurface(const Scene &scene, RndSampler *sampler,
                               const Intersection &its,
                               int depth, const Spectrum &throughput);

    // handle the camera filter, suppose to be overridden
    // can be used to compute the gradient
    virtual void handleSensor(const Scene &scene, RndSampler *sampler,
                              const CameraDirectSamplingRecord &cRec,
                              int depth, const Spectrum &throughput){};
};

struct RadianceTracer : VolpathBase
{
    void handleNee(const Spectrum &value);

    Spectrum Li(const Scene &scene, RndSampler *sampler,
                const Medium *medium, const Ray &ray, int max_bounces, bool incEmission);

    std::tuple<Spectrum, std::vector<Spectrum>> sampleSource(
        const Scene &scene, RndSampler *sampler,
        const Medium *medium, const Vector &p, const Vector &wi, int max_bounces);
    int depth = 0;
    std::vector<Spectrum> radiances;
};

// volpath without nee
struct RadianceTracer2 : VolpathBase
{
    Spectrum handleNee(const Scene &scene, RndSampler *sampler,
                       const Medium *medium, const Vector &p, const Vector &wi) override;

    Spectrum handleEmission(const Intersection &its, const Vector &wo, bool incEmission) override;
};

// volpath for unidirectional boundary integrator
struct RadianceTracer3 : VolpathBase
{
    RadianceTracer3(SceneAD &sceneAD, const Spectrum &dI) : sceneAD(sceneAD), dI(dI) {}
    
    // given a sampled ray, compute the boundary term
    void handleBoundary(const Scene &scene, RndSampler *sampler,
                        const Medium *medium, const Ray &_ray,
                        int max_bounces, const Spectrum &throughput);

    void handleMedium(const Scene &scene, RndSampler *sampler,
                      const Medium *medium, const Vector &p, const Vector &wi,
                      int max_bounces, const Spectrum &throughput) override;

    SceneAD &sceneAD;
    const Spectrum &dI;
};

struct ImportanceTracer : ParticleTracerBase
{
    ImportanceTracer(SceneAD &sceneAD, const Intersection &its_b, int max_bounces,
                     const ImageBlock &d_image, const std::vector<Spectrum> &radiances,
                     ImageBlock &grad_image)
        : sceneAD(sceneAD), its_b(its_b), max_bounces(max_bounces),
          d_image(d_image), radiances(radiances), grad_image(grad_image) {}

    void sampleDetector()
    {
    }

    // connect the source subpath with the detector subpath
    // I need to know which bounce I am on
    void handleSensor(const Scene &scene, RndSampler *sampler,
                      const CameraDirectSamplingRecord &cRec,
                      int depth, const Spectrum &throughput) override;

    SceneAD &sceneAD;                       // sceneAD, used in handleSensor.d_velocity
    const Intersection &its_b;              // boundary event, used in handleSensor.d_velocity
    int max_bounces;                        // used to connect the paths
    const ImageBlock &d_image;              // fetch upstream gradient
    const std::vector<Spectrum> &radiances; // used to connect the paths
    ImageBlock &grad_image;                 // store gradient in handleSensor
};

struct ImportanceTracer2 : ParticleTracerBase
{
    ImportanceTracer2(SceneAD &sceneAD, const Intersection &its_b, int max_bounces,
                      const ImageBlock &d_image, const Spectrum &radiance,
                      ImageBlock &grad_image)
        : sceneAD(sceneAD), its_b(its_b), max_bounces(max_bounces),
          d_image(d_image), radiance(radiance), grad_image(grad_image) {}

    void sampleDetector()
    {
    }

    // connect the source subpath with the detector subpath
    // I need to know which bounce I am on
    void handleSensor(const Scene &scene, RndSampler *sampler,
                      const CameraDirectSamplingRecord &cRec,
                      int depth, const Spectrum &throughput) override;

    SceneAD &sceneAD;          // sceneAD, used in handleSensor.d_velocity
    const Intersection &its_b; // boundary event, used in handleSensor.d_velocity
    int max_bounces;           // used to connect the paths
    const ImageBlock &d_image; // fetch upstream gradient
    const Spectrum &radiance;  // used to connect the paths
    ImageBlock &grad_image;    // store gradient in handleSensor
};

struct ParticleTracer3 : ParticleTracerBase
{
    ParticleTracer3(std::vector<Spectrum> &image) : image(image) {}
    void handleSensor(const Scene &scene, RndSampler *sampler,
                      const CameraDirectSamplingRecord &cRec,
                      int depth, const Spectrum &throughput) override;

    std::vector<Spectrum> &image; // store pixel color to image
};