// Copyright @yucwang 2022

#pragma once
#include "Integrator.h"
#include "Ray.h"
#include "Records.h"

// namespace EDX
// {
//     namespace TensorRay
//     {
//         struct RISRenderOptions : public RenderOptions {
//             RISRenderOptions(): RenderOptions() {};
//             int M;
//             bool haveTemporalReuse = true;
// 
//             RISRenderOptions(int seed, int maxBounces, int spp, 
//                           int sppe, int sppse0, int sppse1, int sppe0 = 0, int M = 4, bool haveTemporalReuse = true)
//                         : RenderOptions(seed, maxBounces, spp, sppe, sppse0, sppse1, sppe0), 
//                           M(M), haveTemporalReuse(haveTemporalReuse) {}
//         };
// 
//         struct LightSample {
//             int numOfSamples;
//             PositionSample positionSample;
//             Expr pHat;
// 
//             LightSample(): numOfSamples(0) {}
//             LightSample(int size): numOfSamples(size) {
//                 pHat = Zeros(Shape({ size }, VecType::Scalar1));
//             }
// 
//             LightSample(const PositionSample& posSample, const Expr& _pHat):
//                 positionSample(posSample), pHat(_pHat) {}
// 
//             LightSample GetMaskedCopy(const IndexMask& mask) const {
//                 LightSample ret;
//                 ret.positionSample = positionSample.GetMaskedCopy(mask);
//                 ret.pHat = Mask(pHat, mask, 0);
// 
//                 return ret;
//             }
// 
//             LightSample GetDetach() const {
//                 LightSample ret;
//                 ret.positionSample = positionSample.GetDetach();
//                 ret.pHat = Detach(pHat);
// 
//                 return ret;
//             }
// 
//             void WriteToIndex(const LightSample& other, const Expr& index) {
//                 //Expr cond = Zeros(Shape({ numOfSamples }, VecType::Scalar1));
//                 //cond = cond + IndexedWrite(Scalar(1), index, cond->GetShape(), 0);
// 
//                 Expr pHat1 = IndexedRead(pHat, index, 0);
//                 Expr pHatVal = IndexedWrite(Detach(other.pHat) - pHat1, index, pHat->GetShape(), 0);
//                 pHat = pHat + pHatVal;
// 
//                 positionSample.WriteToIndex(other.positionSample, index);
//             }
// 
//             LightSample GetIndexedCopy(const Expr& index, int size) const {
//                 LightSample ret;
//                 ret.numOfSamples = size;
//                 ret.pHat = IndexedRead(pHat, index, 0);
//                 ret.positionSample = positionSample.GetIndexedCopy(index, size);
// 
//                 return ret;
//             }
// 
//             void UpdateMasked(const Expr& mask, const LightSample& newSample) {
//                 positionSample.UpdateMasked(mask, newSample.positionSample);
//                 pHat = Where(mask, newSample.pHat, pHat);
// 
//                 if (numOfSamples == 0) {
//                     numOfSamples = newSample.numOfSamples;
//                 }
//             }
// 
//             void Eval() {
//                 positionSample.Eval();
//                 pHat = Tensorf(Detach(pHat));
//             }
//         };
// 
//         class Reservoir {
//             public:
//                 Reservoir(): numOfSamples(0) {}
//                 Reservoir(int _size): numOfSamples(_size), x(_size) {
//                     wAverage = Zeros(Shape({ _size }, VecType::Scalar1));
//                     W = Zeros(Shape({ _size }, VecType::Scalar1));
//                     m = Zeros(Shape({ _size }, VecType::Scalar1));
//                     n = Zeros(Shape({ _size }, VecType::Vec3));
//                 }
// 
//                 void Update(const LightSample& xI, const Expr& wI, const Expr& m1, 
//                                     const Expr& sample, int mMax) {
//                     m = m + m1;
// 
//                     Expr haveSampleMask = (m > Scalar(0));
//                     Expr factor = Where(haveSampleMask, Scalar(1.0f) / m, Scalar(0.0f));
// 
//                     wAverage = (m - m1) * factor * wAverage + m1 * factor * wI;
// 
//                     Expr wAverageVal = Where(wAverage > Scalar(0.0f), Scalar(1.0f) / wAverage, Scalar(0.0f));
//                     Expr updateMask = (haveSampleMask * (sample < (wI * m1 * factor * wAverageVal)));
// 
//                     // Fix for the case that all M cadidate samples have p_hat = 0
//                     // Can we remove this
//                     Expr forceUpdate =  ((m1 > Scalar(0)) * (x.pHat < Scalar(1e-9f)));
//                     x.UpdateMasked(forceUpdate, xI);
// 
//                     x.UpdateMasked(updateMask, xI);
//                 }
// 
//                 void WriteToIndex(const Reservoir& other, const Expr& index) {
//                     // Expr cond = Zeros(Shape({ numOfSamples }, VecType::Scalar1));
//                     // cond = cond + IndexedWrite(Scalar(1), index, cond->GetShape(), 0);
// 
//                     Expr m1 = IndexedRead(m, index, 0);
//                     Expr mVal = IndexedWrite(Detach(other.m) - m1, index, m.GetShape(), 0);
//                     m = m + mVal;
// 
//                     Expr W1 = IndexedRead(W, index, 0);
//                     Expr WVal = IndexedWrite(Detach(other.W) - W1, index, W.GetShape(), 0);
//                     W = W + WVal;
// 
//                     Expr wAverage1 = IndexedRead(wAverage, index, 0);
//                     Expr wAverageVal = IndexedWrite(Detach(other.wAverage) - wAverage1, index, wAverage.GetShape(), 0);
//                     wAverage = wAverage + wAverageVal;
// 
//                     Expr n1 = IndexedRead(n, index, 0);
//                     Expr nVal = IndexedWrite(Detach(other.n) - n1, index, n.GetShape(), 0);
//                     n = n + nVal;
// 
//                     x.WriteToIndex(other.x, index);
//                 }
// 
//                 Reservoir GetIndexedCopy(const Expr& index, int size) const {
//                     Reservoir ret;
//                     ret.numOfSamples = size;
//                     ret.m = IndexedRead(m, index, 0);
//                     ret.wAverage = IndexedRead(wAverage, index, 0);
//                     ret.W = IndexedRead(W, index, 0);
//                     ret.n = IndexedRead(n, index, 0);
//                     ret.x = x.GetIndexedCopy(index, size);
// 
//                     return ret;
//                 }
// 
//                 void Update(const Reservoir& other, const Expr& sample, int mMax) {
//                     Expr G = (VectorDot(n, other.n));
//                     G = Tensorf(G);
//                     Expr nFilter = ((G > Scalar(0.5f)));
// 
//                     Expr m1 = Where(nFilter, other.m, Scalar(0));
//                     Expr W1 = Where(nFilter, other.W, Scalar(0.0f));
// 
//                     Update(other.x, W1 * other.x.pHat, m1, sample, mMax);
//                     W = Where(wAverage > Scalar(0.0f), wAverage / x.pHat, Scalar(0.0f));
//                     m = Where(m > Scalar(mMax * 20), Scalar(mMax * 20), m);
//                 }
// 
//                 void VisibilityFilter(const Scene& scene, const Ray& ray,
//                                       const Intersection& its) {
//                     Expr visibility = Zeros(Shape({ numOfSamples }, VecType::Scalar1));
//                     for (int iBSDF = 0; iBSDF < scene.mBSDFCount; ++iBSDF) 
//                     {
//                         IndexMask mask_bsdf = its.mBsdfId == Scalar(iBSDF);
//                         Expr condBsdf = (its.mBsdfId == Scalar(iBSDF));
//                         if (mask_bsdf.sum == 0) continue;
//                         Intersection its_nee = its.GetMaskedCopy(mask_bsdf, true);
//                         PositionSample light_sample = x.positionSample.GetMaskedCopy(mask_bsdf);
//                         Expr wi = light_sample.p - its_nee.mPosition;
//                         auto dist_sqr = VectorSquaredLength(wi);
//                         auto dist = Sqrt(dist_sqr);
// 
//                         wi = wi / dist;
//                         auto dotShNorm = VectorDot(its_nee.mNormal, wi);
//                         auto dotGeoNorm = VectorDot(its_nee.mGeoNormal, wi);
// 
//                         // check if the connection is valid
//                         Expr connectValid;
//                         {
//                             auto numStable = (dist > Scalar(SHADOW_EPSILON) && Abs(dotShNorm) > Scalar(EDGE_EPSILON));
//                             auto noLightLeak = ((dotShNorm * dotGeoNorm) > Scalar(0.0f));
//                             auto faceCheck = (dotShNorm > Scalar(0.0f));
//                             Expr visible;
//                             {
//                                 Ray shadowRays;
//                                 shadowRays.mNumRays = mask_bsdf.sum;
//                                 shadowRays.mOrg = its_nee.mPosition;
//                                 shadowRays.mDir = wi;
//                                 shadowRays.mMin = Ones(mask_bsdf.sum) * Scalar(SHADOW_EPSILON);
//                                 shadowRays.mMax = dist - Scalar(SHADOW_EPSILON);
//                                 scene.Occluded(shadowRays, visible);
//                             }
//                             connectValid =  visible;
//                         }
// 
//                         Expr val = IndexedWrite(connectValid, mask_bsdf.index, visibility->GetShape(), 0);
//                         visibility = visibility + val;
//                     }
// 
//                     Expr filter = (visibility > Scalar(0.0f));
//                     W = Where(filter, W, Scalar(0.0f));
//                     wAverage = Where(filter, wAverage, Scalar(0.0f));
//                 }
// 
//                 Expr Shade(const Scene& scene, const Ray& rays, 
//                            const Intersection& its) const {
//                     Expr li = Zeros(Shape({ rays.mNumRays }, VecType::Vec3));
//                     for (int iBSDF = 0; iBSDF < scene.mBSDFCount; ++iBSDF) 
//                     {
//                         IndexMask mask_bsdf = its.mBsdfId == Scalar(iBSDF);
//                         Expr condBsdf = (its.mBsdfId == Scalar(iBSDF));
//                         if (mask_bsdf.sum == 0) continue;
//                         Intersection its_nee = its.GetMaskedCopy(mask_bsdf, true);
//                         PositionSample light_sample = x.positionSample.GetMaskedCopy(mask_bsdf);
//                         Expr wi = light_sample.p - its_nee.mPosition;
//                         auto dist_sqr = VectorSquaredLength(wi);
//                         auto dist = Sqrt(dist_sqr);
// 
//                         wi = wi / dist;
//                         auto dotShNorm = VectorDot(its_nee.mNormal, wi);
//                         auto dotGeoNorm = VectorDot(its_nee.mGeoNormal, wi);
// 
//                         // check if the connection is valid
//                         Expr connectValid;
//                         {
//                             auto numStable = (dist > Scalar(SHADOW_EPSILON) && Abs(dotShNorm) > Scalar(EDGE_EPSILON));
//                             auto noLightLeak = ((dotShNorm * dotGeoNorm) > Scalar(0.0f));
//                             Expr visible;
//                             {
//                                 Ray shadowRays;
//                                 shadowRays.mNumRays = mask_bsdf.sum;
//                                 shadowRays.mOrg = its_nee.mPosition;
//                                 shadowRays.mDir = wi;
//                                 shadowRays.mMin = Ones(mask_bsdf.sum) * Scalar(SHADOW_EPSILON);
//                                 shadowRays.mMax = dist - Scalar(SHADOW_EPSILON);
//                                 scene.Occluded(shadowRays, visible);
//                             }
//                             connectValid = numStable * noLightLeak * visible;
//                         }
//                         // evaluate the direct lighting from light sampling
//                         Expr throughput = Mask(rays.mThroughput, mask_bsdf, 0);
//                         Expr wo = Mask(-rays.mDir, mask_bsdf, 0);
//                         Expr nDotWi = VectorDot(light_sample.n, -wi);
//                         Expr G = Where(nDotWi > Scalar(0.0f), nDotWi, Scalar(0.0f)) / dist_sqr;
//                         // Expr G = Abs(VectorDot(light_sample.n, -wi)) / dist_sqr;
//                         Expr bsdfVal = scene.mBsdfs[iBSDF]->Eval(its_nee, wo, wi) * Abs(dotShNorm);
//                         Expr Le = scene.mLights[scene.mAreaLightIndex]->Eval(light_sample, -wi);
//                         Expr val = connectValid * throughput * bsdfVal * light_sample.J * G * Le;
//                         li = li + IndexedWrite(val, mask_bsdf.index, li->GetShape(), 0);
//                     }
//                     return NaNToZero(Detach(W) * li);
//                 }
// 
//                 void Eval() {
//                     x.Eval();
//                 }
// 
//                 void Reset() {
//                     m = Zeros(m.GetShape());
//                     n = Zeros(n.GetShape());
//                     wAverage = Zeros(wAverage.GetShape());
//                     W = Zeros(W.GetShape());
//                     numOfSamples = 0;
//                     return;
//                 }
// 
//             public:
//                 int numOfSamples;
//                 LightSample x;
//                 Tensorf wAverage;
//                 Tensorf W;
//                 Tensori m;
//                 Tensorf n;
//         };
// 
//         class RISPathTracer : public Integrator
//         {
//         public:
//             RISPathTracer()
//             {
// #if USE_BOX_FILTER
//                 mAntitheticSpp = 1;
// #else
//                 mAntitheticSpp = 4;
// #endif
//             }
// 
//             void SetParam(const RenderOptions& options) 
//             {
//                 const RISRenderOptions& renderOptions = dynamic_cast<const RISRenderOptions&>(options);
//                 mSpp = renderOptions.mSppInterior;
//                 mSppBatch = renderOptions.mSppInteriorBatch;
//                 mMaxBounces = renderOptions.mMaxBounces;
//                 mVerbose = !renderOptions.mQuiet;
//                 M = renderOptions.M;
//                 haveTempoarlReuse = renderOptions.haveTemporalReuse;
//             }
// 
//             Tensorf RenderC(const Scene& scene, const RenderOptions& options) override
//             {
//                 const Camera& camera = *scene.mSensors[0];
//                 Tensorf ret = Zeros(Shape({ camera.GetFilmSizeX() * camera.GetFilmSizeY() }, VecType::Vec3));
//                 SetParam(options);
//                 mDLoss.Free();
//                 Integrate(scene, ret);
//                 return ret;
//             }
// 
//             void Integrate(const Scene& scene, Tensorf& image) const;
// 
//             Expr Radiance(const Scene& scene, Ray& rays, Tensorf& image) const;
// 
//         private:
//             void SampleRadianceDirect(const Scene& scene, const Ray& rays,
//                     const Intersection& its, const Tensorf& rndLight,
//                     LightSample& lightSample) const;
//             Expr RISPathTracer::EvalPHat(const Scene& scene, const Ray& ray, const Intersection& its, 
//                                             const LightSample& lightSamples) const;
//         private:
//             mutable Reservoir reservoir;
//             int mAntitheticSpp;
//             int M = 4;
//             bool haveTempoarlReuse = false;
//         };
//     }
// }
