// Copyright @yucwang 2022

#include "RISPathTracer.h"

// namespace EDX
// {
//     namespace TensorRay
//     {
//         Expr RISPathTracer::Radiance(const Scene& scene, Ray& rays, Tensorf& image) const
//         {
//             Expr contrib = Zeros(image.GetShape());
//             int screenSpaceSize = image.GetShape(0);
//             int nRaysInit = (mSppBatch / mAntitheticSpp) * image.GetShape(0);
//              // Path tracing
//              Intersection its;
//              // Emitted Radiance
//              scene.IntersectHit(rays, its);
//              if (rays.mNumRays > 0)
//              {
//                  scene.PostIntersect(its);
//                  const Camera& camera = *scene.mSensors[0];
//                  rays.mThroughput = camera.EvalFilter(rays.mPixelIdx, its) * its.mJ;
//                  rays.mDir = VectorNormalize(its.mPosition - rays.mOrg);
//                  Expr Le = EvalRadianceEmitted(scene, rays, its);
//                  contrib = contrib + IndexedWrite(Le, rays.mPixelIdx, image.GetShape(), 0);
//             }
//          
//             // Light Sampling
//             Reservoir curReservoir(rays.mNumRays);
//             curReservoir.n = its.mNormal;
//             for (int i = 0; i < M; ++i) {
//                 if (rays.mNumRays > 0) 
//                 {
//                     LightSample lightSample(rays.mNumRays);
//                     Tensorf antitheticRnd_light = Tensorf::RandomFloat(
//                         Shape({ nRaysInit }, VecType::Vec2));
//                     Expr rnd_light = IndexedRead(antitheticRnd_light, 
//                                                  rays.mRayIdx % Scalar(nRaysInit), 0);
//                     SampleRadianceDirect(scene, rays, its, rnd_light, lightSample);
//                     Expr reservoirSample = Tensorf::RandomFloat(
//                         Shape({ nRaysInit }, VecType::Scalar1));
//                     Expr rndReservoir = IndexedRead(reservoirSample, rays.mRayIdx % Scalar(nRaysInit), 0);
//                     curReservoir.Update(lightSample, 
//                                      lightSample.pHat / lightSample.positionSample.pdf, Scalar(1),
//                                      rndReservoir, M);
//                 }
//             }
//             curReservoir.W = Where(curReservoir.x.pHat > Scalar(0.0f), 
//                                 curReservoir.wAverage / curReservoir.x.pHat, Scalar(0.0f));
//             
//             if (haveTempoarlReuse) {
//                 curReservoir.VisibilityFilter(scene, rays, its);
//                 if (reservoir.numOfSamples > 0) {
//                     Reservoir previousReservoir = reservoir.GetIndexedCopy(rays.mPixelIdx, rays.mNumRays);
//                     Expr newPHat = EvalPHat(scene, rays, its, previousReservoir.x);
//                     previousReservoir.x.pHat = newPHat;
//                     Expr reservoirSample = Tensorf::RandomFloat(
//                         Shape({ nRaysInit }, VecType::Scalar1));
//                     Expr rndReservoir = IndexedRead(reservoirSample, rays.mRayIdx % Scalar(nRaysInit), 0);
//                     curReservoir.Update(previousReservoir, rndReservoir, M);
//                     
//                     // Check if we can do it along with shading
//                     curReservoir.VisibilityFilter(scene, rays, its);
//                 }
//             }
//             curReservoir.Eval();
//             contrib = contrib + IndexedWrite(curReservoir.Shade(scene, rays, its), 
//                                         rays.mPixelIdx, image.GetShape(), 0);
//             contrib = contrib * Scalar(1.0f / float(mSpp));
// 
//             if (haveTempoarlReuse) {
//                 if (reservoir.numOfSamples == 0) {
//                     reservoir = Reservoir(screenSpaceSize);
//                     reservoir.x.numOfSamples = screenSpaceSize;
//                     reservoir.numOfSamples = screenSpaceSize;
//                     reservoir.x.positionSample.mNumSample = screenSpaceSize;
//                     reservoir.x.positionSample.lightId = Zeros(Shape({ screenSpaceSize }, curReservoir.x.positionSample.lightId->GetShape().mVecType));
//                     reservoir.x.positionSample.J = Zeros(Shape({ screenSpaceSize }, curReservoir.x.positionSample.J->GetShape().mVecType));
//                     reservoir.x.positionSample.n = Zeros(Shape({ screenSpaceSize }, curReservoir.x.positionSample.n->GetShape().mVecType));
//                     reservoir.x.positionSample.p = Zeros(Shape({ screenSpaceSize }, curReservoir.x.positionSample.p->GetShape().mVecType));
//                     reservoir.x.positionSample.pdf = Zeros(Shape({ screenSpaceSize }, curReservoir.x.positionSample.pdf->GetShape().mVecType));
//                     reservoir.x.positionSample.triangleId = Zeros(Shape({ screenSpaceSize }, curReservoir.x.positionSample.triangleId->GetShape().mVecType));
//                     reservoir.x.positionSample.baryU = Zeros(Shape({ screenSpaceSize }, curReservoir.x.positionSample.baryU->GetShape().mVecType));
//                     reservoir.x.positionSample.baryV = Zeros(Shape({ screenSpaceSize }, curReservoir.x.positionSample.baryV->GetShape().mVecType));
//                 }
//                 reservoir.Eval();
//                 reservoir.WriteToIndex(curReservoir, rays.mPixelIdx);
//             }
//             return contrib;
//         }
// 
//         void RISPathTracer::Integrate(const Scene& scene, Tensorf& image) const
//         {
// #if USE_PROFILING
//             nvtxRangePushA(__FUNCTION__);
// #endif
//             if (mSpp == 0) return;
//             const Camera& camera = *scene.mSensors[0];
//             Timer timer;
//             timer.Start();
//             // For output derivative image
//             int npass = mSpp / mSppBatch;
// 
//             for (int ipass = 0; ipass < npass; ipass++)
//             {
//                 Expr contrbPass = Zeros(image.GetShape());
// 
//                 Ray rays;
//                 // Generate antithetic rays
//                 camera.GenerateAntitheticRays(rays, mSppBatch, mAntitheticSpp);
//                 contrbPass = contrbPass + Radiance(scene, rays, image);
//                 
//                 Tensorf result = contrbPass;
//                 mGradHandler.AccumulateDeriv(result);
//                 if (mDLoss.Empty()) 
//                 {
//                     // RenderC: update returned image
//                     image = image + Detach(result);
//                 } 
//                 else 
//                 {
//                     // RenderD: backward + update dervaitive image (optional)
//                     result.Backward(mDLoss);
//                     AccumulateGradsAndReleaseGraph();
//                 }
// 
//                 if (mVerbose)
//                     std::cout << string_format("[RISPathTracer] #Pass %d / %d, %d kernels launched\r", ipass + 1, npass, KernelLaunchCounter::GetHandle());
//                 KernelLaunchCounter::Reset();
//             }
//             if (mVerbose)
//                 std::cout << string_format("[RISPathTracer] Total Elapsed time = %f (%d samples/pass, %d pass)", timer.GetElapsedTime(), mSppBatch, npass) << std::endl;
// 
// #if USE_PROFILING
//             nvtxRangePop();
// #endif
//         }
// 
//         Expr RISPathTracer::EvalPHat(const Scene& scene, const Ray& rays, 
//                         const Intersection& its, const LightSample& lightSamples) const {
//             Expr pHat = Zeros(Shape({ rays.mNumRays }, VecType::Scalar1));
//             pHat = pHat + Scalar(EPSILON);
//             // Sample point on area light
//             for (int iBSDF = 0; iBSDF < scene.mBSDFCount; ++iBSDF) 
//             {
//                 IndexMask mask_bsdf = its.mBsdfId == Scalar(iBSDF);
//                 Expr condBsdf = (its.mBsdfId == Scalar(iBSDF));
//                 if (mask_bsdf.sum == 0) continue;
//                 Intersection its_nee = its.GetMaskedCopy(mask_bsdf, true);
//                 PositionSample light_sample = lightSamples.positionSample.GetMaskedCopy(mask_bsdf);
//                 Expr wi = light_sample.p - its_nee.mPosition;
//                 auto dist_sqr = VectorSquaredLength(wi);
//                 auto dist = Sqrt(dist_sqr);
// 
//                 wi = wi / dist;
//                 auto dotShNorm = VectorDot(its_nee.mNormal, wi);
//                 auto dotGeoNorm = VectorDot(its_nee.mGeoNormal, wi);
// 
//                 // Compute the weight of the light samples
//                 // Expr throughput = Mask(rays.mThroughput, mask_bsdf, 0);
//                 Expr wo = Mask(-rays.mDir, mask_bsdf, 0);
//                 Expr nDotWi = VectorDot(light_sample.n, -wi);
//                 Expr G = Where(nDotWi > Scalar(0.0f), nDotWi, Scalar(0.0f)) / dist_sqr;
//                 Expr bsdfVal = scene.mBsdfs[iBSDF]->Eval(its_nee, wo, wi) * Abs(dotShNorm);
//                 Expr Le = scene.mLights[scene.mAreaLightIndex]->Eval(light_sample, -wi);
//                 pHat = pHat + IndexedWrite(Luminance(bsdfVal * G * Le), mask_bsdf.index, pHat->GetShape(), 0);
//             }
// 
//             return pHat;
//         }
// 
//         void RISPathTracer::SampleRadianceDirect(const Scene& scene, 
//                                 const Ray& rays, const Intersection& its, 
//                                 const Tensorf& rnd_light, LightSample& lightSamples) const
//         {
//             scene.mLights[scene.mAreaLightIndex]->Sample(rnd_light, lightSamples.positionSample);
//             for (int iBSDF = 0; iBSDF < scene.mBSDFCount; ++iBSDF) 
//             {
//                 IndexMask mask_bsdf = its.mBsdfId == Scalar(iBSDF);
//                 Expr condBsdf = (its.mBsdfId == Scalar(iBSDF));
//                 if (mask_bsdf.sum == 0) continue;
//                 Intersection its_nee = its.GetMaskedCopy(mask_bsdf, true);
//                 PositionSample light_sample = lightSamples.positionSample.GetMaskedCopy(mask_bsdf);
//                 Expr wi = light_sample.p - its_nee.mPosition;
//                 auto dist_sqr = VectorSquaredLength(wi);
//                 auto dist = Sqrt(dist_sqr);
// 
//                 wi = wi / dist;
//                 auto dotShNorm = VectorDot(its_nee.mNormal, wi);
//                 auto dotGeoNorm = VectorDot(its_nee.mGeoNormal, wi);
// 
//                 // Compute the weight of the light samples
//                 Expr wo = Mask(-rays.mDir, mask_bsdf, 0);
//                 Expr nDotWi = VectorDot(light_sample.n, -wi);
//                 nDotWi = Tensorf(nDotWi);
//                 Expr G = Where(nDotWi > Scalar(0.0f), nDotWi, Scalar(0.0f)) / dist_sqr;
//                 Expr bsdfVal = scene.mBsdfs[iBSDF]->Eval(its_nee, wo, wi) * Abs(dotShNorm);
//                 Expr Le = scene.mLights[scene.mAreaLightIndex]->Eval(light_sample, -wi);
//                 lightSamples.pHat = lightSamples.pHat + 
//                             IndexedWrite(Luminance(bsdfVal * G * Le), mask_bsdf.index,
//                             lightSamples.pHat->GetShape(), 0);
//             }
//         }
//     }
// }