// Copyright @yucwang 2022

#pragma once
#ifndef LIBTENSORRAY_ALGORITHM1_RESTIR_RESTIRDATATYPES_H_
#define LIBTENSORRAY_ALGORITHM1_RESTIR_RESTIRDATATYPES_H_
#endif
#include "../../Integrator.h"
#include "../PathVertex.h"
#include "../../Scene.h"
#include "../../../Tensor/Tensor.h"

namespace EDX {
namespace TensorRay {

const float RESTIR_PHAT_EPSILON = 1e-6f;

struct RISRenderOptions : public RenderOptions {
    RISRenderOptions(): RenderOptions() {};
    int M;
    int Me0;
    int Me1;
    bool haveTemporalReuse = true;
    int k;
    int k1;
    int secBoundaryK;
    int secBoundaryK1;
    int pixelBoundaryK;
    int pixelBoundaryK1;
    bool boundaryGuiding = true;
    int historyLength;
    float reservoirMergeNormalThreshold;
    float reservoirMergeDistThreshold;
    bool storeReservoir;

    RISRenderOptions(int seed, int maxBounces, int spp, 
                  float sppe, float sppse0, int sppse1, int sppe0 = 0, int _M = 4, int _Me0 = 4,
                  int _Me1 = 4, int _historyLength = 4, bool _haveTemporalReuse = true, int _k = 1, int _k1 = 1, 
                  int _secBoundaryK = 2, int _secBoundaryK1 = 1, int _pixelBoundaryK = 1, 
                  int _pixelBoundaryK1 = 1, bool _boundaryGuiding = true, float _reservoirMergeNormalThreshold = 0.8f, 
                  float _reservoirMergeDistThreshold = 0.1f, bool _storeReservoir = true)
                : RenderOptions(seed, maxBounces, spp, sppe, sppse0, sppse1, sppe0), 
                  M(_M), Me0(_Me0), Me1(_Me1), historyLength(_historyLength), 
                  haveTemporalReuse(_haveTemporalReuse), k(_k), k1(_k1), 
                  secBoundaryK(_secBoundaryK), secBoundaryK1(_secBoundaryK1), 
                  pixelBoundaryK(_pixelBoundaryK), pixelBoundaryK1(_pixelBoundaryK1),
                  boundaryGuiding(_boundaryGuiding), 
                  reservoirMergeNormalThreshold(_reservoirMergeNormalThreshold),
                  reservoirMergeDistThreshold(_reservoirMergeDistThreshold), storeReservoir(_storeReservoir) {}
};

struct LightSample {
    int numOfSamples;
    std::shared_ptr<MaterialVertices> vertices; 
    Tensorf pHat;

    LightSample(): numOfSamples(0),
                   vertices(std::make_shared<MaterialVertices>()) {}
    LightSample(int _numOfSamples): numOfSamples(_numOfSamples), 
                                    vertices(std::make_shared<MaterialVertices>(_numOfSamples)) {
        pHat = Zeros(Shape({ numOfSamples }, VecType::Scalar1));
    }

    static void Combine(const LightSample& s1, const LightSample& s2, LightSample& dst) {
        if (s1.numOfSamples == 0) {
            dst.pHat = s2.pHat;
        } else if (s2.numOfSamples == 0) {
            dst.pHat = s1.pHat;
        } else {
            dst.pHat = Detach(Concat(s1.pHat, s2.pHat, 0));
        }

        MaterialVertices::Combine(*s1.vertices, *s2.vertices, *dst.vertices);
        dst.numOfSamples = s1.numOfSamples + s2.numOfSamples;
    }

    void UpdateMasked(const LightSample& other, const Expr& mask) {
        pHat = Where(mask, other.pHat, pHat);
        if (numOfSamples == 0) {
            numOfSamples = other.numOfSamples;
        }
        vertices->UpdateMasked(*other.vertices, mask);
    }

    LightSample GetIndexedCopy(const Tensori& index, int size) const {
        LightSample ret;
        ret.vertices = std::make_shared<MaterialVertices>(vertices->GetIndexedCopy(index));
        ret.vertices->numOfSamples = size;
        ret.pHat = IndexedRead(pHat, index, 0);
        ret.numOfSamples = size;
        return ret;
    }

    void UpdateIndexed(const LightSample& other, const Tensori& index) {
        vertices->UpdateIndexed(*other.vertices, index);
        Tensorf pHat1 = IndexedRead(pHat, index, 0);
        pHat = Detach(pHat + IndexedWrite(other.pHat - pHat1, index, pHat.GetShape(), 0));
    }
};

struct Reservoir {
    int numOfSamples;
    std::shared_ptr<LightSample> x;
    std::shared_ptr<SpatialVertices> shadingPoints;
    Tensorf wAverage;
    Tensorf W;
    Tensori m;

    Reservoir(): numOfSamples(0), 
                 x(std::make_shared<LightSample>()), 
                 shadingPoints(std::make_shared<SpatialVertices>()) {}
    Reservoir(int _numOfSamples): numOfSamples(_numOfSamples),
                                  x(std::make_shared<LightSample>(_numOfSamples)), 
                                  shadingPoints(std::make_shared<SpatialVertices>(_numOfSamples)) {
        wAverage = Zeros(Shape({ _numOfSamples }, VecType::Scalar1));
        W = Zeros(Shape({ _numOfSamples }, VecType::Scalar1));
        m = Zeros(Shape({ _numOfSamples }, VecType::Scalar1));
    }

    static void Combine(const Reservoir& r1, const Reservoir& r2, Reservoir& dst) {
        if (r1.numOfSamples == 0) {
            dst.m = r2.m;
            dst.wAverage = r2.wAverage;
            dst.W = r2.W;
        } else if (r2.numOfSamples == 0) {
            dst.m = r1.m;
            dst.wAverage = r1.wAverage;
            dst.W = r1.W;
        } else {
            dst.m = Detach(Concat(r1.m, r2.m, 0));
            dst.wAverage = Detach(Concat(r1.wAverage, r2.wAverage, 0));
            dst.W = Detach(Concat(r1.W, r2.W, 0));
        }

        LightSample lightSample;
        LightSample::Combine(*r1.x, *r2.x, lightSample);
        dst.x = std::make_shared<LightSample>(lightSample);

        SpatialVertices shadingPoint(r1.numOfSamples + r2.numOfSamples);
        SpatialVertices::Combine(*r1.shadingPoints, *r2.shadingPoints, shadingPoint);
        dst.shadingPoints = std::make_shared<SpatialVertices>(shadingPoint);
        dst.numOfSamples = r1.numOfSamples + r2.numOfSamples;
    }

    Reservoir GetIndexedCopy(const Tensori& index, int size) const {
        Reservoir ret;
        ret.numOfSamples = size;
        ret.x = std::make_shared<LightSample>(x->GetIndexedCopy(index, size));
        ret.shadingPoints = std::make_shared<SpatialVertices>(shadingPoints->GetIndexedCopy(index));
        ret.shadingPoints->numOfSamples = size;
        ret.wAverage = IndexedRead(wAverage, index, 0);
        ret.W = IndexedRead(W, index, 0);
        ret.m = IndexedRead(m, index, 0);

        return ret;
    }

    void UpdateIndexed(const Reservoir& other, const Tensori& index) {
        x->UpdateIndexed(*other.x, index);
        shadingPoints->UpdateIndexed(*other.shadingPoints, index);
        Tensorf wAverage1 = IndexedRead(wAverage, index, 0);
        wAverage = Detach(wAverage + IndexedWrite(other.wAverage - wAverage1, index, wAverage.GetShape(), 0));

        Tensorf W1 = IndexedRead(W, index, 0);
        W = Detach(W + IndexedWrite(other.W - W1, index, W.GetShape(), 0));

        Tensori m1 = IndexedRead(m, index, 0);
        m = Detach(m + IndexedWrite(other.m - m1, index, m.GetShape(), 0));
    }
};

Tensori UpdateReservoir(Reservoir& r, const LightSample& xI,const Expr& wI, const Expr& m1, 
                     const Expr& sample);
void UpdateReservoir(Reservoir& r, const Reservoir& r1, const Expr& sample, int mMax);
void UpdateReservoir2(Reservoir& r, const Reservoir& r1, const Expr& sample);
Tensori UpdateReservoir(Reservoir& r, const Reservoir& r1, const Scene& scene, 
                    const Ray& rays, const Expr& sample);

Tensorf EvalPHat(const Scene& scene, const Ray& ray, const Intersection& curV,
              const LightSample& lightSamples, bool testShadowRay);

Tensorf EvalPHat(const Scene& scene, const Ray& ray, const SpatialVertices& curV,
              const SpatialVertices& lightSamples, bool testShadowRay);

void EvalScreenSpaceReSTIR(const Scene& scene, 
                           const SpatialVertices& vCamera, 
                           const SpatialVertices& vCur, 
                           const SpatialVertices& vLight, 
                           Tensorf& res);

} // namespace TensorRay
} // namespace EDX