/*
 * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: BSD-3-Clause
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * 1. Redistributions of source code must retain the above copyright notice, this
 * list of conditions and the following disclaimer.
 *
 * 2. Redistributions in binary form must reproduce the above copyright notice,
 * this list of conditions and the following disclaimer in the documentation
 * and/or other materials provided with the distribution.
 *
 * 3. Neither the name of the copyright holder nor the names of its
 * contributors may be used to endorse or promote products derived from
 * this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

#include "Algorithm1.h"

namespace EDX
{
    namespace TensorRay
    {
        void ConvertToSpatialVertices(const Scene& scene, const MaterialVertices& matV, SpatialVertices& spaV)
        {
            spaV.prevId = matV.prevId;
            spaV.triangleId = matV.triangleId;
            spaV.emitterId = matV.emitterId;
            spaV.bsdfId = matV.bsdfId;
            spaV.u = matV.u;
            spaV.v = matV.v;
            spaV.pdf = matV.pdf;
            spaV.numOfSamples = matV.numOfSamples;

            if (matV.numOfSamples == 0) return;

            Expr indicesTri0 = Scalar(3) * matV.triangleId;
            Expr indicesTri1 = Scalar(3) * matV.triangleId + Scalar(1);
            Expr indicesTri2 = Scalar(3) * matV.triangleId + Scalar(2);
            Expr w = Scalar(1.f) - matV.u - matV.v;
            {
                Expr indicesPos0 = IndexedRead(scene.mIndexPosBuffer, indicesTri0, 0);
                Expr indicesPos1 = IndexedRead(scene.mIndexPosBuffer, indicesTri1, 0);
                Expr indicesPos2 = IndexedRead(scene.mIndexPosBuffer, indicesTri2, 0);
                Expr position0 = IndexedRead(scene.mPositionBuffer, indicesPos0, 0);
                Expr position1 = IndexedRead(scene.mPositionBuffer, indicesPos1, 0);
                Expr position2 = IndexedRead(scene.mPositionBuffer, indicesPos2, 0);
                spaV.position = w * position0 + matV.u * position1 + matV.v * position2;
            }

            if (scene.mTexcoordBuffer.LinearSize() > 0)
            {
                Expr indicesTex0 = IndexedRead(scene.mIndexTexBuffer, indicesTri0, 0);
                Expr indicesTex1 = IndexedRead(scene.mIndexTexBuffer, indicesTri1, 0);
                Expr indicesTex2 = IndexedRead(scene.mIndexTexBuffer, indicesTri2, 0);
                Expr texcoord0 = IndexedRead(scene.mTexcoordBuffer, indicesTex0, 0);
                Expr texcoord1 = IndexedRead(scene.mTexcoordBuffer, indicesTex1, 0);
                Expr texcoord2 = IndexedRead(scene.mTexcoordBuffer, indicesTex2, 0);
                spaV.texcoord = w * texcoord0 + matV.u * texcoord1 + matV.v * texcoord2;
            }

            spaV.geoNormal = IndexedRead(scene.mFaceNormalBuffer, matV.triangleId, 0);

            Expr useShadingNormal = IndexedRead(scene.mUseSmoothShadingBuffer, matV.triangleId, 0);
            spaV.normal = Zeros(spaV.geoNormal.GetShape());
            IndexMask shNormalMask = (useShadingNormal == True(1));
            if (shNormalMask.sum > 0)
            {
                Expr indicesNorm0 = IndexedRead(scene.mIndexNormalBuffer, Mask(indicesTri0, shNormalMask, 0), 0);
                Expr indicesNorm1 = IndexedRead(scene.mIndexNormalBuffer, Mask(indicesTri1, shNormalMask, 0), 0);
                Expr indicesNorm2 = IndexedRead(scene.mIndexNormalBuffer, Mask(indicesTri2, shNormalMask, 0), 0);
                Expr normal0 = IndexedRead(scene.mVertexNormalBuffer, indicesNorm0, 0);
                Expr normal1 = IndexedRead(scene.mVertexNormalBuffer, indicesNorm1, 0);
                Expr normal2 = IndexedRead(scene.mVertexNormalBuffer, indicesNorm2, 0);
                Expr shNormal = Mask(w, shNormalMask, 0) * normal0
                    + Mask(matV.u, shNormalMask, 0) * normal1
                    + Mask(matV.v, shNormalMask, 0) * normal2;
                spaV.normal = spaV.normal + IndexedWrite(shNormal, shNormalMask.index, spaV.geoNormal.GetShape(), 0);
            }
            IndexMask geoNormalMask = (useShadingNormal == False(1));
            if (geoNormalMask.sum > 0)
            {
                Expr geoNormal = Mask(spaV.geoNormal, geoNormalMask, 0);
                spaV.normal = spaV.normal + IndexedWrite(geoNormal, geoNormalMask.index, spaV.geoNormal.GetShape(), 0);
            }

            spaV.J = IndexedRead(scene.mTriangleAreaBuffer, matV.triangleId, 0);
            spaV.J = spaV.J / Detach(spaV.J);

            Expr tang, bitang;
            CoordinateSystem(spaV.normal, &tang, &bitang);
            spaV.tangent = tang;
            spaV.bitangent = bitang;
        }

        void EvalLe0Vertices(const Scene& scene, const SpatialVertices& vCur, Tensorf& throughput, Tensorf& le)
        {
            const Camera& camera = *scene.mSensors[0];
            throughput = camera.EvalFilter(vCur.prevId, vCur) * vCur.J;
            //pixelId = vCur.prevId;
            Expr rayDir = VectorNormalize(vCur.position - camera.mPosTensor);

            IndexMask isEmitter = (vCur.emitterId != Scalar(-1));
            if (isEmitter.sum > 0)
            {
                SpatialVertices vCurEmitter = vCur.GetMaskedCopy(isEmitter);
                le = scene.mLights[scene.mAreaLightIndex]->Eval(vCurEmitter, Mask(-rayDir, isEmitter, 0));
                //res = Mask(throughput, isEmitter, 0) * scene.mLights[scene.mAreaLightIndex]->Eval(vCurEmitter, Mask(-rayDir, isEmitter, 0));
                //activeIndex = isEmitter.index;
            }
        }

        void EvalNEEVertices(const Scene& scene, const SpatialVertices& vPrev, const SpatialVertices& vCur, const SpatialVertices& vNEE, Tensorf& nee)
        {
            if (vNEE.numOfSamples == 0) return;
            Expr _res = Zeros(vNEE.position.GetShape());

            SpatialVertices _vCur = vCur.GetIndexedCopy(vNEE.prevId);
            Expr pPrev = IndexedRead(IndexedRead(vPrev.position, vCur.prevId, 0), vNEE.prevId, 0);
            
            Expr wi = vNEE.position - _vCur.position;
            Expr distSqr = VectorSquaredLength(wi);
            Expr dist = Sqrt(distSqr);
            wi = wi / dist;
            Expr dotShNorm = VectorDot(_vCur.normal, wi);

            Expr wo = VectorNormalize(pPrev - _vCur.position);
            Expr G = Abs(VectorDot(vNEE.normal, -wi)) / distSqr;
            Expr Le = scene.mLights[scene.mAreaLightIndex]->Eval(vNEE, -wi);
            Expr lightPdf = Detach(vNEE.pdf);
            Expr tmpVal = G * Le * vNEE.J / lightPdf;

            for (int iBSDF = 0; iBSDF < scene.mBSDFCount; iBSDF++)
            {
                IndexMask isActive = (_vCur.bsdfId == Scalar(iBSDF));
                if (isActive.sum == 0)
                    continue;

                SpatialVertices v = _vCur.GetMaskedCopy(isActive);
                Expr bsdfVal = scene.mBsdfs[iBSDF]->Eval(v, Mask(wo, isActive, 0), Mask(wi, isActive, 0)) * Abs(Mask(dotShNorm, isActive, 0));
                Expr val = bsdfVal * Mask(tmpVal, isActive, 0);

                // TODO: MIS

                _res = _res + IndexedWrite(val, isActive.index, _res->GetShape(), 0);
            }

            nee = _res;
        }

        void EvalBSDFVertices(const Scene& scene, const SpatialVertices& vPrev, const SpatialVertices& vCur, const SpatialVertices& vNext, Tensorf& subThroughput)
        {
            Expr _res = Zeros(vNext.position.GetShape());

            SpatialVertices _vCur = vCur.GetIndexedCopy(vNext.prevId);
            Expr pPrev = IndexedRead(IndexedRead(vPrev.position, vCur.prevId, 0), vNext.prevId, 0);

            Expr wi = vNext.position - _vCur.position;
            Expr distSqr = VectorSquaredLength(wi);
            Expr dist = Sqrt(distSqr);
            wi = wi / dist;
            Expr dotShNorm = VectorDot(_vCur.normal, wi);

            Expr wo = VectorNormalize(pPrev - _vCur.position);
            Expr G = Abs(VectorDot(vNext.geoNormal, -wi)) / distSqr;  // Note: use geoNormal or shNormal?
            Expr bsdfPdf = Detach(vNext.pdf * G);
            Expr tmpVal = Where(G > Scalar(0.f), G * vNext.J / bsdfPdf, Scalar(0.f));

            for (int iBSDF = 0; iBSDF < scene.mBSDFCount; iBSDF++)
            {
                IndexMask isActive = (_vCur.bsdfId == Scalar(iBSDF));
                if (isActive.sum == 0)
                    continue;

                SpatialVertices v = _vCur.GetMaskedCopy(isActive);
                Expr bsdfVal = scene.mBsdfs[iBSDF]->Eval(v, Mask(wo, isActive, 0), Mask(wi, isActive, 0)) * Abs(Mask(dotShNorm, isActive, 0));
                Expr val = bsdfVal * Mask(tmpVal, isActive, 0);
                _res = _res + IndexedWrite(val, isActive.index, _res->GetShape(), 0);
            }

            subThroughput = _res;
        }
    }
}
