#pragma once
#include <render/integrator.h>

#include "../volpath.h"
#include "../volpathBase.h"
struct RndSampler;
struct Ray;

// volpath for unidirectional boundary integrator
struct BoundaryUnidirectionalBase : VolpathBase {
    struct Context {
        SceneAD *sceneAD = nullptr;
        Spectrum dI;
    };
    BoundaryUnidirectionalBase() = default;
    BoundaryUnidirectionalBase(const Context &ctx, ESamplingMode mode)
        : VolpathBase(mode), ctx(ctx) {}
    BoundaryUnidirectionalBase(const Properties &props)
        : VolpathBase(props) {}

    // given a sampled ray, compute the boundary term
    // return how many intersections are found
    int handleBoundary(const Scene &scene, RndSampler *sampler,
                        const Medium *medium, const Ray &_ray,
                        int max_bounces, const Spectrum &throughput, 
                        Float pdf /*solid angle*/);

    void handleSensor(const Scene &scene, RndSampler *sampler, const Ray &ray,
                      const Array2i &pixel_idx, const Medium *medium, int max_bounces) override;

    void handleMedium(const Scene &scene, RndSampler *sampler,
                      const Medium *medium, const Vector &p, 
                      const Vector &wi, const Vector &wo,
                      int max_bounces, const Spectrum &throughput) override;

    void handleSurface(const Scene &scene, RndSampler *sampler,
                       const Intersection &its,
                       int max_bounces, const Spectrum &throughput) override;

    bool isNext(int bounces) const override;

    Spectrum Lins(const Scene &scene, RndSampler *sampler,
                  const Intersection &its,
                  const Vector &wi, int max_bounces);
    Context  ctx;
    MISContext mis_ctx;
};

struct BoundaryUnidirectional : public UnidirectionalPathTracer, public MISIntegrator {
    BoundaryUnidirectional(const Properties &props = Properties({ { "enable_antithetic", false }, {"sampling_mode", "area"} }))
        : UnidirectionalPathTracer(props),
          MISIntegrator(props) {}

    Spectrum Li(const Scene &scene, const Ray &ray, RadianceQueryRecord &rRec) const override;
    // sample boundary path and evalute the boundary term
    void        LiAD(SceneAD &sceneAD, const Ray &ray, RadianceQueryRecord &rRec, const Spectrum &d_res) const override;
    void        handleBoundary(SceneAD &sceneAD, const Spectrum &throughput,
                               const Ray &_ray, bool onSurface, const Medium *medium1,
                               RadianceQueryRecord &rRec,
                               const Spectrum      &d_res) const;
    void        LiAD2(SceneAD &sceneAD, const Ray &ray, RadianceQueryRecord &rRec, const Spectrum &d_res) const;
    std::string getName() const override { return "BoundaryUnidirectional"; }
};