#pragma once
#include <render/integrator.h>
#include "volpath.h"
#include <core/pmf.h>
#include <core/sampler.h>
#include <render/imageblock.h>
#include <render/intersection.h>
struct RndSampler;
struct Ray;

struct VolpathInterior : UnidirectionalPathTracer
{
    Spectrum Li(const Scene &scene, const Ray &ray, RadianceQueryRecord &rRec) const override;
    void LiAD(SceneAD &sceneAD, const Ray &ray, RadianceQueryRecord &rRec, const Spectrum &d_res) const override;

    std::string getName() const override { return "Volpath Interior"; }
};

struct VolpathBoundary : UnidirectionalPathTracer
{
    Spectrum Li(const Scene &scene, const Ray &ray, RadianceQueryRecord &rRec) const override;
    // sample boundary path and evalute the boundary term
    void LiAD(SceneAD &sceneAD, const Ray &ray, RadianceQueryRecord &rRec, const Spectrum &d_res) const override;
    void handleBoundary(SceneAD &sceneAD, const Spectrum &throughput,
                        const Ray &_ray, bool onSurface, const Medium *medium1,
                        RadianceQueryRecord &rRec,
                        const Spectrum &d_res) const;
    void LiAD2(SceneAD &sceneAD, const Ray &ray, RadianceQueryRecord &rRec, const Spectrum &d_res) const;
    std::string getName() const override { return "Volpath Boundary"; }

    Volpath volpath;
};

struct VolpathMerged : UnidirectionalPathTracer
{
    Spectrum Li(const Scene &scene, const Ray &ray, RadianceQueryRecord &rRec) const override;
    void LiAD(SceneAD &sceneAD, const Ray &ray, RadianceQueryRecord &rRec, const Spectrum &d_res) const override;

    std::string getName() const override { return "Volpath Merged"; }
    VolpathInterior volpathInterior;
    VolpathBoundary volpathBoundary;
};

struct Volpath2 : Integrator
{
    ArrayXd renderC(const Scene &scene, const RenderOptions &options) const override;
    ArrayXd renderD(SceneAD &sceneAD, const RenderOptions &options, const ArrayXd &d_image) const override;
    std::string getName() const { return "Volpath2"; }
    VolpathInterior volpathInterior;
    VolpathBoundary volpathBoundary;
};

// struct Volpath2 : UnidirectionalPathTracer
// {
//     Spectrum Li(const Scene &scene, const Ray &ray, RadianceQueryRecord &rRec) const override;
//     void LiAD(SceneAD &sceneAD, const Ray &ray, RadianceQueryRecord &rRec, const Spectrum &d_res) const override;
//     virtual ArrayXd renderD(SceneAD &sceneAD, const RenderOptions &options, const ArrayXd &d_image) const override;
//     Spectrum pixelColorAD1(SceneAD &sceneAD, PixelQueryRecord &pRec, const Spectrum &_d_res) const;
//     Spectrum LiAD1(SceneAD &sceneAD, const Ray &ray, RadianceQueryRecord &rRec, const Spectrum &d_res) const;
//     void pixelColorAD(SceneAD &sceneAD, PixelQueryRecord &pRec, const Spectrum &d_res) const;

//     // estimate the boundary term of the new formulation
//     // taking the upstream gradient and backpropagating it to the d_scene,
//     // and return the boundary term gradient image if necessary.
//     ArrayXs boundaryTerm(SceneAD &sceneAD, const RenderOptions &options, const ArrayXd &d_image);
//     Spectrum boundaryPixelColorAD(SceneAD &sceneAD, PixelQueryRecord &pRec, const Spectrum &d_res);
//     Spectrum boundaryLiAD(SceneAD &sceneAD, const Ray &ray, RadianceQueryRecord &rRec, const Spectrum &d_res);

//     ArrayXs interiorTerm(SceneAD &sceneAD, const RenderOptions &options, const ArrayXd &d_image);
//     Spectrum interiorPixelColorAD(SceneAD &sceneAD, PixelQueryRecord &pRec, const Spectrum &d_res);
//     Spectrum interiorLiAD(SceneAD &sceneAD, const Ray &ray, RadianceQueryRecord &rRec, const Spectrum &d_res);

//     std::string getName() const override { return "Volpath2"; }
// };

// ===========================================================================
//                    Bidirectional Boundary Integrator
// ===========================================================================
#include "volpathBase.h"
struct VolumeBoundaryQueryRecord
{
    RndSampler *sampler;
    int max_bounces;
    const ImageBlock &d_image;    // get the upstream gradient
    ImageBlock &grad_image; // gradient image
};

struct VolumeParticleRecord : VolumeBoundaryQueryRecord
{
    VolumeParticleRecord(VolumeBoundaryQueryRecord &bRec)
        : VolumeBoundaryQueryRecord(bRec) {}
    Intersection its_b;
    Spectrum throughput;
    const Medium *medium;
    Vector vn; // normal velocity
};

struct BoundaryBidirectional : Integrator
{
    BoundaryBidirectional(const Scene &scene);

    // build the distribution of the medium bounding shapes
    DiscreteDistribution buildShapeDistribution(const Scene &scene) const;
    // sample a point on medium boundary and return the intersection and the sampling pdf
    std::pair<Intersection, Float> sampleBoundaryPoint(const Scene &scene, const Array2 &rnd) const;

    /** sample a vertex on the volume boundary
     *  sample the source subpath
     *  sample the detector subpath
     *  connect them together
     *  compute the boundary normal velocity
     *  compute the boundary term
     *  accumulate the boundary term to the gradient image
     */
    void sampleBoundary(SceneAD &sceneAD, VolumeBoundaryQueryRecord &bRec) const;


    // forward rendering
    ArrayXd renderC(const Scene &scene, const RenderOptions &options) const override;
    // rendering gradient image
    ArrayXd renderD(SceneAD &sceneAD, const RenderOptions &options, const ArrayXd &d_image) const override;

    DiscreteDistribution shapeDistribution;
};