/*
 * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: BSD-3-Clause
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * 1. Redistributions of source code must retain the above copyright notice, this
 * list of conditions and the following disclaimer.
 *
 * 2. Redistributions in binary form must reproduce the above copyright notice,
 * this list of conditions and the following disclaimer in the documentation
 * and/or other materials provided with the distribution.
 *
 * 3. Neither the name of the copyright holder nor the names of its
 * contributors may be used to endorse or promote products derived from
 * this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

#pragma once
#include "../../Tensor/Tensor.h"
using namespace EDX;
using namespace EDX::DeepLearning;

namespace EDX
{
    namespace TensorRay
    {
        struct MaterialVertices
        {
            Tensori prevId;  // Tracking the reduction made by 1) path tracing or 2) NEE shadow test
            Tensori triangleId;
            Tensori emitterId;
            Tensori bsdfId;
            Tensorf u;
            Tensorf v;
            Tensorf pdf;     // NEE sampling pdf or BSDF sampling pdf
            int numOfSamples;

            MaterialVertices(): numOfSamples(0) {}
            MaterialVertices(int _numOfSamples): numOfSamples(_numOfSamples) {
                prevId = Zeros(Shape({ numOfSamples }, VecType::Scalar1));
                triangleId = Zeros(Shape({ numOfSamples }, VecType::Scalar1));
                emitterId = Zeros(Shape({ numOfSamples }, VecType::Scalar1));
                bsdfId = Zeros(Shape({ numOfSamples }, VecType::Scalar1));
                u = Zeros(Shape({ numOfSamples }, VecType::Scalar1));
                v = Zeros(Shape({ numOfSamples }, VecType::Scalar1));
                pdf = Zeros(Shape({ numOfSamples }, VecType::Scalar1));
            }

            static void Combine(const MaterialVertices& v1, const MaterialVertices& v2, MaterialVertices& dst) {
                if (v1.numOfSamples == 0) {
                    dst.prevId = v2.prevId;
                    dst.triangleId = v2.triangleId;
                    dst.emitterId = v2.emitterId;
                    dst.bsdfId = v2.bsdfId;
                    dst.u = v2.u;
                    dst.v = v2.v;
                    dst.pdf = v2.pdf;
                } else if (v2.numOfSamples == 0) {
                    dst.prevId = v1.prevId;
                    dst.triangleId = v1.triangleId;
                    dst.emitterId = v1.emitterId;
                    dst.bsdfId = v1.bsdfId;
                    dst.u = v1.u;
                    dst.v = v1.v;
                    dst.pdf = v1.pdf;
                } else {
                    dst.prevId = Detach(Concat(v1.prevId, v2.prevId, 0));
                    dst.triangleId = Detach(Concat(v1.triangleId, v2.triangleId, 0));
                    dst.emitterId = Detach(Concat(v1.emitterId, v2.emitterId, 0));
                    dst.bsdfId = Detach(Concat(v1.bsdfId, v2.bsdfId, 0));
                    dst.u = Detach(Concat(v1.u, v2.u, 0));
                    dst.v = Detach(Concat(v1.v, v2.v, 0));
                    dst.pdf = Detach(Concat(v1.pdf, v2.pdf, 0));
                }
                dst.numOfSamples = v1.numOfSamples + v2.numOfSamples;
            }

            MaterialVertices GetMaskedCopy(const IndexMask& mask) const
            {
                MaterialVertices res;
                res.numOfSamples = mask.sum;
                res.prevId = Mask(prevId, mask, 0);
                res.triangleId = Mask(triangleId, mask, 0);
                res.emitterId = Mask(emitterId, mask, 0);
                res.bsdfId = Mask(bsdfId, mask, 0);
                res.u = Mask(u, mask, 0);
                res.v = Mask(v, mask, 0);
                res.pdf = Mask(pdf, mask, 0);
                return res;
            }

            MaterialVertices GetIndexedCopy(const Tensori& index) const
            {
                MaterialVertices res;
                res.prevId = IndexedRead(prevId, index, 0);
                res.triangleId = IndexedRead(triangleId, index, 0);
                res.emitterId = IndexedRead(emitterId, index, 0);
                res.bsdfId = IndexedRead(bsdfId, index, 0);
                res.u = IndexedRead(u, index, 0);
                res.v = IndexedRead(v, index, 0);
                res.pdf = IndexedRead(pdf, index, 0);
                return res;
            }

            void UpdateMasked(const MaterialVertices& other, const Expr& mask) {
                prevId = Where(mask, other.prevId, prevId);
                triangleId = Where(mask, other.triangleId, triangleId);
                emitterId = Where(mask, other.emitterId, emitterId);
                bsdfId = Where(mask, other.bsdfId, bsdfId);
                u = Where(mask, other.u, u);
                v = Where(mask, other.v, v);
                pdf = Where(mask, other.pdf, pdf);
            }

            void UpdateIndexed(const MaterialVertices& other, const Tensori& index) {
                Tensori prevId1 = IndexedRead(prevId, index, 0);
                prevId = Detach(prevId + IndexedWrite(other.prevId - prevId1, index, prevId.GetShape(), 0));

                Tensori triangleId1 = IndexedRead(triangleId, index, 0);
                triangleId = Detach(triangleId + IndexedWrite(other.triangleId - triangleId1, index, triangleId.GetShape(), 0));

                Tensori emitterId1 = IndexedRead(emitterId, index, 0);
                emitterId = Detach(emitterId + IndexedWrite(other.emitterId - emitterId1, index, emitterId.GetShape(), 0));

                Tensori bsdfId1 = IndexedRead(bsdfId, index, 0);
                bsdfId = Detach(bsdfId + IndexedWrite(other.bsdfId - bsdfId1, index, bsdfId.GetShape(), 0));

                Tensorf u1 = IndexedRead(u, index, 0);
                u = Detach(u + IndexedWrite(other.u - u1, index, u.GetShape(), 0));

                Tensorf v1 = IndexedRead(v, index, 0);
                v = Detach(v + IndexedWrite(other.v - v1, index, v.GetShape(), 0));

                Tensorf pdf1 = IndexedRead(pdf, index, 0);
                pdf = Detach(pdf + IndexedWrite(other.pdf - pdf1, index, pdf.GetShape(), 0));
            }
        };

        struct SpatialVertices : public MaterialVertices
        {
            Tensorf position;
            Tensorf normal;
            Tensorf geoNormal;
            Tensorf texcoord;
            Tensorf tangent;
            Tensorf bitangent;
            Tensorf J;

            SpatialVertices(): MaterialVertices() {}
            SpatialVertices(const MaterialVertices& matV) : MaterialVertices(matV) {}

            SpatialVertices(int numOfSamples): MaterialVertices(numOfSamples) {
                position = Zeros(Shape({ numOfSamples }, VecType::Vec3));
                normal = Zeros(Shape({ numOfSamples }, VecType::Vec3));
                geoNormal = Zeros(Shape({ numOfSamples }, VecType::Vec3));
                texcoord = Zeros(Shape({ numOfSamples }, VecType::Vec2));
                tangent = Zeros(Shape({ numOfSamples }, VecType::Vec3));
                bitangent = Zeros(Shape({ numOfSamples }, VecType::Vec3));
                J = Zeros(Shape({ numOfSamples }, VecType::Scalar1));
            }

            static void Combine(const SpatialVertices& v1, const SpatialVertices& v2, SpatialVertices& dst) {
                if (v1.numOfSamples == 0) {
                    dst.position = v2.position;
                    dst.normal = v2.normal;
                    dst.geoNormal = v2.geoNormal;
                    dst.texcoord = v2.texcoord;
                    dst.tangent = v2.tangent;
                    dst.bitangent = v2.bitangent;
                    dst.J = v2.J;
                } else if (v2.numOfSamples == 0) {
                    dst.position = v1.position;
                    dst.normal = v1.normal;
                    dst.geoNormal = v1.geoNormal;
                    dst.texcoord = v1.texcoord;
                    dst.tangent = v1.tangent;
                    dst.bitangent = v1.bitangent;
                    dst.J = v1.J;
                } else {
                    dst.position = Detach(Concat(v1.position, v2.position, 0));
                    dst.normal = Detach(Concat(v1.normal, v2.normal, 0));
                    dst.geoNormal = Detach(Concat(v1.geoNormal, v2.geoNormal, 0));
                    dst.texcoord = Detach(Concat(v1.texcoord, v2.texcoord, 0));
                    dst.tangent = Detach(Concat(v1.tangent, v2.tangent, 0));
                    dst.bitangent = Detach(Concat(v1.bitangent, v2.bitangent, 0));
                    dst.J = Detach(Concat(v1.J, v2.J, 0));
                }
                MaterialVertices::Combine(v1, v2, dst);
            }                                   

            SpatialVertices GetMaskedCopy(const IndexMask& mask) const
            {
                SpatialVertices res = MaterialVertices::GetMaskedCopy(mask);
                res.position = Mask(position, mask, 0);
                res.normal = Mask(normal, mask, 0);
                res.geoNormal = Mask(geoNormal, mask, 0);
                res.texcoord = Mask(texcoord, mask, 0);
                res.tangent = Mask(tangent, mask, 0);
                res.bitangent = Mask(bitangent, mask, 0);
                res.J = Mask(J, mask, 0);
                return res;
            }

            SpatialVertices GetIndexedCopy(const Tensori& index) const
            {
                SpatialVertices res = MaterialVertices::GetIndexedCopy(index);
                res.position = IndexedRead(position, index, 0);
                res.normal = IndexedRead(normal, index, 0);
                res.geoNormal = IndexedRead(geoNormal, index, 0);
                res.texcoord = IndexedRead(texcoord, index, 0);
                res.tangent = IndexedRead(tangent, index, 0);
                res.bitangent = IndexedRead(bitangent, index, 0);
                res.J = IndexedRead(J, index, 0);
                return res;
            }

            SpatialVertices UpdateMasked(const SpatialVertices& other, const Expr& mask) {
                MaterialVertices::UpdateMasked(other, mask);
                position = Where(mask, other.position, position);
                normal = Where(mask, other.normal, normal);
                geoNormal = Where(mask, other.geoNormal, geoNormal);
                texcoord = Where(mask, other.texcoord, texcoord);
                tangent = Where(mask, other.tangent, tangent);
                bitangent = Where(mask, other.bitangent, bitangent);
                J = Where(mask, other.J, J);
            }

            void UpdateIndexed(const SpatialVertices& other, const Tensori& index) {
                MaterialVertices::UpdateIndexed(other, index);
                Tensorf position1 = IndexedRead(position, index, 0);
                position = Detach(position + IndexedWrite(other.position - position1, index, position.GetShape(), 0));

                Tensorf normal1 = IndexedRead(normal, index, 0);
                normal = Detach(normal + IndexedWrite(other.normal - normal1, index, normal.GetShape(), 0));

                Tensorf geoNormal1 = IndexedRead(geoNormal, index, 0);
                geoNormal = Detach(geoNormal + IndexedWrite(other.geoNormal - geoNormal1, index, geoNormal.GetShape(), 0));

                Tensorf texcoord1 = IndexedRead(texcoord, index, 0);
                texcoord = Detach(texcoord + IndexedWrite(other.texcoord - texcoord1, index, texcoord.GetShape(), 0));

                Tensorf tangent1 = IndexedRead(tangent, index, 0);
                tangent = Detach(tangent + IndexedWrite(other.tangent - tangent1, index, tangent.GetShape(), 0));

                Tensorf bitangent1 = IndexedRead(bitangent, index, 0);
                bitangent = Detach(bitangent + IndexedWrite(other.bitangent - bitangent1, index, bitangent.GetShape(), 0));

                Tensorf J1 = IndexedRead(J, index, 0);
                J = J + Detach(IndexedWrite(other.J - J1, index, J.GetShape(), 0));
            }
        };
    }
}
