/*
 * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: BSD-3-Clause
 * 
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * 1. Redistributions of source code must retain the above copyright notice, this
 * list of conditions and the following disclaimer.
 *
 * 2. Redistributions in binary form must reproduce the above copyright notice,
 * this list of conditions and the following disclaimer in the documentation
 * and/or other materials provided with the distribution.
 *
 * 3. Neither the name of the copyright holder nor the names of its
 * contributors may be used to endorse or promote products derived from
 * this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

#pragma once

#include "BoundaryDirect2.h"

#include "./Algorithm1.h"
#include "../Boundary.h"
#include "cukd/knnfinder.h"

namespace EDX
{
    namespace TensorRay
    {
        Tensori EdgeReservoir::UpdateReservoir(EdgeReservoir& r, const LightSample& xI, const Expr& wI,
                const Expr& m1, const Expr& sample) {

            r.m = r.m + m1;
            Expr haveSampleMask = (r.m > Scalar(0));
            Expr factor = Where(haveSampleMask, Scalar(1.0f) / r.m, Scalar(0.0f));

            r.wAverage = (r.m - m1) * factor * r.wAverage + m1 * factor * wI;

            Expr wAverageVal = Where(r.wAverage > Scalar(0.0f), Scalar(1.0f) / r.wAverage, Scalar(0.0f));
            Expr updateMask = (haveSampleMask * (sample < (wI * m1 * factor * wAverageVal)) * (m1 > Scalar(0)));

            r.x->UpdateMasked(xI, updateMask);

            return Tensori(updateMask);
        }

        Tensori EdgeReservoir::UpdateReservoir(EdgeReservoir& r, const EdgeReservoir& r1, const Expr& sample){
            return UpdateReservoir(r, *r1.x, r1.W * r1.x->pHat, r1.m, sample);
        }

        void CheckEdgeList(const SecondaryEdgeInfo& list1, const SecondaryEdgeInfo& list2) {
            assert(list1.numTot == list2.numTot);
            for (int i = 0; i < list1.numTot; ++i) {
                assert(list1.materialExp->globalId.Get(i) == list2.materialExp->globalId.Get(i));
                assert(list1.materialExp->shapeId.Get(i) == list2.materialExp->shapeId.Get(i));
                assert(list1.materialExp->internalId.Get(i) == list2.materialExp->internalId.Get(i));
            }
        }

        void CheckEdgeList(const EdgeIndexInfo& list1, const EdgeIndexInfo& list2) {
            assert(list1.numEdges == list2.numEdges);
            for (int i = 0; i < list1.numEdges; ++i) {
                assert(list1.indexTri0.Get(i) == list1.indexTri0.Get(i));
                assert(list1.indexTri1.Get(i) == list1.indexTri1.Get(i));
                assert(list1.indexVert0.Get(i) == list2.indexVert0.Get(i));
                assert(list1.indexVert1.Get(i) == list2.indexVert1.Get(i));
                assert(list1.indexVert2.Get(i) == list2.indexVert2.Get(i));
            }
        }

        void DirectBoundaryIntegrator2::Integrate(const Scene& scene, Tensorf& image) const
        {
#if USE_PROFILING
            nvtxRangePushA(__FUNCTION__);
#endif
            if (mSpp == 0)
            {
#if USE_PROFILING
                nvtxRangePop();
#endif
                return;
            }
            const Camera& camera = *scene.mSensors[0];
            
            SecondaryEdgeInfo secEdges;
            if (ConstructSecEdgeList(scene, secEdges) == 0) 
            {
#if USE_PROFILING
                nvtxRangePop();
#endif
                return;
            }

            // if (storedInfo != nullptr) {
            //     CheckEdgeList(*storedInfo, scene.mPrims[2]->mpMesh->mEdgeInfo);
            // }
            // storedInfo = std::make_shared<EdgeIndexInfo>(scene.mPrims[2]->mpMesh->mEdgeInfo);

            // TODO: Guiding Parameters
            Tensorf edgeLength = VectorLength(secEdges.e1);
            Distribution1D secEdgeDistrb(edgeLength);
            int edge_size = edgeLength.LinearSize();
            Tensori idx = Tensori::ArrayRange(0, edge_size-1, 1);
            
            if (mVerbose)
                std::cout << "Total edge_size: " << edge_size << std::endl;

            AdaptiveQuadratureDistribution m_aq;
            Tensorf cut_cdf = secEdgeDistrb.mCDF;

            if (g_direct) 
            {
                Timer timer_guide;
                timer_guide.Start();
                m_aq.setup(camera, scene, secEdges, cut_cdf, g_options);
                if (mVerbose)
                    std::cout << string_format("[DirectBoundary Guiding] Total Elapsed time = %f seconds", timer_guide.GetElapsedTime()) << std::endl;
            }
                
            Timer timer;
            timer.Start();
            Shape imageShape = Shape({ camera.mResX * camera.mResY }, VecType::Vec3);
            int npass = std::ceil(mSpp / mSppBatch);
            for (int ipass = 0; ipass < npass; ipass++) 
            {
                int numSecondarySamples = imageShape[0] * mSppBatch;
                BoundarySegSampleDirect bss;
                // Step 1: Sample point on the edge & emitter
                Tensorf rnd_b = Tensorf::RandomFloat(Shape({ numSecondarySamples }, VecType::Vec3)); // no guiding
                EdgeDirectSample edgeDirectSamples;
                if (SampleBoundarySegmentReSTIR(scene, secEdges, numSecondarySamples, rnd_b, edgeDirectSamples) == 0)
                {
#if USE_PROFILING
                    nvtxRangePop();
#endif
                    return;
                }
                // 
                // std::cout << numSecondarySamples << std::endl;
                // std::cout << bss.pdf.LinearSize() << std::endl;
                Tensorf boundaryTerm = Zeros(Shape({ camera.mResX * camera.mResY }, VecType::Vec3));
                if (EvalBoundarySegment(scene, edgeDirectSamples, boundaryTerm, mSpp) == 0) 
                {
#if USE_PROFILING
                    nvtxRangePop();
#endif
                    return;
                }
                boundaryTerm = boundaryTerm - Detach(boundaryTerm);
                mGradHandler.AccumulateDeriv(boundaryTerm);
                if (!mDLoss.Empty()) 
                {
                    boundaryTerm.Backward(mDLoss);
                    AccumulateGradsAndReleaseGraph();
                }

                if (mVerbose)
                    std::cout << string_format("[DirectBoundary2] #Pass %d / %d, %d kernels launched\r", ipass + 1, npass, KernelLaunchCounter::GetHandle());
                KernelLaunchCounter::Reset();
            }
            if (mVerbose)
                std::cout << string_format("[DirectBoundary2] Total Elapsed time = %f (%f samples/pass, %d pass)", timer.GetElapsedTime(), mSppBatch, npass) << std::endl;
#if USE_PROFILING
            nvtxRangePop();
#endif
        }

        int DirectBoundaryIntegrator2::SampleBoundarySegmentReSTIR(const Scene& scene, 
                const SecondaryEdgeInfo &secEdges, int numSamples, const Tensorf& rnd_b, 
                EdgeDirectSample& samples) const {
            auto samples_x = X(rnd_b);
            auto samples_y = Y(rnd_b);
            auto samples_z = Z(rnd_b);

            // Sampling an edge point

            Tensorf prePdf = Ones(Shape({secEdges.numTot}, VecType::Scalar1));
            std::shared_ptr<EdgeReservoir> prefiltReservoirs;
            if (reservoirs != nullptr && reservoirs->numOfSamples > 0) {
                prefiltReservoirs = std::move(PrefilerReservoir(scene, *scene.mSensors[0], reservoirs, true));
                if (prefiltReservoirs->numOfSamples > 0 && haveTemporalReuse && boundaryGuiding) {
                    Tensori m = Ones(Shape({prefiltReservoirs->numOfSamples}, VecType::Scalar1));
                    BoundarySegSampleSecondary edgePoints;
                    ConvertEdgeSampleMaterialToSpatial(scene, 
                                *prefiltReservoirs->edgePoints->materialRep,
                                edgePoints);
                    SpatialVertices lightPoints;
                    ConvertToSpatialVertices(scene, *prefiltReservoirs->x->vertices, lightPoints);
                    Tensori validIndex;
                    Tensorf prePHat = EvalPHat(scene, edgePoints, 
                            lightPoints, prefiltReservoirs->isBoundary, validIndex);
                    Tensori numOfSamples = IndexedWrite(m, 
                        prefiltReservoirs->edgePoints->materialRep->edgeMaterialInfo->globalId, prePdf.GetShape(), 0);
                    Tensorf wSum = IndexedWrite(prePHat, 
                        prefiltReservoirs->edgePoints->materialRep->edgeMaterialInfo->globalId, prePdf.GetShape(), 0);
                    prePdf = Where(numOfSamples > Scalar(0), wSum / numOfSamples, Scalar(0.0f));
                }
            }

            EdgeSampleMaterialRep edgeMaterialRep;
            BoundarySegSampleSecondary edgeSpatialRep;
            Tensorf edgeSamplingPdf;
            Tensori edgeIdx = SampleFromSecEdgesMaterial(secEdges, samples_x, prePdf, edgeMaterialRep, edgeSamplingPdf);
            ConvertEdgeSampleMaterialToSpatial(scene, edgeMaterialRep, edgeSpatialRep);
            edgeSpatialRep.pdf = edgeSamplingPdf;
            samples.edgePoints = std::make_shared<BoundarySegSampleSecondary>(edgeSpatialRep);
            samples.edgePointsPdf = edgeSamplingPdf;
            samples.isBoundary = IndexedRead(secEdges.isBoundary, edgeIdx, 0);

            int mm0 = edgeMaterialRep.numSamples;
            EdgeReservoir curIterationReservoir(edgeMaterialRep.numSamples);
            curIterationReservoir.edgePoints = std::make_shared<BoundarySegSampleSecondary>(*samples.edgePoints);
            curIterationReservoir.isBoundary = samples.isBoundary;
            // Sampling lights
            for (int i = 0; i < M; ++i) {
                LightSample curLightSample(edgeMaterialRep.numSamples);
                Tensorf rndLight = Tensorf::RandomFloat(Shape({ edgeMaterialRep.numSamples }, VecType::Vec2));

                SampleEmitterDirect(scene, rndLight, *curLightSample.vertices);

                SpatialVertices curSpatialLightSamples;
                ConvertToSpatialVertices(scene, *curLightSample.vertices, curSpatialLightSamples);

                Tensori curValidIndex;
                curLightSample.pHat = EvalPHat(scene, *samples.edgePoints, curSpatialLightSamples, samples.isBoundary, curValidIndex);
                Tensorf rndReservoir = Tensorf::RandomFloat(
                    Shape({ curLightSample.numOfSamples }, VecType::Scalar1));
                EdgeReservoir::UpdateReservoir(curIterationReservoir, curLightSample, 
                    curLightSample.pHat / curLightSample.vertices->pdf, Scalar(1), rndReservoir);
            }
            curIterationReservoir.W = Detach(Where(curIterationReservoir.x->pHat > Scalar(0.0f), 
                                            curIterationReservoir.wAverage / curIterationReservoir.x->pHat, 
                                            Scalar(0.0f)));

            if (haveTemporalReuse && prefiltReservoirs != nullptr && prefiltReservoirs->numOfSamples > 0) {
                Expr m0 = curIterationReservoir.m;
                Tensori pickedIndex = Zeros(Shape({curIterationReservoir.numOfSamples}, VecType::Scalar1));
                Expr w0 = curIterationReservoir.wAverage;
                int numSamples = curIterationReservoir.numOfSamples;
                prefiltReservoirs->m = Where(prefiltReservoirs->m > Scalar(M * historyLength), Scalar(M * historyLength), prefiltReservoirs->m);
                Tensori localShiftMapping = Zeros(Shape({ numSamples * k }, VecType::Scalar1));
                IndexMask validReservoirIndex = ((prefiltReservoirs->m > Scalar(0)) && (prefiltReservoirs->x->pHat > Scalar(0.0f)));
                {
                    BoundarySegSampleSecondary prevEdgePoints;
                    ConvertEdgeSampleMaterialToSpatial(scene, 
                        *prefiltReservoirs->edgePoints->GetMaskedCopy(validReservoirIndex).materialRep, 
                        prevEdgePoints);
                    cukd::FindKNN(prevEdgePoints.p0.Data(), 
                        validReservoirIndex.sum, 
                        curIterationReservoir.edgePoints->p0.Data(), 
                        numSamples, 
                        localShiftMapping.Data(), 
                        k);
                    localShiftMapping.CopyToHost();
                }

                shuffle(reservoirToUse.begin(), reservoirToUse.end(), randomGenerator);
                for (int j = 0; j < k1; ++j) {
                    int i = reservoirToUse[j];
                    Tensori curLocalShiftMapping = IndexedRead(localShiftMapping, 
                        Tensori::ArrayRange(numSamples * i, numSamples * (i + 1), 1, false), 0);
                    Tensori globalShiftMapping = IndexedRead(validReservoirIndex.index, 
                                                            curLocalShiftMapping, 0);
                    EdgeReservoir previousReservoirCopy = prefiltReservoirs->GetIndexedCopy(globalShiftMapping, 
                                                            numSamples);

                    SpatialVertices prevSpatialLightSamples;
                    ConvertToSpatialVertices(scene, *previousReservoirCopy.x->vertices, prevSpatialLightSamples);

                    BoundarySegSampleSecondary previousEdgePoints;
                    ConvertEdgeSampleMaterialToSpatial(scene, *(previousReservoirCopy.edgePoints->materialRep), 
                            previousEdgePoints);

                    Tensori previousValidIndex;
                    previousReservoirCopy.x->pHat = EvalPHat(scene, *(curIterationReservoir.edgePoints), 
                            prevSpatialLightSamples, previousReservoirCopy.isBoundary, previousValidIndex);
                    Tensori valid = IndexedWrite(Scalar(1), previousValidIndex, previousReservoirCopy.x->pHat.GetShape(), 0);

                    Expr normal1 = VectorCross(curIterationReservoir.edgePoints->n0, curIterationReservoir.edgePoints->edge);
                    Expr normal2 = VectorCross(previousReservoirCopy.edgePoints->n0, previousReservoirCopy.edgePoints->edge);
                    Expr tangent1, bitangent1;
                    Expr tangent2, bitangent2;
                    CoordinateSystem(normal1, &tangent1, &bitangent1);
                    CoordinateSystem(normal2, &tangent2, &bitangent2);
                    Tensorf tangentTest = VectorDot(tangent1, tangent2);

                    // Expr normalDotProduct0Exp = VectorDot(previousReservoirCopy.edgePoints->n0, curIterationReservoir.edgePoints->n0);
                    // Tensorf normalDotProduct0 = normalDotProduct0Exp;
                    // Tensorf normalDotProduct1 = VectorDot(previousReservoirCopy.edgePoints->n1, curIterationReservoir.edgePoints->n1);
                    Tensorf dist = VectorLength(previousReservoirCopy.edgePoints->p0 - curIterationReservoir.edgePoints->p0);
                    Expr nFilter = ((tangentTest > Scalar(reservoirMergeNormalThreshold)) * 
                                    (dist < Scalar(reservoirMergeDistThreshold)) * 
                                    (valid > Scalar(0)));
                    previousReservoirCopy.m = Where(nFilter, previousReservoirCopy.m, Scalar(0));
                    previousReservoirCopy.W = Where(nFilter, previousReservoirCopy.W, Scalar(0.0f));
                    previousReservoirCopy.wAverage = Where(nFilter, previousReservoirCopy.wAverage, Scalar(0.0f));

                    Tensorf rndReservoir = Tensorf::RandomFloat(
                        Shape({ curIterationReservoir.numOfSamples }, VecType::Scalar1));

                    Tensori curMask = EdgeReservoir::UpdateReservoir(curIterationReservoir, previousReservoirCopy, rndReservoir);
                    pickedIndex = Where(curMask > Scalar(0), Scalar(j + 1), pickedIndex);
                }

                // Tensorf Z = Zeros(Shape({ numSamples }, VecType::Scalar1));
                // Tensorf _M = Zeros(Shape({ numSamples }, VecType::Scalar1));
                Tensorf pI = Zeros(Shape({ numSamples }, VecType::Scalar1));
                Tensorf pSum = Zeros(Shape({ numSamples }, VecType::Scalar1));
                SpatialVertices curLightSampleSpatialPoints;
                ConvertToSpatialVertices(scene, *(curIterationReservoir.x->vertices), curLightSampleSpatialPoints);
                for (int j = 0; j < k1; ++j) {
                    int i = reservoirToUse[j];
                    Tensori curLocalShiftMapping = IndexedRead(localShiftMapping, 
                        Tensori::ArrayRange(numSamples * i, numSamples * (i + 1), 1, false), 0);
                    Tensori globalShiftMapping = IndexedRead(validReservoirIndex.index, curLocalShiftMapping, 0);
                    EdgeReservoir previousReservoirCopy = prefiltReservoirs->GetIndexedCopy(globalShiftMapping, numSamples);

                    BoundarySegSampleSecondary previousEdgePoints;
                    ConvertEdgeSampleMaterialToSpatial(scene, *(previousReservoirCopy.edgePoints->materialRep), 
                            previousEdgePoints);

                    // Apply normal test
                    Expr normal1 = VectorCross(curIterationReservoir.edgePoints->n0, curIterationReservoir.edgePoints->edge);
                    Expr normal2 = VectorCross(previousReservoirCopy.edgePoints->n0, previousReservoirCopy.edgePoints->edge);
                    Expr tangent1, bitangent1;
                    Expr tangent2, bitangent2;
                    CoordinateSystem(normal1, &tangent1, &bitangent1);
                    CoordinateSystem(normal2, &tangent2, &bitangent2);
                    Tensorf tangentTest = VectorDot(tangent1, tangent2);
                    Tensorf dist = VectorLength(previousReservoirCopy.edgePoints->p0 - curIterationReservoir.edgePoints->p0);
                    Tensori previousValidIndex;
                    Expr pHatPrevCur = EvalPHat(scene, previousEdgePoints, 
                            curLightSampleSpatialPoints, previousReservoirCopy.isBoundary, previousValidIndex);
                    Tensorf valid = IndexedWrite(Scalar(1.0f), previousValidIndex, previousReservoirCopy.x->pHat.GetShape(), 0);
                    Expr nFilter = ((tangentTest > Scalar(reservoirMergeNormalThreshold)) * 
                                    (dist < Scalar(reservoirMergeDistThreshold)) * 
                                    (valid > Scalar(0.0f)));
                    previousReservoirCopy.m = Where(nFilter, previousReservoirCopy.m, Scalar(0));

                    pSum = pSum + previousReservoirCopy.m * pHatPrevCur * valid;
                    pI = Where(pickedIndex == Scalar(j + 1), pHatPrevCur * valid, pI);
                    // Z = Z + Where(pHatPrevCur > Scalar(0.0f), previousReservoirCopy.m * previousReservoirCopy.wAverage + Scalar(RESTIR_PHAT_EPSILON), Scalar(0.0f));
                    // _M = _M + previousReservoirCopy.m * previousReservoirCopy.wAverage + Scalar(RESTIR_PHAT_EPSILON);
                }

                // Test itself
                Tensori curValidIndex;
                Expr pHatCurCur = EvalPHat(scene, *(curIterationReservoir.edgePoints), 
                            curLightSampleSpatialPoints, curIterationReservoir.isBoundary, curValidIndex);
                Tensorf valid = IndexedWrite(Scalar(1.0f), curValidIndex, curIterationReservoir.x->pHat.GetShape(), 0);
                pSum = pSum + m0 * pHatCurCur * valid;
                pI = Where(pickedIndex == Scalar(0), pHatCurCur * valid, pI);
                Tensorf mi = Where(pSum > Scalar(0.0f), pI / pSum, Scalar(0.0f));
                Tensorf wSum = curIterationReservoir.wAverage * curIterationReservoir.m;
                curIterationReservoir.W = Where(pHatCurCur > Scalar(0.0f), Scalar(1.0f) / pHatCurCur * wSum * mi, Scalar(0.0f));
            }

            if (haveTemporalReuse) {
                EdgeReservoir r;
                EdgeReservoir::Combine(*reservoirsOnFile, curIterationReservoir, r);
                SpatialVertices curLightSampleSpatialPoints;
                ConvertToSpatialVertices(scene, *(curIterationReservoir.x->vertices), curLightSampleSpatialPoints);
                Tensori curValidIndex;
                Expr pHatCurCur = EvalPHat(scene, *(curIterationReservoir.edgePoints), 
                            curLightSampleSpatialPoints, curIterationReservoir.isBoundary, curValidIndex);
                // IndexMask mm1Mask = pHatCurCur > Scalar(SHADOW_EPSILON);
                // int mm1 = mm1Mask.sum;
                // std::cout << float(mm1) / mm0 << std::endl;
                if (curValidIndex.NumElements() > 0) {
                    EdgeReservoir r;
                    EdgeReservoir::Combine(*reservoirsOnFile, curIterationReservoir.GetIndexedCopy(curValidIndex, curValidIndex.NumElements()), r);
                    reservoirsOnFile = std::make_shared<EdgeReservoir>(r);
                } 
                else {
                    // EdgeReservoir r;
                    // auto reservoirToMerge = PrefilerReservoir(scene, *scene.mSensors[0], std::shared_ptr<EdgeReservoir>(curIterationReservoir));
                    // EdgeReservoir::Combine(*reservoirsOnFile, 
                    //                         *PrefilerReservoir(scene, *scene.mSensors[0], std::make_shared<EdgeReservoir>(curIterationReservoir)), 
                    //                         r);
                    // reservoirsOnFile = std::make_shared<EdgeReservoir>(r);
                }
            }

            SpatialVertices spatialLightSamples;
            ConvertToSpatialVertices(scene, *curIterationReservoir.x->vertices, spatialLightSamples);
            samples.lightSamples = std::make_shared<SpatialVertices>(spatialLightSamples);
            samples.lightPdf = Where(curIterationReservoir.W > Scalar(0.0f), Scalar(1.0f) / curIterationReservoir.W,
                                            Scalar(0.0f));
            samples.lightSamples->pdf = Where(curIterationReservoir.W > Scalar(0.0f), Scalar(1.0f) / curIterationReservoir.W,
                                            Scalar(0.0f));

            return 1;
        }

        Tensorf DirectBoundaryIntegrator2::EvalPHat(const Scene& scene, const BoundarySegSampleSecondary& edgePoints,
                            const SpatialVertices& lightSamples, const Tensorb& isBoundary, Tensori& validIndex) const {
            // secEdgeSamples.p2 = lightSamples.p;
            // secEdgeSamples.n = lightSamples.n;
            Tensorf pHat = Zeros(Shape({ edgePoints.numSamples }, VecType::Scalar1));
            auto e = lightSamples.position - Detach(edgePoints.p0);
            auto distSqr = VectorSquaredLength(e);
            auto dist = Sqrt(distSqr);
            auto eNormalized = e / dist;
            auto cosTheta = VectorDot(lightSamples.normal, -eNormalized);
            auto cosine0 = VectorDot(eNormalized, edgePoints.n0);
            auto cosine1 = VectorDot(eNormalized, edgePoints.n1);
            auto valid0 = Abs(cosine0) > Scalar(EDGE_EPSILON);
            auto valid1 = (cosine0 > Scalar(EDGE_EPSILON) && cosine1 < Scalar(-EDGE_EPSILON)) || (cosine0 < Scalar(-EDGE_EPSILON) && cosine1 > Scalar(EDGE_EPSILON));
            auto rightSide = (isBoundary && valid0) || (~isBoundary && valid1);
            auto valid2 = cosTheta > Scalar(EPSILON);
            auto valid3 = dist > Scalar(SHADOW_EPSILON);

            Tensorf G1 = cosTheta / (distSqr);
            // secEdgeSamples.pdf = secEdgeSamples.pdf * lightSamples.pdf * distSqr / cosTheta * pdf_b;
            // secEdgeSamples.maskValid = IndexMask(rightSide && valid2 && valid3);
            IndexMask validMask1(rightSide && valid2 && valid3);
            if (validMask1.sum == 0) {
                return pHat;
            }
            BoundarySegSampleSecondary validEdgePoints = edgePoints.GetMaskedCopy(validMask1);
            SpatialVertices lightSamples1 = lightSamples.GetMaskedCopy(validMask1);

            // if (secEdgeSamples.maskValid.sum > 0)
            //     secEdgeSamples = secEdgeSamples.getValidCopy();
            {
                IndexMask validMask2;
                BoundarySegSampleSecondary validEdgePoints2;
                SpatialVertices lightSamples2;
                Tensorf rayDir = Detach(VectorNormalize(validEdgePoints.p0 - lightSamples1.position));
                Ray ray(validEdgePoints.p0, rayDir);
                Intersection its;
                Expr emittedRadiance, baseVal, xDotN;
                {
                    Intersection its0;
                    Ray ray0(validEdgePoints.p0, -rayDir);
                    scene.Intersect(ray0, its0);
                    scene.Intersect(ray, its);
                    auto hitP = ray0.mOrg + its0.mTHit * ray0.mDir;
                    Tensorb samePoint = VectorLength(hitP - lightSamples1.position) < Scalar(SHADOW_EPSILON);
                    validMask2 = IndexMask(samePoint && its0.mTriangleId != Scalar(-1)
                        && its.mTriangleId != Scalar(-1));
                    if (validMask2.sum == 0)
                    {
                        return pHat;
                    }

                    ray = ray.GetMaskedCopy(validMask2);
                    its0 = its0.GetMaskedCopy(validMask2);
                    its = its.GetMaskedCopy(validMask2);
                    Tensori validIndex;
                    validEdgePoints2 = validEdgePoints.GetMaskedCopy(validMask2);
                    lightSamples2 = lightSamples1.GetMaskedCopy(validMask2);

                    emittedRadiance = Zeros(Shape({ ray.mNumRays }, VecType::Vec3));
                    IndexMask mask_light = its0.mEmitterId != Scalar(-1);
                    if (mask_light.sum > 0) 
                    {
                        Intersection its_light = its0.GetMaskedCopy(mask_light);
                        auto dir = Mask(ray.mDir, mask_light, 0);
                        scene.PostIntersect(its_light);
                        auto val = Detach(scene.mLights[scene.mAreaLightIndex]->Eval(its_light, dir));
                        emittedRadiance = emittedRadiance + IndexedWrite(val, mask_light.index, emittedRadiance->GetShape(), 0);
                    }

                    scene.PostIntersect(its);
                    Tensorf d_position = its.mPosition;
                    auto dist = VectorLength(lightSamples2.position - its.mPosition);
                    auto cos2 = Abs(VectorDot(lightSamples2.normal, ray.mDir));
                    auto e = VectorCross(validEdgePoints2.edge, -ray.mDir);
                    auto sinphi = VectorLength(e);
                    auto proj = VectorNormalize(VectorCross(e, lightSamples2.normal));
                    auto sinphi2 = VectorLength(VectorCross(-ray.mDir, proj));
                    auto itsT = VectorLength(its.mPosition - validEdgePoints2.p0);
                    auto n = Detach(VectorNormalize(VectorCross(lightSamples2.normal, proj)));
                    auto sign0 = VectorDot(e, validEdgePoints2.edge2) > Scalar(0.0f);
                    auto sign1 = VectorDot(e, n) > Scalar(0.0f);
                    baseVal = Detach((itsT / dist) * (sinphi / sinphi2) * cos2);
                    baseVal = baseVal * (sinphi > Scalar(EPSILON)) * (sinphi2 > Scalar(EPSILON));
                    baseVal = baseVal * Where(sign0 == sign1, Ones(validEdgePoints2.pdf.GetShape()), -Ones(validEdgePoints2.pdf.GetShape()));

                    // auto indicesTri0 = Scalar(3) * its0.mTriangleId;
                    // auto indicesTri1 = Scalar(3) * its0.mTriangleId + Scalar(1);
                    // auto indicesTri2 = Scalar(3) * its0.mTriangleId + Scalar(2);
                    // Expr u, v, w, t;
                    // auto indicesPos0 = IndexedRead(scene.mIndexPosBuffer, indicesTri0, 0);
                    // auto indicesPos1 = IndexedRead(scene.mIndexPosBuffer, indicesTri1, 0);
                    // auto indicesPos2 = IndexedRead(scene.mIndexPosBuffer, indicesTri2, 0);
                    // auto position0 = IndexedRead(scene.mPositionBuffer, indicesPos0, 0);
                    // auto position1 = IndexedRead(scene.mPositionBuffer, indicesPos1, 0);
                    // auto position2 = IndexedRead(scene.mPositionBuffer, indicesPos2, 0);
                    // RayIntersectAD(VectorNormalize(bss.p0 - its.mPosition), its.mPosition,
                    //     position0, position1 - position0, position2 - position0, u, v, t);
                    // w = Scalar(1.0f) - u - v;
                    // auto u2 = w * Detach(position0) + u * Detach(position1) + v * Detach(position2);
                    // xDotN = VectorDot(n, u2);
                }
                // Step 3: trace towards the sensor
                ray.mThroughput = Ones(Shape({ ray.mNumRays }, VecType::Vec3));
                ray.mRayIdx = Tensori::ArrayRange(ray.mNumRays);
                ray.mPrevPdf = Ones(ray.mNumRays);
                ray.mSpecular = True(ray.mNumRays);
                ray.mPixelIdx = Tensori::ArrayRange(ray.mNumRays);
                Tensorf value = IndexedRead(G1, validMask1.index, 0);
                Tensorf value2 = IndexedRead(value, validMask2.index, 0);
                Tensorf pHat2 = Detach(emittedRadiance * baseVal * value2) + Scalar(SHADOW_EPSILON);

                // for (int iBounce = 0; iBounce < 1; iBounce++)
                // {
                //     if (ray.mNumRays == 0) break;
                //     Ray rayNext;
                //     Intersection itsNext;
                //     Expr pixelCoor;
                //     Tensori rayIdx;
                //     Tensorf pathContrib = EvalImportance(scene, *scene.mSensors[0], ray, its, rayNext, itsNext, pixelCoor, rayIdx);
                //     pixelCoor = Tensori(pixelCoor);

                //     if (rayIdx.LinearSize() > 0)
                //     {
                //         pHat2 = pHat2 * Where(pathContrib > Scalar(0.0f), Scalar(1.0f), Scalar(0.0f));
                //         // pHat2 = Detach(IndexedWrite(pathContrib, rayIdx, pHat2.GetShape(), 0)) * pathContrib;
                //         // pHat2 = Detach(pHat2 * IndexedWrite(pathContrib, rayIdx, pHat2.GetShape(), 0));
                //         Tensori validIndex3 = IndexedRead(validMask2.index, rayIdx, 0);
                //         // Tensorf val = Scalar(1.0f / float(spp)) * IndexedRead(value0, rayIdx, 0) * Detach(pathContrib);
                //         // boundaryTerm = boundaryTerm + camera.WriteToImage(val, pixelCoor);
                //     }

                //     ray = rayNext;
                // }

                Tensorf pHat1 = Zeros(Shape({ validEdgePoints.numSamples }, VecType::Scalar1));
                pHat1 = pHat1 + IndexedWrite(Luminance(Abs(pHat2)), validMask2.index, pHat1.GetShape(), 0);
                pHat = pHat + IndexedWrite(pHat1, validMask1.index, pHat.GetShape(), 0);
                // validIndex = validMask1.index;
                validIndex = IndexedRead(validMask1.index, validMask2.index, 0);
            }
            return Abs(pHat);
        }

        int DirectBoundaryIntegrator2::SampleBoundarySegment(const Scene& scene, 
                const SecondaryEdgeInfo &secEdges, int numSamples, const Tensorf& rnd_b, 
                const Tensorf& pdf_b, EdgeDirectSample& samples) const {
            auto samples_x = X(rnd_b);
            auto samples_y = Y(rnd_b);
            auto samples_z = Z(rnd_b);

            EdgeSampleMaterialRep materialRep;
            BoundarySegSampleDirect secEdgeSamples;
            Tensorf edgeSamplingPdf;
            Tensori edgeIdx = SampleFromSecEdgesMaterial(secEdges, samples_x, materialRep, edgeSamplingPdf);
            ConvertEdgeSampleMaterialToSpatial(scene, materialRep, secEdgeSamples);
            secEdgeSamples.pdf = edgeSamplingPdf;

            // Sample a point on the light sources
            MaterialVertices lightMaterialRep;
            auto rndLight = MakeVector2(samples_y, samples_z);
            SampleEmitterDirect(scene, rndLight, lightMaterialRep);
            SpatialVertices lightSpatialRep;
            ConvertToSpatialVertices(scene, lightMaterialRep, lightSpatialRep);

            samples.edgePointsPdf = edgeSamplingPdf;
            samples.edgePoints = std::make_shared<BoundarySegSampleSecondary>(secEdgeSamples);
            samples.isBoundary = IndexedRead(secEdges.isBoundary, edgeIdx, 0);
            samples.lightSamples = std::make_shared<SpatialVertices>(lightSpatialRep);
            samples.lightPdf = lightMaterialRep.pdf;

            return 1;
        }

        int DirectBoundaryIntegrator2::SampleBoundarySegment(const Scene& scene, const SecondaryEdgeInfo &secEdges, 
                int numSamples, const Tensorf& rnd_b, const Tensorf& pdf_b, BoundarySegSampleDirect& secEdgeSamples) const {
            auto samples_x = X(rnd_b);
            auto samples_y = Y(rnd_b);
            auto samples_z = Z(rnd_b);

            EdgeSampleMaterialRep materialRep;
            Tensorf edgeSamplingPdf;
            Tensori edgeIdx = SampleFromSecEdgesMaterial(secEdges, samples_x, materialRep, edgeSamplingPdf);
            ConvertEdgeSampleMaterialToSpatial(scene, materialRep, secEdgeSamples);
            secEdgeSamples.pdf = edgeSamplingPdf;

            // Sample a point on the light sources
            auto rnd_light = MakeVector2(samples_y, samples_z);
            PositionSample lightSamples;
            scene.mLights[scene.mAreaLightIndex]->Sample(rnd_light, lightSamples);

            secEdgeSamples.p2 = lightSamples.p;
            secEdgeSamples.n = lightSamples.n;
            auto e = secEdgeSamples.p2 - Detach(secEdgeSamples.p0);
            auto distSqr = VectorSquaredLength(e);
            auto dist = Sqrt(distSqr);
            auto eNormalized = e / dist;
            auto cosTheta = VectorDot(secEdgeSamples.n, -eNormalized);
            auto isBoundary = IndexedRead(secEdges.isBoundary, edgeIdx, 0);
            auto cosine0 = VectorDot(eNormalized, secEdgeSamples.n0);
            auto cosine1 = VectorDot(eNormalized, secEdgeSamples.n1);
            auto valid0 = Abs(cosine0) > Scalar(EDGE_EPSILON);
            auto valid1 = (cosine0 > Scalar(EDGE_EPSILON) && cosine1 < Scalar(-EDGE_EPSILON)) || (cosine0 < Scalar(-EDGE_EPSILON) && cosine1 > Scalar(EDGE_EPSILON));
            auto rightSide = (isBoundary && valid0) || (~isBoundary && valid1);
            auto valid2 = cosTheta > Scalar(EPSILON);
            auto valid3 = dist > Scalar(SHADOW_EPSILON);
            secEdgeSamples.pdf = secEdgeSamples.pdf * lightSamples.pdf * distSqr / cosTheta * pdf_b;
            secEdgeSamples.maskValid = IndexMask(rightSide && valid2 && valid3);

            if (secEdgeSamples.maskValid.sum > 0)
                secEdgeSamples = secEdgeSamples.getValidCopy();
            return secEdgeSamples.maskValid.sum;
        }

        int DirectBoundaryIntegrator2::EvalBoundarySegment(const Camera& camera, const Scene& scene, 
                BoundarySegSampleDirect& bss, Tensorf& boundaryTerm, int spp) {
            // Step 2: Compute the contrib from valid boundary segments (AD)
            Tensorf rayDir = Detach(VectorNormalize(bss.p0 - bss.p2));
            Ray ray(bss.p0, rayDir);
            Intersection its;
            Expr emittedRadiance, baseVal, xDotN;
            {
                Intersection its0;
                Ray ray0(bss.p0, -rayDir);
                scene.Intersect(ray0, its0);
                scene.Intersect(ray, its);
                auto hitP = ray0.mOrg + its0.mTHit * ray0.mDir;
                Tensorb samePoint = VectorLength(hitP - bss.p2) < Scalar(SHADOW_EPSILON);
                bss.maskValid = IndexMask(samePoint && its0.mTriangleId != Scalar(-1)
                    && its.mTriangleId != Scalar(-1));
                if (bss.maskValid.sum == 0)
                {
                    return 0;
                }

                ray = ray.GetMaskedCopy(bss.maskValid);
                its0 = its0.GetMaskedCopy(bss.maskValid);
                its = its.GetMaskedCopy(bss.maskValid);
                Tensori validIndex;
                bss = bss.getValidCopy();

                emittedRadiance = Zeros(Shape({ ray.mNumRays }, VecType::Vec3));
                IndexMask mask_light = its0.mEmitterId != Scalar(-1);
                if (mask_light.sum > 0) 
                {
                    Intersection its_light = its0.GetMaskedCopy(mask_light);
                    auto dir = Mask(ray.mDir, mask_light, 0);
                    scene.PostIntersect(its_light);
                    auto val = Detach(scene.mLights[scene.mAreaLightIndex]->Eval(its_light, dir));
                    emittedRadiance = emittedRadiance + IndexedWrite(val, mask_light.index, emittedRadiance->GetShape(), 0);
                }

                scene.PostIntersect(its);
                Tensorf d_position = its.mPosition;
                auto dist = VectorLength(bss.p2 - its.mPosition);
                auto cos2 = Abs(VectorDot(bss.n, ray.mDir));
                auto e = VectorCross(bss.edge, -ray.mDir);
                auto sinphi = VectorLength(e);
                auto proj = VectorNormalize(VectorCross(e, bss.n));
                auto sinphi2 = VectorLength(VectorCross(-ray.mDir, proj));
                auto itsT = VectorLength(its.mPosition - bss.p0);
                auto n = Detach(VectorNormalize(VectorCross(bss.n, proj)));
                auto sign0 = VectorDot(e, bss.edge2) > Scalar(0.0f);
                auto sign1 = VectorDot(e, n) > Scalar(0.0f);
                baseVal = Detach((itsT / dist) * (sinphi / sinphi2) * cos2);
                baseVal = baseVal * (sinphi > Scalar(EPSILON)) * (sinphi2 > Scalar(EPSILON));
                baseVal = baseVal * Where(sign0 == sign1, Ones(bss.pdf.GetShape()), -Ones(bss.pdf.GetShape()));

                auto indicesTri0 = Scalar(3) * its0.mTriangleId;
                auto indicesTri1 = Scalar(3) * its0.mTriangleId + Scalar(1);
                auto indicesTri2 = Scalar(3) * its0.mTriangleId + Scalar(2);
                Expr u, v, w, t;
                auto indicesPos0 = IndexedRead(scene.mIndexPosBuffer, indicesTri0, 0);
                auto indicesPos1 = IndexedRead(scene.mIndexPosBuffer, indicesTri1, 0);
                auto indicesPos2 = IndexedRead(scene.mIndexPosBuffer, indicesTri2, 0);
                auto position0 = IndexedRead(scene.mPositionBuffer, indicesPos0, 0);
                auto position1 = IndexedRead(scene.mPositionBuffer, indicesPos1, 0);
                auto position2 = IndexedRead(scene.mPositionBuffer, indicesPos2, 0);
                RayIntersectAD(VectorNormalize(bss.p0 - its.mPosition), its.mPosition,
                    position0, position1 - position0, position2 - position0, u, v, t);
                w = Scalar(1.0f) - u - v;
                auto u2 = w * Detach(position0) + u * Detach(position1) + v * Detach(position2);
                xDotN = VectorDot(n, u2);
            }
            // Step 3: trace towards the sensor
            ray.mThroughput = Ones(Shape({ ray.mNumRays }, VecType::Vec3));
            ray.mRayIdx = Tensori::ArrayRange(ray.mNumRays);
            ray.mPrevPdf = Ones(ray.mNumRays);
            ray.mSpecular = True(ray.mNumRays);
            ray.mPixelIdx = Tensori::ArrayRange(ray.mNumRays);
            Tensorf value0 = Detach(emittedRadiance * baseVal / bss.pdf) * xDotN;

            for (int iBounce = 0; iBounce < 1; iBounce++)
            {
                if (ray.mNumRays == 0) break;
                Ray rayNext;
                Intersection itsNext;
                Expr pixelCoor;
                Tensori rayIdx;
                Tensorf pathContrib = EvalImportance(scene, camera, ray, its, rayNext, itsNext, pixelCoor, rayIdx);

                if (rayIdx.LinearSize() > 0)
                {
                    Tensorf val = Scalar(1.0f / float(spp)) * IndexedRead(value0, rayIdx, 0) * Detach(pathContrib);
                    boundaryTerm = boundaryTerm + camera.WriteToImage(val, pixelCoor);
                }

                ray = rayNext;
                its = itsNext;
            }

            return 1;
        }

        int DirectBoundaryIntegrator2::EvalBoundarySegment(const Scene& scene, 
                const EdgeDirectSample& edgeSampleDirect, Tensorf& boundaryTerm, float spp) const {
            const BoundarySegSampleSecondary& edgePoints = *edgeSampleDirect.edgePoints;
            const SpatialVertices& lightSamples = *edgeSampleDirect.lightSamples;
            auto e = lightSamples.position - Detach(edgePoints.p0);
            auto distSqr = VectorSquaredLength(e);
            auto dist = Sqrt(distSqr);
            auto eNormalized = e / dist;
            auto cosTheta = VectorDot(lightSamples.normal, -eNormalized);
            auto isBoundary = edgeSampleDirect.isBoundary;
            auto cosine0 = VectorDot(eNormalized, edgePoints.n0);
            auto cosine1 = VectorDot(eNormalized, edgePoints.n1);
            auto valid0 = Abs(cosine0) > Scalar(EDGE_EPSILON);
            auto valid1 = (cosine0 > Scalar(EDGE_EPSILON) && cosine1 < Scalar(-EDGE_EPSILON)) || (cosine0 < Scalar(-EDGE_EPSILON) && cosine1 > Scalar(EDGE_EPSILON));
            auto rightSide = (isBoundary && valid0) || (~isBoundary && valid1);
            auto valid2 = cosTheta > Scalar(EPSILON);
            auto valid3 = dist > Scalar(SHADOW_EPSILON);
            Tensorf G1 = cosTheta / distSqr;
            Tensorf pdf = edgeSampleDirect.edgePointsPdf * edgeSampleDirect.lightPdf;
            IndexMask validMask1(rightSide && valid2 && valid3);

            const BoundarySegSampleSecondary& validEdgePoints1 = edgePoints.GetMaskedCopy(validMask1);
            const SpatialVertices& lightSamples1 = lightSamples.GetMaskedCopy(validMask1);
            Tensorf rayDir = Detach(VectorNormalize(validEdgePoints1.p0 - lightSamples1.position));
            Ray ray(validEdgePoints1.p0, rayDir);
            Intersection its;
            Expr emittedRadiance, baseVal, xDotN;
            IndexMask validMask2;
            {
                Intersection its0;
                Ray ray0(validEdgePoints1.p0, -rayDir);
                scene.Intersect(ray0, its0);
                scene.Intersect(ray, its);
                auto hitP = ray0.mOrg + its0.mTHit * ray0.mDir;
                Tensorb samePoint = VectorLength(hitP - lightSamples1.position) < Scalar(SHADOW_EPSILON);
                validMask2 = IndexMask(samePoint && its0.mTriangleId != Scalar(-1)
                    && its.mTriangleId != Scalar(-1));
                if (validMask2.sum == 0)
                {
                    return 0;
                }

                ray = ray.GetMaskedCopy(validMask2);
                its0 = its0.GetMaskedCopy(validMask2);
                its = its.GetMaskedCopy(validMask2);
                Tensori validIndex;
                const BoundarySegSampleSecondary& validEdgePoints2 = validEdgePoints1.GetMaskedCopy(validMask2);
                const SpatialVertices& lightSamples2 = lightSamples1.GetMaskedCopy(validMask2);

                emittedRadiance = Zeros(Shape({ ray.mNumRays }, VecType::Vec3));
                IndexMask mask_light = its0.mEmitterId != Scalar(-1);
                if (mask_light.sum > 0) 
                {
                    Intersection its_light = its0.GetMaskedCopy(mask_light);
                    auto dir = Mask(ray.mDir, mask_light, 0);
                    scene.PostIntersect(its_light);
                    auto val = Detach(scene.mLights[scene.mAreaLightIndex]->Eval(its_light, dir));
                    emittedRadiance = emittedRadiance + IndexedWrite(val, mask_light.index, emittedRadiance->GetShape(), 0);
                }

                scene.PostIntersect(its);
                Tensorf d_position = its.mPosition;
                auto dist = VectorLength(lightSamples2.position - its.mPosition);
                auto cos2 = Abs(VectorDot(lightSamples2.normal, ray.mDir));
                auto e = VectorCross(validEdgePoints2.edge, -ray.mDir);
                auto sinphi = VectorLength(e);
                auto proj = VectorNormalize(VectorCross(e, lightSamples2.normal));
                auto sinphi2 = VectorLength(VectorCross(-ray.mDir, proj));
                auto itsT = VectorLength(its.mPosition - validEdgePoints2.p0);
                auto n = Detach(VectorNormalize(VectorCross(lightSamples2.normal, proj)));
                auto sign0 = VectorDot(e, validEdgePoints2.edge2) > Scalar(0.0f);
                auto sign1 = VectorDot(e, n) > Scalar(0.0f);
                baseVal = Detach((itsT / dist) * (sinphi / sinphi2) * cos2);
                baseVal = baseVal * (sinphi > Scalar(EPSILON)) * (sinphi2 > Scalar(EPSILON));
                baseVal = baseVal * Where(sign0 == sign1, Ones(validEdgePoints2.pdf.GetShape()), -Ones(validEdgePoints2.pdf.GetShape()));

                auto indicesTri0 = Scalar(3) * its0.mTriangleId;
                auto indicesTri1 = Scalar(3) * its0.mTriangleId + Scalar(1);
                auto indicesTri2 = Scalar(3) * its0.mTriangleId + Scalar(2);
                Expr u, v, w, t;
                auto indicesPos0 = IndexedRead(scene.mIndexPosBuffer, indicesTri0, 0);
                auto indicesPos1 = IndexedRead(scene.mIndexPosBuffer, indicesTri1, 0);
                auto indicesPos2 = IndexedRead(scene.mIndexPosBuffer, indicesTri2, 0);
                auto position0 = IndexedRead(scene.mPositionBuffer, indicesPos0, 0);
                auto position1 = IndexedRead(scene.mPositionBuffer, indicesPos1, 0);
                auto position2 = IndexedRead(scene.mPositionBuffer, indicesPos2, 0);
                RayIntersectAD(VectorNormalize(validEdgePoints2.p0 - its.mPosition), its.mPosition,
                    position0, position1 - position0, position2 - position0, u, v, t);
                w = Scalar(1.0f) - u - v;
                auto u2 = w * Detach(position0) + u * Detach(position1) + v * Detach(position2);
                xDotN = VectorDot(n, u2);
            }
            // Step 3: trace towards the sensor
            ray.mThroughput = Ones(Shape({ ray.mNumRays }, VecType::Vec3));
            ray.mRayIdx = Tensori::ArrayRange(ray.mNumRays);
            ray.mPrevPdf = Ones(ray.mNumRays);
            ray.mSpecular = True(ray.mNumRays);
            ray.mPixelIdx = Tensori::ArrayRange(ray.mNumRays);
            Tensorf value = IndexedRead(G1 * Where(pdf > Scalar(0.0f), Scalar(1.0f) / pdf, Scalar(0.0f)), validMask1.index, 0);
            Tensorf value2 = IndexedRead(value, validMask2.index, 0);
            Tensorf value0 = Detach(emittedRadiance * baseVal * value2) * xDotN;

            const Camera& camera = *(scene.mSensors[0]);
            for (int iBounce = 0; iBounce < 1; iBounce++)
            {
                if (ray.mNumRays == 0) break;
                Ray rayNext;
                Intersection itsNext;
                Expr pixelCoor;
                Tensori rayIdx;
                Tensorf pathContrib = EvalImportance(scene, camera, ray, its, rayNext, itsNext, pixelCoor, rayIdx);

                if (rayIdx.LinearSize() > 0)
                {
                    Tensorf val = NaNToZero(Scalar(1.0f / float(spp)) * IndexedRead(value0, rayIdx, 0) * Detach(pathContrib));
                    boundaryTerm = boundaryTerm + camera.WriteToImage(val, pixelCoor);
                }

                ray = rayNext;
                its = itsNext;
            }

            return 1;
        }

        void DirectBoundaryIntegrator2::SampleEmitterDirect(const Scene& scene, const Tensorf& rndLight, MaterialVertices& vLight) const {
            // Sample point on area lights
            PositionSample lightSamples;
            scene.mLights[scene.mAreaLightIndex]->Sample(rndLight, lightSamples);

            vLight.numOfSamples = rndLight.NumElements();
            vLight.prevId = Zeros(Shape({rndLight.NumElements()}, VecType::Scalar1));
            vLight.triangleId = IndexedRead(scene.mEmitTriIdToTriIdBuffer, lightSamples.triangleId, 0);
            vLight.emitterId = lightSamples.lightId;
            vLight.bsdfId = Ones(vLight.triangleId.GetShape()) * Scalar(-1);
            vLight.u = lightSamples.baryU;
            vLight.v = lightSamples.baryV;
            vLight.pdf = lightSamples.pdf;
        }

        void DirectBoundaryIntegrator2::Step() {
            if (!haveTemporalReuse) return;
            if (reservoirsOnFile != nullptr && reservoirsOnFile->numOfSamples == 0) return;
            EdgeReservoir r;
            r = *reservoirsOnFile;
            // // if (reservoirs->numOfSamples > 0) {
            // //     IndexMask nonreuseReservoirMask = IndexMask((reuseMask > Scalar(0)) && (T < Scalar(1)));
            // //     if (nonreuseReservoirMask.sum > 0) {
            // //         Tensori tempT = Zeros(Shape({reservoirsOnFile->numOfSamples}, VecType::Scalar1));
            // //         T = Concat(tempT, IndexedRead(T, nonreuseReservoirMask.index, 0), 0);
            // //         EdgeReservoir nonreuseReservoir = reservoirs->GetIndexedCopy(nonreuseReservoirMask.index, nonreuseReservoirMask.sum);
            // //         EdgeReservoir::Combine(*reservoirsOnFile, nonreuseReservoir, r);
            // //     } else {
            // //         r = *reservoirsOnFile;
            // //         T = Zeros(Shape({reservoirsOnFile->numOfSamples}, VecType::Scalar1));
            // //     }
            // // } else {
            // //     r = *reservoirsOnFile;
            // //     T = Zeros(Shape({reservoirsOnFile->numOfSamples}, VecType::Scalar1));
            // // }
            // // T = T + Scalar(1);
            reservoirs.reset();
            reservoirs = std::make_shared<EdgeReservoir>(r);
            reuseMask = Zeros(Shape({reservoirs->numOfSamples}, VecType::Scalar1));
            reservoirsOnFile.reset();
            reservoirsOnFile = std::make_shared<EdgeReservoir>();
        }

        std::shared_ptr<EdgeReservoir> DirectBoundaryIntegrator2::PrefilerReservoir(const Scene& scene, 
            const Camera& camera,
            const std::shared_ptr<EdgeReservoir>& edgeReservoirs, bool updateReuseMask) const {
            BoundarySegSampleSecondary edgePoints;
            ConvertEdgeSampleMaterialToSpatial(scene, 
                        *edgeReservoirs->edgePoints->materialRep,
                        edgePoints);
            SpatialVertices lightPoints;
            ConvertToSpatialVertices(scene, *edgeReservoirs->x->vertices, lightPoints);
            Tensorf edgeSecRayDir = Detach(VectorNormalize(edgePoints.p0 - lightPoints.position));
            Ray edgeSecRay(edgePoints.p0, edgeSecRayDir);
            Ray edgeSecRay1(edgePoints.p0, -edgeSecRayDir);
            Intersection edgeSecIts1;
            scene.Intersect(edgeSecRay1, edgeSecIts1);
            Tensorf edgeSecRay1HitP = edgeSecRay1.mOrg + edgeSecIts1.mTHit * edgeSecRay1.mDir;
            Tensorb sameLightPoints = VectorLength(edgeSecRay1HitP - lightPoints.position) < Scalar(SHADOW_EPSILON);
            Intersection edgeSecIts;
            scene.Intersect(edgeSecRay, edgeSecIts);

            IndexMask edgeSecValidMask = (edgeSecIts.mTriangleId != Scalar(-1) 
                    && edgeSecIts1.mTriangleId != Scalar(-1)
                    && sameLightPoints);
            edgeSecIts = edgeSecIts.GetMaskedCopy(edgeSecValidMask);
            scene.PostIntersect(edgeSecIts);

            // Check if the shading point is in image plane
            SensorDirectSample sds = camera.sampleDirect(edgeSecIts.mPosition);
            IndexMask isInImagePlaneMask(sds.isValid);
            // Check visibility
            Ray primaryRay;
            camera.GenerateBoundaryRays(sds, primaryRay);
            Tensorf pToCamDist = VectorLength(IndexedRead(edgeSecIts.mPosition, isInImagePlaneMask.index, 0) - 
                    camera.mPosTensor);
            primaryRay.mMin = Scalar(SHADOW_EPSILON);
            primaryRay.mMax = pToCamDist - Scalar(SHADOW_EPSILON);
            Tensorb isVisible;
            scene.Occluded(primaryRay, isVisible);
            isVisible = IndexedWrite(isVisible, isInImagePlaneMask.index,
                    isInImagePlaneMask.mask.GetShape(), 0);

            IndexMask validReservoirIndex(isInImagePlaneMask.mask * isVisible);
            if (updateReuseMask) {
                reuseMask = reuseMask * IndexedWrite((Scalar(1) - validReservoirIndex.mask),
                    edgeSecValidMask.index,
                    reuseMask.GetShape(), 0);
                reuseMask = reuseMask * edgeSecValidMask.mask;
            }
            Tensori originalIndex = IndexedRead(edgeSecValidMask.index, validReservoirIndex.index, 0);
            return std::make_shared<EdgeReservoir>(edgeReservoirs->GetIndexedCopy(originalIndex, validReservoirIndex.sum));
        }
        
    }
}
