/*
 * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: BSD-3-Clause
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * 1. Redistributions of source code must retain the above copyright notice, this
 * list of conditions and the following disclaimer.
 *
 * 2. Redistributions in binary form must reproduce the above copyright notice,
 * this list of conditions and the following disclaimer in the documentation
 * and/or other materials provided with the distribution.
 *
 * 3. Neither the name of the copyright holder nor the names of its
 * contributors may be used to endorse or promote products derived from
 * this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

#include "Path2.h"

namespace EDX
{
    namespace TensorRay
    {
        void Path2::SetParam(const RenderOptions& options)
        {
            PathTracer::SetParam(options);
            mPathSampler.SetParam(options);
        }

        Expr Path2::EvalPath(const Scene& scene, const PathSpatial& paths, const PathContribTerms& h, const Shape& imageShape) const
        {
            Expr Li = Zeros(imageShape);

            Tensorf throughput = h.subThroughput[0];      // Maintained through the evaluation process. Its size should always be equal to the size of vCur.
            Tensori pixelId = paths.vShade[0].prevId;     // Maintained through the evaluation process. Its size should always be equal to the size of vCur.

            // Primary hit
            IndexMask isEmitter = (paths.vShade[0].emitterId != Scalar(-1));
            if (isEmitter.sum > 0)
            {
                Expr evalRes = IndexedRead(throughput, isEmitter.index, 0) * h.leContrib[0];
                Expr finalPixelId = IndexedRead(pixelId, isEmitter.index, 0);
                Li = Li + IndexedWrite(evalRes, finalPixelId, imageShape, 0);
            }
// 
            for (int iBounce = 0; iBounce < mMaxBounces; iBounce++)
            {
                // NEE
                const SpatialVertices& vNEE = paths.vNEE[iBounce];
                if (vNEE.prevId.GetShape(0) > 0)
                {
                    Expr evalRes = IndexedRead(throughput, vNEE.prevId, 0) * h.neeContrib[iBounce];
                    Expr finalPixelId = IndexedRead(pixelId, vNEE.prevId, 0);
                    Li = Li + IndexedWrite(evalRes, finalPixelId, imageShape, 0);
                }

                // BSDF sampling
                // const SpatialVertices& vNext = paths.vShade[iBounce + 1];
                // if (vNext.prevId.GetShape(0) > 0)
                // {
                //     throughput = IndexedRead(throughput, vNext.prevId, 0) * h.subThroughput[iBounce + 1];
                //     pixelId = IndexedRead(pixelId, vNext.prevId, 0);
                // }
            }

            return Li / Scalar(mSpp);
        }

        void Path2::Integrate(const Scene& scene, Tensorf& image) const
        {
            if (mSpp == 0) return;
            Timer timer;
            timer.Start();
            
            const Camera& camera = *scene.mSensors[0];
            int numRaysPerPass = camera.mResX * camera.mResY * mSppBatch;

            std::vector<SpatialVertices> vShade(mMaxBounces + 1);
            std::vector<SpatialVertices> vNEE(mMaxBounces);
            
            int npass = std::ceil(mSpp / mSppBatch);
            for (int ipass = 0; ipass < npass; ipass++)
            {
                PathSampleResult matPaths;
                mPathSampler.SamplePaths(scene, matPaths);

                PathSpatial spaPaths;
                spaPaths.vShade.resize(mMaxBounces + 1);
                spaPaths.vNEE.resize(mMaxBounces);

                PathContribTerms h;
                h.subThroughput.resize(mMaxBounces + 1);
                h.neeContrib.resize(mMaxBounces);
                h.leContrib.resize(1);

                // Spatial vertices for camera
                SpatialVertices vCamera;
                vCamera.position = Broadcast(camera.mPosTensor, Shape({ numRaysPerPass }, VecType::Vec3));

                // Eval vertices
                for (int iBounce = 0; iBounce < mMaxBounces; iBounce++)
                    ConvertToSpatialVertices(scene, matPaths.mPathVertices[0], spaPaths.vShade[0]);
                for (int iBounce = 0; iBounce < mMaxBounces; iBounce++)
                    ConvertToSpatialVertices(scene, matPaths.mNEEVertices[0], spaPaths.vNEE[0]);

                // Primary hit
                EvalLe0Vertices(scene, spaPaths.vShade[0], h.subThroughput[0], h.leContrib[0]);

                SpatialVertices* vPrev = &vCamera;
                SpatialVertices* vCur = &(spaPaths.vShade[0]);
                for (int iBounce = 0; iBounce < mMaxBounces; iBounce++)
                {
                    // NEE
                    EvalNEEVertices(scene, *vPrev, *vCur, spaPaths.vNEE[iBounce], h.neeContrib[iBounce]);

                    // BSDF sampling
                    // EvalBSDFVertices(scene, *vPrev, *vCur, spaPaths.vShade[iBounce + 1], h.subThroughput[iBounce + 1]);
                    // vPrev = vCur;
                    // vCur = &(spaPaths.vShade[iBounce + 1]);
                }

                Expr Li = EvalPath(scene, spaPaths, h, image.GetShape());

                Tensorf result = Li;
                mGradHandler.AccumulateDeriv(result);
                if (mDLoss.Empty())
                {
                    // RenderC: update returned image
                    image = image + Detach(result);
                }
                else
                {
                    // RenderD: backward + update dervaitive image (optional)
                    result.Backward(mDLoss);
                    AccumulateGradsAndReleaseGraph();
                }

                if (mVerbose)
                    std::cout << string_format("[Path2] #Pass %d / %d, %d kernels launched\r", ipass + 1, npass, KernelLaunchCounter::GetHandle());
                KernelLaunchCounter::Reset();
            }
            if (mVerbose)
                std::cout << string_format("[Path2] Total Elapsed time = %f (%f samples/pass, %d pass)", timer.GetElapsedTime(), mSppBatch, npass) << std::endl;
        }
    }
}
