/*
 * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: BSD-3-Clause
 * 
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * 1. Redistributions of source code must retain the above copyright notice, this
 * list of conditions and the following disclaimer.
 *
 * 2. Redistributions in binary form must reproduce the above copyright notice,
 * this list of conditions and the following disclaimer in the documentation
 * and/or other materials provided with the distribution.
 *
 * 3. Neither the name of the copyright holder nor the names of its
 * contributors may be used to endorse or promote products derived from
 * this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

#pragma once
#include "../Tensor/Tensor.h"
#include "Config.h"

using namespace EDX;
using namespace EDX::DeepLearning;
namespace EDX
{
    namespace TensorRay
    {
        class Scene;
        class Camera;
        class Ray;
        class Intersection;
        // Old boundary term
        class EdgeIndexInfo
        {
        public:
            int numEdges;
            Tensori indexVert0;
            Tensori indexVert1;
            Tensori indexTri0;
            Tensori indexTri1;
            Tensori indexVert2;

            EdgeIndexInfo(): numEdges(0) {}
            EdgeIndexInfo GetMaskedCopy(const IndexMask& mask) const;
            EdgeIndexInfo GetIndexedCopy(const Tensori& index) const;

            static void Combine(const EdgeIndexInfo& e1, const EdgeIndexInfo& e2, EdgeIndexInfo& dst) {
                if (e1.numEdges == 0) {
                    dst.indexTri0 = e2.indexTri0;
                    dst.indexTri1 = e2.indexTri1;
                    dst.indexVert0 = e2.indexVert0;
                    dst.indexVert1 = e2.indexVert1;
                    dst.indexVert2 = e2.indexVert2;
                } else if (e2.numEdges) {
                    dst.indexTri0 = e1.indexTri0;
                    dst.indexTri1 = e1.indexTri1;
                    dst.indexVert0 = e1.indexVert0;
                    dst.indexVert1 = e1.indexVert1;
                    dst.indexVert2 = e1.indexVert2;
                } else {
                    dst.indexTri0 = Concat(e1.indexTri0, e2.indexTri0, 0);
                    dst.indexTri1 = Concat(e1.indexTri1, e2.indexTri1, 0);
                    dst.indexVert0 = Concat(e1.indexVert0, e2.indexVert0, 0);
                    dst.indexVert1 = Concat(e1.indexVert1, e2.indexVert1, 0);
                    dst.indexVert2 = Concat(e1.indexVert2, e2.indexVert2, 0);
                }
                dst.numEdges = e1.numEdges + e2.numEdges;
            }
        };

        class EdgeMaterialRep {
        public:
            int numPoints;
            // The id of shape that this edge is belong to.
            Tensori shapeId;
            // The index of edge in the EdgeIndexInfo of the mesh.
            Tensori internalId;

            Tensori globalId;

            EdgeMaterialRep(): numPoints(0) {}
            EdgeMaterialRep GetMaskedCopy(const IndexMask& mask) const;
            EdgeMaterialRep GetIndexedCopy(const Tensori& index) const;


            static void Combine(const EdgeMaterialRep& e1, const EdgeMaterialRep& e2, EdgeMaterialRep& dst) {
                if (e1.numPoints == 0) {
                    dst.shapeId = e2.shapeId;
                    dst.internalId = e2.internalId;
                    dst.globalId = e2.globalId;
                } else if (e2.numPoints == 0) {
                    dst.shapeId = e1.shapeId;
                    dst.internalId = e1.internalId;
                    dst.globalId = e1.globalId;
                } else {
                    dst.shapeId = Concat(e1.shapeId, e2.shapeId, 0);
                    dst.internalId = Concat(e1.internalId, e2.internalId, 0);
                    dst.globalId = Concat(e1.globalId, e2.globalId, 0);
                }
                dst.numPoints = e1.numPoints + e2.numPoints;
            }
        };

        class EdgeSampleMaterialRep {
        public:
            int numSamples;
            std::shared_ptr<EdgeMaterialRep> edgeMaterialInfo;
            // An edge point: p = p_0 + \alpha * (p_1 - p_0)
            Tensorf alpha;

            EdgeSampleMaterialRep(): numSamples(0), edgeMaterialInfo(std::make_shared<EdgeMaterialRep>()) {}
            EdgeSampleMaterialRep GetMaskedCopy(const IndexMask& mask) const;
            EdgeSampleMaterialRep GetIndexedCopy(const Tensori& index) const;

            static void Combine(const EdgeSampleMaterialRep& e1, const EdgeSampleMaterialRep& e2, 
                        EdgeSampleMaterialRep& dst) {
                if (e1.numSamples == 0) {
                    dst.alpha = e2.alpha;
                } else if (e2.numSamples == 0) {
                    dst.alpha = e1.alpha;
                } else {
                    dst.alpha = Concat(e1.alpha, e2.alpha, 0);
                }
                EdgeMaterialRep rep;
                EdgeMaterialRep::Combine(*e1.edgeMaterialInfo, *e2.edgeMaterialInfo, rep);
                dst.edgeMaterialInfo = std::make_shared<EdgeMaterialRep>(rep);
                dst.numSamples = e1.numSamples + e2.numSamples;
            }
        };

        // Dircect Boundary
        struct SecondaryEdgeInfo
        {
            int numTot;
            std::shared_ptr<EdgeMaterialRep> materialExp;
            Tensorb isBoundary;
            Tensorf p0;
            Tensorf e1;
            Tensorf n0;
            Tensorf n1;
            Tensorf p2;

            SecondaryEdgeInfo(): numTot(0), materialExp(std::make_shared<EdgeMaterialRep>()) {}
            SecondaryEdgeInfo GetMaskedCopy(const IndexMask& mask) const;
            SecondaryEdgeInfo GetIndexedCopy(const Tensori& index) const;

            static void Combine(const SecondaryEdgeInfo& s1, const SecondaryEdgeInfo& s2, SecondaryEdgeInfo& dst) {
                if (s1.numTot == 0) {
                    dst.isBoundary = s2.isBoundary;
                    dst.p0 = s2.p0;
                    dst.e1 = s2.e1;
                    dst.n0 = s2.n0;
                    dst.n1 = s2.n1;
                    dst.p2 = s2.p2;
                } else if (s2.numTot == 0) {
                    dst.isBoundary = s1.isBoundary;
                    dst.p0 = s1.p0;
                    dst.e1 = s1.e1;
                    dst.n0 = s1.n0;
                    dst.n1 = s1.n1;
                    dst.p2 = s1.p2;
                } else {
                    dst.isBoundary = Concat(s1.isBoundary, s2.isBoundary, 0);
                    dst.p0 = Concat(s1.p0, s2.p0, 0);
                    dst.e1 = Concat(s1.e1, s2.e1, 0);
                    dst.n0 = Concat(s1.n0, s2.n0, 0);
                    dst.n1 = Concat(s1.n1, s2.n1, 0);
                    dst.p2 = Concat(s1.p2, s2.p2, 0);
                }

                EdgeMaterialRep rep;
                EdgeMaterialRep::Combine(*s1.materialExp, *s2.materialExp, rep);
                dst.materialExp = std::make_shared<EdgeMaterialRep>(rep);

                dst.numTot = s1.numTot + s2.numTot;
            }
        };

        struct BoundarySegSampleSecondary
        {
            std::shared_ptr<EdgeSampleMaterialRep> materialRep;
            Tensorf p0;
            Tensorf edge;
            Tensorf edge2;
            Tensorf n0;
            Tensorf n1;
            Tensorf pdf;
            int numSamples;

            BoundarySegSampleSecondary(): numSamples(0), materialRep(std::make_shared<EdgeSampleMaterialRep>()) {}
            BoundarySegSampleSecondary(int _numSamples): numSamples(_numSamples), materialRep(std::make_shared<EdgeSampleMaterialRep>()) {
                p0 = Zeros(Shape({_numSamples}, VecType::Vec3));
                edge = Zeros(Shape({_numSamples}, VecType::Vec3));
                edge2 = Zeros(Shape({_numSamples}, VecType::Vec3));
                n0 = Zeros(Shape({_numSamples}, VecType::Vec3));
                n1 = Zeros(Shape({_numSamples}, VecType::Vec3));
                pdf = Ones(Shape({_numSamples}, VecType::Scalar1));
            }
            BoundarySegSampleSecondary GetMaskedCopy(const IndexMask& mask) const;
            BoundarySegSampleSecondary GetIndexedCopy(const Tensori& index) const;

            static void Combine(const BoundarySegSampleSecondary& s1, const BoundarySegSampleSecondary& s2, 
                                    BoundarySegSampleSecondary& dst) {
                if (s1.numSamples == 0) {
                    dst.p0 = s2.p0;
                    dst.edge = s2.edge;
                    dst.edge2 = s2.edge2;
                    dst.n0 = s2.n0;
                    dst.n1 = s2.n1;
                    dst.pdf = s2.pdf;
                } else if (s2.numSamples == 0) {
                    dst.p0 = s1.p0;
                    dst.edge = s1.edge;
                    dst.edge2 = s1.edge2;
                    dst.n0 = s1.n0;
                    dst.n1 = s1.n1;
                    dst.pdf = s1.pdf;
                } else {
                    dst.p0 = Concat(s1.p0, s2.p0, 0);
                    dst.edge = Concat(s1.edge, s2.edge, 0);
                    dst.edge2 = Concat(s1.edge2, s2.edge2, 0);
                    dst.n0 = Concat(s1.n0, s2.n0, 0);
                    dst.n1 = Concat(s1.n1, s2.n1, 0);
                    dst.pdf = Concat(s1.pdf, s2.pdf, 0);
                }              

                EdgeSampleMaterialRep rep;
                EdgeSampleMaterialRep::Combine(*s1.materialRep, *s2.materialRep, rep);
                dst.materialRep = std::make_shared<EdgeSampleMaterialRep>(rep);

                dst.numSamples = s1.numSamples + s2.numSamples;
            }
        };
        
        struct BoundarySegSampleDirect : BoundarySegSampleSecondary
        {
            Tensorf p2;
            Tensorf n;
            IndexMask maskValid;

            BoundarySegSampleDirect getValidCopy() const
            {
                BoundarySegSampleDirect ret;
                ret.p0 = Mask(p0, maskValid, 0);
                ret.edge = Mask(edge, maskValid, 0);
                ret.edge2 = Mask(edge2, maskValid, 0);
                ret.p2 = Mask(p2, maskValid, 0);
                ret.n = Mask(n, maskValid, 0);
                ret.pdf = Mask(pdf, maskValid, 0);
                ret.maskValid = IndexMask(Ones(maskValid.sum));
                return ret;
            }
        };
        int ConstructSecEdgeList(const Scene& scene, SecondaryEdgeInfo& list);
        Tensori SampleFromSecEdges(const SecondaryEdgeInfo& list, const Tensorf& rnd1, BoundarySegSampleSecondary& samples);
        Tensori SampleFromSecEdgesMaterial(const SecondaryEdgeInfo& list, const Tensorf& rnd1, EdgeSampleMaterialRep& samples, Tensorf& pdf);
        Tensori SampleFromSecEdgesMaterial(const SecondaryEdgeInfo& list, const Tensorf& rnd1, const Tensorf& prePdf, EdgeSampleMaterialRep& samples, Tensorf& pdf);
        int SampleBoundarySegmentDirect(const Scene& scene, const SecondaryEdgeInfo &secEdges, int numSamples, const Tensorf& rnd_b, const Tensorf& pdf_b, BoundarySegSampleDirect& samples, bool guiding_mode);
        int EvalBoundarySegmentDirect(const Camera& camera, const Scene& scene, int mSpp, int mMaxBounce, BoundarySegSampleDirect& bss, Tensorf& boundaryTerm, bool guiding_mode);

        void ConvertEdgeSampleMaterialToSpatial(const Scene& scene, const EdgeSampleMaterialRep& materialRep, BoundarySegSampleSecondary& spatialRep);

        // Indirect Boundary
        struct BoundarySegSampleIndirect : BoundarySegSampleSecondary
        {
            Tensorf dir;
            IndexMask maskValid;

            BoundarySegSampleIndirect getValidCopy() const
            {
                BoundarySegSampleIndirect ret;
                ret.p0 = Mask(p0, maskValid, 0);
                ret.edge = Mask(edge, maskValid, 0);
                ret.edge2 = Mask(edge2, maskValid, 0);
                ret.dir = Mask(dir, maskValid, 0);
                ret.pdf = Mask(pdf, maskValid, 0);
                ret.maskValid = IndexMask(Ones(maskValid.sum));
                return ret;
            }
        };
        int SampleBoundarySegmentIndirect(const Scene& scene, const SecondaryEdgeInfo& secEdges, int numSamples, BoundarySegSampleIndirect& samples);

        // New primary edge evaluation
        struct PrimaryEdgeInfo2
        {
            int numTot;
            Expr isBoundary;
            Expr p0;
            Expr e1;
            Expr n0;
            Expr n1;
            Expr p2;
        };
        struct BoundarySegSamplePrimary
        {
            Tensorf p0;
            Tensorf edge;
            Tensorf edge2;
            Tensorf pdf;
            IndexMask maskValid;

            BoundarySegSamplePrimary getValidCopy() const
            {
                BoundarySegSamplePrimary ret;
                ret.p0 = Mask(p0, maskValid, 0);
                ret.edge = Mask(edge, maskValid, 0);
                ret.edge2 = Mask(edge2, maskValid, 0);
                ret.pdf = Mask(pdf, maskValid, 0);
                ret.maskValid = IndexMask(Ones(maskValid.sum));
                return ret;
            }
        };
        int ConstructPrimEdgeList(const Scene& scene, const Camera& camera, PrimaryEdgeInfo2& list);
        Tensori SampleFromPrimEdges(const PrimaryEdgeInfo2& list, int numSamples, BoundarySegSamplePrimary& samples);
        int SampleBoundarySegmentPrimary(const Scene& scene, const PrimaryEdgeInfo2& primEdges, int numSamples, BoundarySegSamplePrimary& samples);

        // Pixel boundary
        struct BoundarySegSamplePixel : BoundarySegSamplePrimary
        {
            Tensori rayIdx;
            Tensori pixelIdx;

            BoundarySegSamplePixel getValidCopy() const
            {
                BoundarySegSamplePixel ret;
                ret.p0 = Mask(p0, maskValid, 0);
                ret.edge = Mask(edge, maskValid, 0);
                ret.edge2 = Mask(edge2, maskValid, 0);
                ret.pdf = Mask(pdf, maskValid, 0);
                ret.rayIdx = Mask(rayIdx, maskValid, 0);
                ret.pixelIdx = Mask(pixelIdx, maskValid, 0);
                ret.maskValid = IndexMask(Ones(maskValid.sum));
                return ret;
            }
        };

        void SampleBoundarySegmentPixel(const Camera& camera, int spp, int antitheticSpp, BoundarySegSamplePixel& samples);
    }
}
