// Copyright @yucwang 2022

#pragma once
#ifndef LIBTENSORRAY_ALGORITHM1_RESTIR_MATERIALSPACERESTIR_H_
#define LIBTENSORRAY_ALGORITHM1_RESTIR_MATERIALSPACERESTIR_H_
#endif

#include "../PathSampler.h"
#include "ReSTIRDataTypes.h"

#include <random>

namespace EDX {

namespace TensorRay {

class MaterialSpaceReSTIRPathSampler: public PathSampler
{
public:
    static shared_ptr<MaterialSpaceReSTIRPathSampler> GetInstance() {
        static shared_ptr<MaterialSpaceReSTIRPathSampler> s(new MaterialSpaceReSTIRPathSampler());
        return s;
    }
public:
    MaterialSpaceReSTIRPathSampler(): reservoir(make_shared<Reservoir>()), reservoirOnFile(make_shared<Reservoir>()) {
        std::random_device rd;
        randomGenerator = std::mt19937(rd());
    }

    virtual void SetParam(const RenderOptions& options) override;

    virtual void SamplePaths(const Scene& scene, PathSampleResult& res) const override;

    virtual void Step() const {
        if (!haveTemporalReuse) return;
        Reservoir r;
        if (reservoir->numOfSamples > 0) {
            IndexMask nonreuseReservoirMask = IndexMask(reuseMask > Scalar(0));
            if (nonreuseReservoirMask.sum > 0) {
                Reservoir nonreuseReservoir = reservoir->GetIndexedCopy(nonreuseReservoirMask.index, nonreuseReservoirMask.sum);
                Reservoir::Combine(*reservoirOnFile, nonreuseReservoir, r);
            } else {
                r = *reservoirOnFile;
            }
        } else {
            r = *reservoirOnFile;
        }
        reservoir.reset();
        reservoir = std::make_shared<Reservoir>(r);
        reuseMask = Ones(Shape({reservoir->numOfSamples}, VecType::Scalar1));
        reservoirOnFile.reset();
        reservoirOnFile = std::make_shared<Reservoir>();
    }

    int M;
    bool haveTemporalReuse;
    int k;
    int k1;
    int historyLength;
    float reservoirMergeNormalThreshold;
    float reservoirMergeDistThreshold;
    mutable std::shared_ptr<Reservoir> reservoir;
    mutable std::shared_ptr<Reservoir> reservoirOnFile;
    mutable Tensori reuseMask;
    mutable bool storeReservoir;

    mutable std::mt19937 randomGenerator;
    mutable std::vector<int> reservoirUsed;

private:
    std::shared_ptr<Reservoir> PrefilerReservoir(const Scene& scene, const Camera& camera,
                                                    const Tensorf& curPosition) const;
};

void EvalMaterialSpaceReSTIR(const Scene& scene, 
                           const SpatialVertices& vCamera, 
                           const SpatialVertices& vCur, 
                           const SpatialVertices& vLight, 
                           Tensorf& res);

} // namespace TensorRay
} // namespace EDX