/*
 * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: BSD-3-Clause
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * 1. Redistributions of source code must retain the above copyright notice, this
 * list of conditions and the following disclaimer.
 *
 * 2. Redistributions in binary form must reproduce the above copyright notice,
 * this list of conditions and the following disclaimer in the documentation
 * and/or other materials provided with the distribution.
 *
 * 3. Neither the name of the copyright holder nor the names of its
 * contributors may be used to endorse or promote products derived from
 * this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

#include "PathSampler.h"

namespace EDX
{
    namespace TensorRay
    {
        void PathSampler::SetParam(const RenderOptions& options)
        {
            mSppBatch = options.mSppInteriorBatch;
            mMaxBounces = options.mMaxBounces;

#if USE_BOX_FILTER
            mAntitheticSpp = 1;
#else
            mAntitheticSpp = 4;
#endif
        }

        void PathSampler::SamplePaths(const Scene& scene, PathSampleResult& res) const
        {
            res.mPathVertices.clear();
            res.mPathVertices.resize(mMaxBounces + 1);
            res.mNEEVertices.clear();
            res.mNEEVertices.resize(mMaxBounces);

            const Camera& camera = *scene.mSensors[0];
            Ray rays;
            // Generate (antithetic) rays
            camera.GenerateAntitheticRays(rays, mSppBatch, mAntitheticSpp);

            Intersection its;
            // Handle primary rays
            scene.IntersectHit(rays, its);

            if (rays.mNumRays > 0)
            {
                scene.PostIntersect(its);
                MaterialVertices& matV = res.mPathVertices[0];
                matV.prevId = rays.mPixelIdx;
                matV.triangleId = its.mTriangleId;
                matV.emitterId = its.mEmitterId;
                matV.bsdfId = its.mBsdfId;
                matV.u = its.mBaryU;
                matV.v = its.mBaryV;
                matV.pdf = Ones(matV.u.GetShape());
                matV.numOfSamples = rays.mNumRays;
            }

            // Handle secondary rays
            int nRaysInit = (mSppBatch / mAntitheticSpp) * camera.mResX * camera.mResY;
            for (int iBounce = 0; iBounce < mMaxBounces; iBounce++)
            {
                if (rays.mNumRays == 0) 
                    break;
                Tensorf antitheticRndLight = Tensorf::RandomFloat(Shape({ nRaysInit }, VecType::Vec2));
                Tensorf antitheticRndBsdf = Tensorf::RandomFloat(Shape({ nRaysInit }, VecType::Vec3));

                Tensorf rndLight = IndexedRead(antitheticRndLight, rays.mRayIdx % Scalar(nRaysInit), 0);
                SampleNEE(scene, rays, its, rndLight, res.mNEEVertices[iBounce]);

                // Tensorf rndBsdf = IndexedRead(antitheticRndBsdf, rays.mRayIdx % Scalar(nRaysInit), 0);
                // Ray raysNext;
                // Intersection itsNext;
                // SampleBSDF(scene, rays, its, rndBsdf, res.mPathVertices[iBounce + 1], raysNext, itsNext);
                // rays = raysNext;
                // its = itsNext;
            }
        }

        void SampleNEE(const Scene& scene, const Ray& rays, const Intersection& its, const Tensorf& rndLight, MaterialVertices& vNEE)
        {
            // Sample point on area lights
            PositionSample lightSamples;
            scene.mLights[scene.mAreaLightIndex]->Sample(rndLight, lightSamples);

            IndexMask isShadingPoint = (its.mEmitterId == Scalar(-1));
            if (isShadingPoint.sum > 0)
            {
                Intersection itsNee = its.GetMaskedCopy(isShadingPoint, true);
                lightSamples = lightSamples.GetMaskedCopy(isShadingPoint);
                Expr wi = lightSamples.p - itsNee.mPosition;
                Expr distSqr = VectorSquaredLength(wi);
                Expr dist = Sqrt(distSqr);
                wi = wi / dist;
                Expr dotShNorm = VectorDot(itsNee.mNormal, wi);
                Expr dotGeoNorm = VectorDot(itsNee.mGeoNormal, wi);

                // check if the connection is valid
                Expr connectValid;
                {
                    Expr numStable = (dist > Scalar(SHADOW_EPSILON) && Abs(dotShNorm) > Scalar(EDGE_EPSILON));
                    Expr noLightLeak = ((dotShNorm * dotGeoNorm) > Scalar(0.0f));
                    Expr visible;
                    {
                        Ray shadowRays;
                        shadowRays.mNumRays = isShadingPoint.sum;
                        shadowRays.mOrg = itsNee.mPosition;
                        shadowRays.mDir = wi;
                        shadowRays.mMin = Ones(isShadingPoint.sum) * Scalar(SHADOW_EPSILON);
                        shadowRays.mMax = dist - Scalar(SHADOW_EPSILON);
                        scene.Occluded(shadowRays, visible);
                    }
                    connectValid = numStable * noLightLeak * visible;
                }
                IndexMask isValid = (connectValid > Scalar(0));

                Expr validNEEIndex = IndexedRead(isShadingPoint.index, isValid.index, 0);
                vNEE.prevId = validNEEIndex;
                vNEE.triangleId = IndexedRead(scene.mEmitTriIdToTriIdBuffer, Mask(lightSamples.triangleId, isValid, 0), 0);
                vNEE.emitterId = Mask(lightSamples.lightId, isValid, 0);
                vNEE.bsdfId = Ones(vNEE.triangleId.GetShape()) * Scalar(-1);
                vNEE.u = Mask(lightSamples.baryU, isValid, 0);
                vNEE.v = Mask(lightSamples.baryV, isValid, 0);
                vNEE.pdf = Mask(lightSamples.pdf, isValid, 0);
                vNEE.numOfSamples = isValid.sum;
            }
        }

        void SampleBSDF(const Scene& scene, const Ray& rays, const Intersection& its, const Tensorf& rndBsdf, MaterialVertices& vShade, Ray& raysNext, Intersection& itsNext)
        {
            Tensori prevId;
            for (int iBSDF = 0; iBSDF < scene.mBSDFCount; iBSDF++)
            {
                IndexMask bsdfMask = (its.mBsdfId == Scalar(iBSDF));
                if (bsdfMask.sum == 0)
                    continue;

                Intersection itsBSDF = its.GetMaskedCopy(bsdfMask, true);
                Expr wo = Mask(-rays.mDir, bsdfMask, 0);
                Expr wi, pdf;
                scene.mBsdfs[iBSDF]->SampleOnly(itsBSDF, wo, Mask(rndBsdf, bsdfMask, 0), &wi, &pdf);

                Ray scatteredRays;
                Intersection scatteredIts;
                scatteredRays.mNumRays = bsdfMask.sum;
                scatteredRays.mOrg = itsBSDF.mPosition;
                scatteredRays.mDir = wi;
                scatteredRays.mPrevPdf = Detach(pdf);
                scatteredRays.mSpecular = make_shared<ConstantExp<bool>>(scene.mBsdfs[iBSDF]->IsDelta(), Shape({ bsdfMask.sum }));
                scatteredRays.mMin = Ones(bsdfMask.sum) * Scalar(SHADOW_EPSILON);
                scatteredRays.mMax = Ones(bsdfMask.sum) * Scalar(1e32f);
                scatteredRays.mPixelIdx = Mask(rays.mPixelIdx, bsdfMask, 0);
                scatteredRays.mRayIdx = Mask(rays.mRayIdx, bsdfMask, 0);
                scatteredRays.mThroughput = Mask(rays.mThroughput, bsdfMask, 0);
                scene.Intersect(scatteredRays, scatteredIts);

                IndexMask isHit = (scatteredIts.mTriangleId != Scalar(-1)
                    && pdf > Scalar(0.0f)
                    && scene.mBsdfs[iBSDF]->OutDirValid(itsBSDF, wi));
                if (isHit.sum > 0)
                {
                    scatteredRays = scatteredRays.GetMaskedCopy(isHit, true);
                    scatteredIts = scatteredIts.GetMaskedCopy(isHit);
                    scene.PostIntersect(scatteredIts);

                    Tensori validIndex = IndexedRead(bsdfMask.index, isHit.index, 0);
                    if (prevId.Empty())
                        prevId = validIndex;
                    else
                        prevId = Concat(prevId, validIndex, 0);

                    raysNext.Append(scatteredRays);
                    itsNext.Append(scatteredIts);
                }
            }

            vShade.prevId = prevId;
            vShade.triangleId = itsNext.mTriangleId;
            vShade.emitterId = itsNext.mEmitterId;
            vShade.bsdfId = itsNext.mBsdfId;
            vShade.u = itsNext.mBaryU;
            vShade.v = itsNext.mBaryV;
            vShade.pdf = raysNext.mPrevPdf;
        }

        void SampleEmitterDirect(const Scene& scene, 
                                 const Ray& rays, 
                                 const Intersection& its,
                                 const Tensorf& rndLight, 
                                 MaterialVertices& vLight) {
            // Sample point on area lights
            PositionSample lightSamples;
            scene.mLights[scene.mAreaLightIndex]->Sample(rndLight, lightSamples);

            // Expr wi = lightSamples.p - its.mPosition;
            // Expr distSqr = VectorSquaredLength(wi);
            // Expr dist = Sqrt(distSqr);
            // wi = wi / dist;
            // Expr dotShNorm = VectorDot(its.mNormal, wi);
            // Expr dotGeoNorm = VectorDot(its.mGeoNormal, wi);

            // Expr validNEEIndex = IndexedRead(isShadingPoint.index, isValid.index, 0);
            // vNEE.prevId = isShadingPoint.index;
            vLight.numOfSamples = rays.mNumRays;
            vLight.prevId = rays.mPixelIdx;
            vLight.triangleId = IndexedRead(scene.mEmitTriIdToTriIdBuffer, lightSamples.triangleId, 0);
            vLight.emitterId = lightSamples.lightId;
            vLight.bsdfId = Ones(vLight.triangleId.GetShape()) * Scalar(-1);
            vLight.u = lightSamples.baryU;
            vLight.v = lightSamples.baryV;
            vLight.pdf = lightSamples.pdf;
        }
    }
}
