#pragma once
#include <render/integrator.h>
#include <render/imageblock.h>
#include <render/intersection.h>
#include "../volpathBase.h"
#include "../ptracerBase.h"

struct VolumeBoundaryQueryRecord
{
    RndSampler *sampler;
    int max_bounces;
    const ImageBlock &d_image;    // get the upstream gradient
    ImageBlock &grad_image; // gradient image
};

struct VolumeParticleRecord : VolumeBoundaryQueryRecord
{
    VolumeParticleRecord(VolumeBoundaryQueryRecord &bRec)
        : VolumeBoundaryQueryRecord(bRec) {}
    Intersection its_b;
    Spectrum throughput;
    const Medium *medium;
    Vector vn; // normal velocity
};

// ============================================================================
//                          Radiance tracer with nee
// ============================================================================

struct RadianceTracer : VolpathBase {
    RadianceTracer() = default;
    RadianceTracer(ESamplingMode mode) : VolpathBase(mode) {}
    RadianceTracer(const Properties &props) : VolpathBase(props) {}

    void handleNee(const Spectrum &value);

    Spectrum Li(const Scene &scene, RndSampler *sampler,
                const Medium *medium, const Ray &ray, int max_bounces, bool incEmission);

    Spectrum _Lins(const Scene &scene, RndSampler *sampler,
                  const Intersection &its,
                  const Vector &wi, int max_bounces);

    std::tuple<Spectrum, std::vector<Spectrum>> sampleSource(
        const Scene &scene, RndSampler *sampler,
        const Medium *medium, const Vector &p, const Vector &wi, int max_bounces);

    std::tuple<Spectrum, std::vector<Spectrum>> sampleSource(
        const Scene &scene, RndSampler *sampler,
        const Intersection &its, const Vector &wi, int max_bounces);

    int                   depth = 0;
    std::vector<Spectrum> radiances;
};

// ============================================================================
//                          Radiance tracer without nee
// ============================================================================


struct RadianceTracer2 : VolpathBase
{
    Spectrum handleNee(const Scene &scene, RndSampler *sampler,
                       const Medium *medium, const Vector &p, const Vector &wi) override;

    Spectrum handleEmission(const Intersection &its, const Vector &wo, bool incEmission) override;
};

struct ImportanceTracer : ParticleTracerBase, MISIntegrator
{
    struct Context {
        SceneAD *sceneAD; // sceneAD, used in handleSensor.d_velocity
        const Intersection *its_b; // boundary event, used in handleSensor.d_velocity
        int max_bounces;
        const ImageBlock *d_image;
        ImageBlock *grad_image;
        const std::vector<Spectrum> *radiances; // from radiance tracer
    };

    ImportanceTracer() = default;
    ImportanceTracer(const Context &ctx, ESamplingMode mode)
        : ParticleTracerBase(), MISIntegrator(mode), ctx(ctx) {}
    ImportanceTracer(const Properties &props) : MISIntegrator(props) {}
    // ImportanceTracer(SceneAD &sceneAD, const Intersection &its_b, int max_bounces,
    //                  const ImageBlock &d_image, const std::vector<Spectrum> &radiances,
    //                  ImageBlock &grad_image, bool mis = false, int mis_mode = 0)
    //     : sceneAD(sceneAD), its_b(its_b), max_bounces(max_bounces),
    //       d_image(d_image), radiances(radiances), grad_image(grad_image), m_mis(mis), m_mis_mode(mis_mode) {}

    void sampleDetector()
    {
    }

    void handleMedium(const Scene &scene, RndSampler *sampler,
                      const Medium *medium, const Vector &p, 
                      const Vector &wi, const Vector &wo,
                      int depth, const Spectrum &throughput) override;

    // connect the boundary interaction directly to the camera
    void handleBoundary(const Scene &scene, RndSampler *sampler,
                        const Medium *medium, const Vector &p,
                        const Spectrum &throughput);
    
    // connect the source subpath with the detector subpath
    // I need to know which bounce I am on
    void handleSensor(const Scene &scene, RndSampler *sampler,
                      const CameraDirectSamplingRecord &cRec,
                      int depth, const Spectrum &throughput) override;
    Context ctx;
    MISContext mis_ctx;
    // SceneAD &sceneAD;                       // sceneAD, used in handleSensor.d_velocity
    // const Intersection &its_b;              // boundary event, used in handleSensor.d_velocity
    // int max_bounces;                        // used to connect the paths
    // const ImageBlock &d_image;              // fetch upstream gradient
    // const std::vector<Spectrum> &radiances; // used to connect the paths
    // ImageBlock &grad_image;                 // store gradient in handleSensor
    // bool m_mis = false;
    // bool m_mis_mode = 0; // 0: balance heuristic, 1: power heuristic
};

struct ImportanceTracer2 : ParticleTracerBase
{
    ImportanceTracer2(SceneAD &sceneAD, const Intersection &its_b, int max_bounces,
                      const ImageBlock &d_image, const Spectrum &radiance,
                      ImageBlock &grad_image)
        : sceneAD(sceneAD), its_b(its_b), max_bounces(max_bounces),
          d_image(d_image), radiance(radiance), grad_image(grad_image) {}

    void sampleDetector()
    {
    }

    // connect the source subpath with the detector subpath
    // I need to know which bounce I am on
    void handleSensor(const Scene &scene, RndSampler *sampler,
                      const CameraDirectSamplingRecord &cRec,
                      int depth, const Spectrum &throughput) override;

    SceneAD &sceneAD;          // sceneAD, used in handleSensor.d_velocity
    const Intersection &its_b; // boundary event, used in handleSensor.d_velocity
    int max_bounces;           // used to connect the paths
    const ImageBlock &d_image; // fetch upstream gradient
    const Spectrum &radiance;  // used to connect the paths
    ImageBlock &grad_image;    // store gradient in handleSensor
};

// ============================================================================
//                          Boundary Bidirectional
// ============================================================================

struct BoundaryBidirectional : Integrator, MISIntegrator {
    BoundaryBidirectional(const Scene &scene);
    BoundaryBidirectional(const Properties &props);
    void configure(const Scene &scene) override;
    // build the distribution of the medium bounding shapes
    DiscreteDistribution              buildShapeDistribution(const Scene &scene) const;
    std::vector<DiscreteDistribution> buildFaceDistributions(const Scene &scene) const;
    // sample a point on medium boundary and return the intersection and the sampling pdf
    std::pair<Intersection, Float> sampleBoundaryPoint(const Scene &scene, const Array2 &rnd) const;

    /** sample a vertex on the volume boundary
     *  sample the source subpath
     *  sample the detector subpath
     *  connect them together
     *  compute the boundary normal velocity
     *  compute the boundary term
     *  accumulate the boundary term to the gradient image
     */
    void sampleBoundary(SceneAD &sceneAD, VolumeBoundaryQueryRecord &bRec) const;

    // forward rendering
    ArrayXd renderC(const Scene &scene, const RenderOptions &options) const override;
    // rendering gradient image
    ArrayXd renderD(SceneAD &sceneAD, const RenderOptions &options, const ArrayXd &d_image) const override;

    bool                              m_adaptive      = false;
    int                               m_adaptive_mode = -1;
    DiscreteDistribution              m_shapeDistribution;
    std::vector<DiscreteDistribution> m_faceDistributions;
};