// Copyright @yucwang 2022

#include "Direct2.h"

namespace EDX
{
    namespace TensorRay
    {
        void Direct2::SetParam(const RenderOptions& options)
        {
            // Direct Integrator has maxBounces = 1
            RenderOptions _options = options;
            if (_options.mMaxBounces != 1) {
                std::cout << "[Direct Integrator]: Options.mMaxBounces != 1, setting it to 1." << std::endl;
                _options.mMaxBounces = 1;
            }
            PathTracer::SetParam(_options);
            mPathSampler->SetParam(options);
        }

        Expr Direct2::EvalPath(const Scene& scene, const PathSpatial& paths, const PathContribTerms& h, const Shape& imageShape) const
        {
            Expr Li = Zeros(imageShape);

            Tensorf throughput = h.subThroughput[0];      // Maintained through the evaluation process. Its size should always be equal to the size of vCur.
            Tensori pixelId = paths.vShade[0].prevId;     // Maintained through the evaluation process. Its size should always be equal to the size of vCur.

            // Primary hit
            IndexMask isEmitter = (paths.vShade[0].emitterId != Scalar(-1));
            if (isEmitter.sum > 0)
            {
                Expr evalRes = IndexedRead(throughput, isEmitter.index, 0) * h.leContrib[0];
                Expr finalPixelId = IndexedRead(pixelId, isEmitter.index, 0);
                Li = Li + IndexedWrite(evalRes, finalPixelId, imageShape, 0);
            }

            const SpatialVertices& vNEE = paths.vNEE[0];
            if (vNEE.prevId.GetShape(0) > 0)
            {
                Expr evalRes = throughput * h.neeContrib[0];
                // Expr finalPixelId = IndexedRead(pixelId, vNEE.prevId, 0);
                Li = Li + IndexedWrite(evalRes, pixelId, imageShape, 0);
            }

            return Li;
        }

        void Direct2::Integrate(const Scene& scene, Tensorf& image) const
        {
            if (mSpp == 0) return;
            Timer timer;
            timer.Start();
            
            const Camera& camera = *scene.mSensors[0];
            int numRaysPerPass = camera.mResX * camera.mResY * mSppBatch;

            std::vector<SpatialVertices> vShade(mMaxBounces + 1);
            std::vector<SpatialVertices> vNEE(mMaxBounces);
            
            int npass = std::ceil(mSpp / mSppBatch);
            for (int ipass = 0; ipass < npass; ipass++)
            {
                PathSampleResult matPaths;
                if (mPath == nullptr) {
                    mPathSampler->SamplePaths(scene, matPaths);
                } else {
                    matPaths = *mPath;
                }

                PathSpatial spaPaths;
                spaPaths.vShade.resize(mMaxBounces + 1);
                spaPaths.vNEE.resize(mMaxBounces);

                PathContribTerms h;
                h.subThroughput.resize(mMaxBounces + 1);
                h.neeContrib.resize(mMaxBounces);
                h.leContrib.resize(1);

                // Spatial vertices for camera
                SpatialVertices vCamera;
                vCamera.position = Broadcast(camera.mPosTensor, Shape({ matPaths.mPathVertices[0].numOfSamples }, VecType::Vec3));

                // Eval vertices
                ConvertToSpatialVertices(scene, matPaths.mPathVertices[0], spaPaths.vShade[0]);
                ConvertToSpatialVertices(scene, matPaths.mNEEVertices[0], spaPaths.vNEE[0]);

                // Primary hit
                EvalLe0Vertices(scene, spaPaths.vShade[0], h.subThroughput[0], h.leContrib[0]);

                SpatialVertices* vPrev = &vCamera;
                SpatialVertices* vCur = &(spaPaths.vShade[0]);
                EvalScreenSpaceReSTIR(scene, *vPrev, *vCur, spaPaths.vNEE[0], h.neeContrib[0]);

                Expr Li = EvalPath(scene, spaPaths, h, image.GetShape());

                Tensorf result = Li / Scalar(mSpp);
                mGradHandler.AccumulateDeriv(result);
                if (mDLoss.Empty())
                {
                    // RenderC: update returned image
                    image = image + Detach(result);
                }
                else
                {
                    // RenderD: backward + update dervaitive image (optional)
                    result.Backward(mDLoss);
                    AccumulateGradsAndReleaseGraph();
                }

                if (mVerbose)
                    std::cout << string_format("[Direct2] #Pass %d / %d, %d kernels launched\r", ipass + 1, npass, KernelLaunchCounter::GetHandle());
                KernelLaunchCounter::Reset();
            }

            if (mVerbose)
                std::cout << string_format("[Direct2] Total Elapsed time = %f (%f samples/pass, %d pass)", timer.GetElapsedTime(), mSppBatch, npass) << std::endl;
        }

        void Direct2::SamplePath(const Scene& scene) const {
            PathSampleResult path;
            mPathSampler->SamplePaths(scene, path);
            mPath = std::make_shared<PathSampleResult>(std::move(path));
        }

        void Direct2::Step() {
            mPathSampler->Step();
        }
    }
}
