// Copyright @yucwang 2022

#include "../Boundary.h"
#include "../Integrator.h"
#include "ReSTIR/ReSTIRDataTypes.h"

#include <random>

namespace EDX {
namespace TensorRay {

struct EdgeReservoir {
    int numOfSamples;
    std::shared_ptr<LightSample> x;
    std::shared_ptr<BoundarySegSampleSecondary> edgePoints;
    Tensorb isBoundary;
    Tensorf wAverage;
    Tensorf W;
    Tensori m;

    EdgeReservoir(): numOfSamples(0), 
                 x(std::make_shared<LightSample>()), 
                 edgePoints(std::make_shared<BoundarySegSampleSecondary>()) {}
    EdgeReservoir(int _numOfSamples): numOfSamples(_numOfSamples),
                                  x(std::make_shared<LightSample>(_numOfSamples)), 
                                  edgePoints(std::make_shared<BoundarySegSampleSecondary>(_numOfSamples)) {
        wAverage = Zeros(Shape({ _numOfSamples }, VecType::Scalar1));
        W = Zeros(Shape({ _numOfSamples }, VecType::Scalar1));
        m = Zeros(Shape({ _numOfSamples }, VecType::Scalar1));
        isBoundary = Zeros(Shape({ _numOfSamples }, VecType::Scalar1));
    }

    static void Combine(const EdgeReservoir& r1, const EdgeReservoir& r2, EdgeReservoir& dst) {
        if (r1.numOfSamples == 0) {
            dst.m = r2.m;
            dst.wAverage = r2.wAverage;
            dst.W = r2.W;
            dst.isBoundary = r2.isBoundary;
        } else if (r2.numOfSamples == 0) {
            dst.m = r1.m;
            dst.wAverage = r1.wAverage;
            dst.W = r1.W;
            dst.isBoundary = r1.isBoundary;
        } else {
            dst.m = Detach(Concat(r1.m, r2.m, 0));
            dst.wAverage = Detach(Concat(r1.wAverage, r2.wAverage, 0));
            dst.W = Detach(Concat(r1.W, r2.W, 0));
            dst.isBoundary = Detach(Concat(r1.isBoundary, r2.isBoundary, 0));
        }

        LightSample lightSample;
        LightSample::Combine(*r1.x, *r2.x, lightSample);
        dst.x = std::make_shared<LightSample>(lightSample);

        BoundarySegSampleSecondary edgePoints(r1.numOfSamples + r2.numOfSamples);
        BoundarySegSampleSecondary::Combine(*r1.edgePoints, *r2.edgePoints, edgePoints);
        dst.edgePoints = std::make_shared<BoundarySegSampleSecondary>(edgePoints);
        dst.numOfSamples = r1.numOfSamples + r2.numOfSamples;
    }

    EdgeReservoir GetIndexedCopy(const Tensori& index, int size) const {
        EdgeReservoir ret;
        ret.numOfSamples = size;
        ret.x = std::make_shared<LightSample>(x->GetIndexedCopy(index, size));
        ret.edgePoints = std::make_shared<BoundarySegSampleSecondary>(edgePoints->GetIndexedCopy(index));
        ret.wAverage = IndexedRead(wAverage, index, 0);
        ret.W = IndexedRead(W, index, 0);
        ret.m = IndexedRead(m, index, 0);
        ret.isBoundary = IndexedRead(isBoundary, index, 0);

        return ret;
    }

    static Tensori UpdateReservoir(EdgeReservoir& r, const LightSample& xI,const Expr& wI, const Expr& m1, 
                         const Expr& sample);
    
    static Tensori UpdateReservoir(EdgeReservoir& r, const EdgeReservoir& r1, const Expr& sample);

    // void UpdateIndexed(const EdgeReservoir& other, const Tensori& index) {
    //     x->UpdateIndexed(*other.x, index);
    //     // shadingPoints->UpdateIndexed(*other.shadingPoints, index);
    //     Tensorf wAverage1 = IndexedRead(wAverage, index, 0);
    //     wAverage = Detach(wAverage + IndexedWrite(other.wAverage - wAverage1, index, wAverage.GetShape(), 0));

    //     Tensorf W1 = IndexedRead(W, index, 0);
    //     W = Detach(W + IndexedWrite(other.W - W1, index, W.GetShape(), 0));

    //     Tensori m1 = IndexedRead(m, index, 0);
    //     m = Detach(m + IndexedWrite(other.m - m1, index, m.GetShape(), 0));
    // }
};

class EdgeDirectSample {
public:
    std::shared_ptr<SpatialVertices> lightSamples;
    std::shared_ptr<BoundarySegSampleSecondary> edgePoints;
    Tensorf lightPdf;
    Tensorf edgePointsPdf;
    Tensorb isBoundary;
};

void CheckEdgeList(const SecondaryEdgeInfo& list1, const SecondaryEdgeInfo& list2);
void CheckEdgeList(const EdgeIndexInfo& list1, const EdgeIndexInfo& list2);

class DirectBoundaryIntegrator2 : public Integrator
{
public:
    DirectBoundaryIntegrator2(): Integrator(), reservoirs(std::make_shared<EdgeReservoir>()), 
        reservoirsOnFile(std::make_shared<EdgeReservoir>()) {
            std::random_device rd;
            randomGenerator = std::mt19937(rd());
        }

    void SetParam(const RenderOptions& options) 
    {
        const RISRenderOptions& _options = dynamic_cast<const RISRenderOptions&>(options);
        M = _options.Me0;
        haveTemporalReuse = _options.haveTemporalReuse;
        k = _options.secBoundaryK;
        k1 = _options.secBoundaryK1;
        historyLength = _options.historyLength;
        mSpp = options.mSppDirect;
        mSppBatch = options.mSppDirectBatch;
        mMaxBounces = options.mMaxBounces;
        mVerbose = !options.mQuiet;
        boundaryGuiding = _options.boundaryGuiding;
        reservoirMergeNormalThreshold = _options.reservoirMergeNormalThreshold;
        reservoirMergeDistThreshold = _options.reservoirMergeDistThreshold;

        g_direct = options.g_direct;
        g_options.depth = options.g_direct_depth;
        g_options.max_size = options.g_direct_max_size;
        g_options.spp = options.g_direct_spp;
        g_options.thold = options.g_direct_thold;
        g_options.eps = options.g_eps;

        reservoirToUse.clear();
        for (int i = 0; i < k; ++i) reservoirToUse.push_back(i);
    }

    void Integrate(const Scene& scene, Tensorf& image) const;

    int SampleBoundarySegmentReSTIR(const Scene& scene, 
            const SecondaryEdgeInfo &secEdges, int numSamples, const Tensorf& rnd_b, 
            EdgeDirectSample& samples) const;

    int EvalBoundarySegment(const Scene& scene, 
            const EdgeDirectSample &secEdges, Tensorf& boundaryTerm, float spp) const;

    int SampleBoundarySegment(const Scene& scene, 
            const SecondaryEdgeInfo &secEdges, int numSamples, const Tensorf& rnd_b, 
            const Tensorf& pdf_b, BoundarySegSampleDirect& samples) const;

    int SampleBoundarySegment(const Scene& scene, 
            const SecondaryEdgeInfo &secEdges, int numSamples, const Tensorf& rnd_b, 
            const Tensorf& pdf_b, EdgeDirectSample& samples) const;

    Tensorf EvalPHat(const Scene& scene, const BoundarySegSampleSecondary& edgePoints,
                        const SpatialVertices& lightSamples, const Tensorb& isBoundary, Tensori& validIndex) const;

    static int EvalBoundarySegment(const Camera& camera, const Scene& scene, 
                BoundarySegSampleDirect& bss, Tensorf& boundaryTerm, int spp); 

    void SampleEmitterDirect(const Scene& scene, const Tensorf& rnd_light, MaterialVertices& vLight) const;

    virtual void Step() override;

private:
    bool  g_direct = false;
    GuidingOption g_options;
    int M;
    mutable bool haveTemporalReuse;
    bool boundaryGuiding;
    int k;
    int k1;
    float reservoirMergeNormalThreshold;
    float reservoirMergeDistThreshold;
    int historyLength;
    mutable std::shared_ptr<EdgeReservoir> reservoirs;
    mutable std::shared_ptr<EdgeReservoir> reservoirsOnFile;
    mutable std::shared_ptr<EdgeIndexInfo> storedInfo;

    mutable std::mt19937 randomGenerator;
    mutable std::vector<int> reservoirToUse;
    mutable Tensori reuseMask;
    mutable Tensori T;

private:
    std::shared_ptr<EdgeReservoir> PrefilerReservoir(const Scene& scene, 
        const Camera& camera,
        const std::shared_ptr<EdgeReservoir>& edgeReservoirs,
        bool updateReuseMask = false) const;
};

} // namespace TensorRay
} // namespace EDX