/*
 * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: BSD-3-Clause
 * 
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * 1. Redistributions of source code must retain the above copyright notice, this
 * list of conditions and the following disclaimer.
 *
 * 2. Redistributions in binary form must reproduce the above copyright notice,
 * this list of conditions and the following disclaimer in the documentation
 * and/or other materials provided with the distribution.
 *
 * 3. Neither the name of the copyright holder nor the names of its
 * contributors may be used to endorse or promote products derived from
 * this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

#include "BoundaryPixel2.h"
#include "cukd/cpfinder.h"
#include "cukd/knnfinder.h"
#include "LibTensorRay/Renderer/Integrator.h"

namespace EDX
{
namespace TensorRay
{

void PixelBoundaryIntegrator2::Integrate(const Scene& scene, Tensorf& image) const
{
#if USE_PROFILING
    nvtxRangePushA(__FUNCTION__);
#endif
    if (mSpp == 0) return;
    const Camera& camera = *scene.mSensors[0];

    Timer timer;
    timer.Start();
    Shape imageShape = Shape({ camera.mResX * camera.mResY }, VecType::Vec3);
    int npass = std::ceil(mSpp / mSppBatch);
    for (int ipass = 0; ipass < npass; ipass++) 
    {
        int numRays = camera.mResX * camera.mResY * mSppBatch;
        int numRaysPerAnti = numRays / mAntitheticSpp;

        BoundarySegSamplePixel bss;
        // Step 1: Sample point on the pixel boundary
        SampleBoundarySegmentPixel(camera, mSppBatch, mAntitheticSpp, bss);

        // Step 2: Check if the ray hit anything
        Intersection its;
        TensorRay::Ray rayFromEdge(bss.p0, VectorNormalize(bss.p0 - camera.mPosTensor));                
        scene.Intersect(rayFromEdge, its);
        bss.maskValid = IndexMask(its.mBsdfId != Scalar(-1));
        if (bss.maskValid.sum == 0)
        {
#if USE_PROFILING
            nvtxRangePop();
#endif
            return;
        }
        Expr valid_boundary_ray_index = bss.maskValid.index;
        Expr valid_boundary_pixel_index = Mask(bss.pixelIdx, bss.maskValid, 0);

        Tensorf hitT = Mask(its.mTHit, bss.maskValid, 0);
        auto hitP = Mask(rayFromEdge.mOrg, bss.maskValid, 0) + hitT * Mask(rayFromEdge.mDir, bss.maskValid, 0);
        its = its.GetMaskedCopy(bss.maskValid);
        bss = bss.getValidCopy();

        // hitP should always be visible to camera
        SensorDirectSample sds = camera.sampleDirect(hitP);
        // Overwrite valid mask and pixel index
        sds.isValid = bss.maskValid.mask;
        sds.pixelIdx = bss.pixelIdx;

        TensorRay::Ray rays;        // Ray from camera
        camera.GenerateBoundaryRays(sds, rays);
        rays.mRayIdx = bss.rayIdx;  // Overwrite rayIdx for antithetic sampling

        // Step 4: Compute the boundary contribution
        Expr baseVal;
        scene.PostIntersect(its);
        auto dist = VectorLength(its.mPosition - camera.mPosTensor);
        auto dist1 = VectorLength(bss.p0 - camera.mPosTensor);
        auto cos2 = Abs(VectorDot(its.mGeoNormal, -rays.mDir));
        auto e = VectorCross(bss.edge, rays.mDir);
        auto sinphi = VectorLength(e);
        auto proj = VectorNormalize(VectorCross(e, its.mGeoNormal));
        auto sinphi2 = VectorLength(VectorCross(rays.mDir, proj));
        auto n = Detach(VectorNormalize(VectorCross(its.mGeoNormal, proj)));
        auto sign0 = VectorDot(e, bss.edge2) > Scalar(0.0f);
        auto sign1 = VectorDot(e, n) > Scalar(0.0f);
        baseVal = (dist / dist1) * (sinphi / sinphi2) * cos2;
        baseVal = baseVal * (sinphi > Scalar(EPSILON)) * (sinphi2 > Scalar(EPSILON));
        baseVal = baseVal * Where(sign0 == sign1, -Ones(bss.pdf.GetShape()), Ones(bss.pdf.GetShape()));  // revert signs

        auto indicesTri0 = Scalar(3) * its.mTriangleId;
        auto indicesTri1 = Scalar(3) * its.mTriangleId + Scalar(1);
        auto indicesTri2 = Scalar(3) * its.mTriangleId + Scalar(2);
        Expr u, v, w, t;
        auto indicesPos0 = IndexedRead(scene.mIndexPosBuffer, indicesTri0, 0);
        auto indicesPos1 = IndexedRead(scene.mIndexPosBuffer, indicesTri1, 0);
        auto indicesPos2 = IndexedRead(scene.mIndexPosBuffer, indicesTri2, 0);
        auto position0 = IndexedRead(scene.mPositionBuffer, indicesPos0, 0);
        auto position1 = IndexedRead(scene.mPositionBuffer, indicesPos1, 0);
        auto position2 = IndexedRead(scene.mPositionBuffer, indicesPos2, 0);
        RayIntersectAD(VectorNormalize(bss.p0 - camera.mPosTensor), camera.mPosTensor,
            position0, position1 - position0, position2 - position0, u, v, t);
        w = Scalar(1.0f) - u - v;
        auto u2 = w * Detach(position0) + u * Detach(position1) + v * Detach(position2);
        auto xDotN = VectorDot(n, u2);

        Tensorf radiance = Zeros(Shape({ numRays }, VecType::Vec3));
        auto Le = Detach(EvalRadianceEmitted(scene, rays, its));
        radiance = radiance + IndexedWrite(Le, rays.mRayIdx, radiance.GetShape(), 0);
        // for (int iBounce = 0; iBounce < mMaxBounces; iBounce++)
        // {
        //     if (rays.mNumRays == 0) break;
        //     Ray raysNext;
        //     Intersection itsNext;
        //     Tensorf antithetic_rnd_light = Tensorf::RandomFloat(Shape({ numRaysPerAnti }, VecType::Vec2));
        //     Tensorf antithetic_rnd_bsdf = Tensorf::RandomFloat(Shape({ numRaysPerAnti }, VecType::Vec3));
        //     Expr rnd_light = IndexedRead(antithetic_rnd_light, rays.mRayIdx % Scalar(numRaysPerAnti), 0);
        //     Expr rnd_bsdf = IndexedRead(antithetic_rnd_bsdf, rays.mRayIdx % Scalar(numRaysPerAnti), 0);
        //     Expr val = Detach(EvalRadianceDirect(scene, rays, its, rnd_light, rnd_bsdf, raysNext, itsNext));
        //     radiance = radiance + IndexedWrite(val, rays.mRayIdx, radiance.GetShape(), 0);
        //     rays = raysNext;
        //     its = itsNext;
        // }
        if (rays.mNumRays > 0) {
            SpatialVertices s1Spatial;
            SpatialVertices l1Spatial;
            SampleEmitter(scene, rays, its, s1Spatial, l1Spatial);

            // Spatial vertices for camera
            SpatialVertices vCamera;
            Tensorf contrib;
            vCamera.position = Broadcast(camera.mPosTensor, Shape({ s1Spatial.numOfSamples }, VecType::Vec3));
            EvalScreenSpaceReSTIR(scene, vCamera, s1Spatial, l1Spatial, contrib);

            radiance = radiance + Detach(IndexedWrite(rays.mThroughput * contrib, rays.mRayIdx, radiance.GetShape(), 0));
        }

        radiance = IndexedRead(radiance, valid_boundary_ray_index, 0);
        Tensorf boundaryTerm = Detach(radiance * baseVal / bss.pdf) * xDotN;
        boundaryTerm = Tensorf({ 1.0f / float(mSpp) }) * 
            IndexedWrite(boundaryTerm, valid_boundary_pixel_index, image.GetShape(), 0);

        boundaryTerm = boundaryTerm - Detach(boundaryTerm);
        mGradHandler.AccumulateDeriv(boundaryTerm);
        if (!mDLoss.Empty()) 
        {
            boundaryTerm.Backward(mDLoss);
            AccumulateGradsAndReleaseGraph();
        }

        if (mVerbose)
            std::cout << string_format("[PixelBoundary2] #Pass %d / %d, %d kernels launched\r", ipass + 1, npass, KernelLaunchCounter::GetHandle());
        KernelLaunchCounter::Reset();
    }
    if (mVerbose)
        std::cout << string_format("[PixelBoundary2] Total Elapsed time = %f (%f samples/pass, %d pass)", timer.GetElapsedTime(), mSppBatch, npass) << std::endl;
#if USE_PROFILING
        nvtxRangePop();
#endif
}

void PixelBoundaryIntegrator2::SampleEmitter(const Scene& scene, const Ray& rays, Intersection& its, SpatialVertices& s1Spatial,
            SpatialVertices& spatialLight) const {
    MaterialVertices s1Material;
    if (rays.mNumRays > 0)
    {
        scene.PostIntersect(its);
        s1Material.numOfSamples = rays.mNumRays;
        s1Material.prevId = rays.mPixelIdx;
        s1Material.triangleId = its.mTriangleId;
        s1Material.emitterId = its.mEmitterId;
        s1Material.bsdfId = its.mBsdfId;
        s1Material.u = its.mBaryU;
        s1Material.v = its.mBaryV;
        s1Material.pdf = Ones(s1Material.u.GetShape());
    }

    const Camera& camera = *scene.mSensors[0];
    int nRaysInit = mSppBatch * camera.mResX * camera.mResY;
    int numRaysPerAnti = nRaysInit / mAntitheticSpp;

    Reservoir curIterationReservoir(rays.mNumRays);
    ConvertToSpatialVertices(scene, s1Material, *(curIterationReservoir.shadingPoints));
    s1Spatial = *(curIterationReservoir.shadingPoints);
    for (int i = 0; i < M; ++i) {
        if (rays.mNumRays == 0) break;
        LightSample curLightSample;
        Tensorf antitheticRndLight = Tensorf::RandomFloat(Shape({ numRaysPerAnti }, VecType::Vec2));
        Tensorf rndLight = IndexedRead(antitheticRndLight, rays.mRayIdx % Scalar(numRaysPerAnti), 0);

        SampleEmitterDirect(scene, rays, its, rndLight, *curLightSample.vertices);

        SpatialVertices curSpatialLightSamples;
        ConvertToSpatialVertices(scene, *curLightSample.vertices, curSpatialLightSamples);

        curLightSample.pHat = EvalPHat(scene, rays, 
                *(curIterationReservoir.shadingPoints), curSpatialLightSamples, true);
        Tensorf rndReservoirSample = Tensorf::RandomFloat(
            Shape({ numRaysPerAnti }, VecType::Scalar1));
        Expr rndReservoir = IndexedRead(rndReservoirSample, rays.mRayIdx % Scalar(numRaysPerAnti), 0);
        UpdateReservoir(curIterationReservoir, curLightSample, 
            curLightSample.pHat / curLightSample.vertices->pdf, Scalar(1), rndReservoir);
    }
    curIterationReservoir.W = Detach(Where(curIterationReservoir.x->pHat > Scalar(0.0f), 
                                    curIterationReservoir.wAverage / curIterationReservoir.x->pHat, 
                                    Scalar(0.0f)));

    std::shared_ptr<Reservoir> reservoir(mPathSampler->reservoir);
    if (haveTemporalReuse && reservoir != nullptr, reservoir->numOfSamples > 0) {
        Tensori pickedIndex = Zeros(Shape({curIterationReservoir.numOfSamples}, VecType::Scalar1));
        Expr m0 = curIterationReservoir.m;
        Expr w0 = curIterationReservoir.wAverage;
        // reservoir->m = Where(reservoir->m > Scalar(M * historyLength), Scalar(M * historyLength), reservoir->m);

        SpatialVertices prevSpatialPoints;
        ConvertToSpatialVertices(scene, *reservoir->shadingPoints, prevSpatialPoints);

        std::shared_ptr<Reservoir> prefiltReservoir = PrefilerReservoir(scene,
            reservoir, camera, prevSpatialPoints.position);
        prefiltReservoir->m = Where(prefiltReservoir->m > Scalar(M * 4), Scalar(M * 4), prefiltReservoir->m);
        Tensori localShiftMapping = Zeros(Shape({ numRaysPerAnti * k }, VecType::Scalar1));
        IndexMask validReservoirIndex = (prefiltReservoir->m > Scalar(0));
        {
            Tensorf curPositions = Zeros(Shape({numRaysPerAnti}, VecType::Vec3));
            curPositions = IndexedRead(curIterationReservoir.shadingPoints->position, 
                        Tensori::ArrayRange(0, numRaysPerAnti, 1, false), 0);
            for (int j = 1; j < mAntitheticSpp; ++j) {
                curPositions = curPositions + IndexedRead(curIterationReservoir.shadingPoints->position, 
                            Tensori::ArrayRange(j * numRaysPerAnti, (j + 1) * numRaysPerAnti, 1, false), 0);
            }
            curPositions = curPositions / Scalar(mAntitheticSpp);

            ConvertToSpatialVertices(scene, 
                prefiltReservoir->shadingPoints->GetMaskedCopy(validReservoirIndex), 
                prevSpatialPoints);
            cukd::FindKNN(prevSpatialPoints.position.Data(), 
                        prevSpatialPoints.numOfSamples, 
                        curPositions.Data(), 
                        numRaysPerAnti, 
                        localShiftMapping.Data(), 
                        k);
            // Tensori localShiftMapping1;
            Tensori localShiftMapping1;
            for (int j = 0; j < k; ++j) {
                Tensori curLocalShiftMapping = IndexedRead(localShiftMapping, 
                    Tensori::ArrayRange(numRaysPerAnti * j, numRaysPerAnti * (j + 1), 1, false), 0);
                Tensori curReservoirMapping = IndexedRead(curLocalShiftMapping, rays.mRayIdx % Scalar(numRaysPerAnti), 0);
                if (j == 0) {
                    localShiftMapping1 = curReservoirMapping;
                } else {
                    localShiftMapping1 = Concat(localShiftMapping1, curReservoirMapping, 0);
                }
            }
            localShiftMapping = localShiftMapping1;
            localShiftMapping.CopyToHost();
        }

        shuffle(reservoirToUse.begin(), reservoirToUse.end(), randomGenerator);
        for (int j = 0; j < k1; ++j) {
            int i = reservoirToUse[j];
            Tensori curLocalShiftMapping = IndexedRead(localShiftMapping, 
                Tensori::ArrayRange(rays.mNumRays * i, rays.mNumRays * (i + 1), 1, false), 0);
            Tensori globalShiftMapping = IndexedRead(validReservoirIndex.index, 
                curLocalShiftMapping, 0);
            Reservoir previousReservoirCopy = prefiltReservoir->GetIndexedCopy(
                globalShiftMapping, 
                rays.mNumRays);

            SpatialVertices prevSpatialLightSamples;
            ConvertToSpatialVertices(scene, *previousReservoirCopy.x->vertices, prevSpatialLightSamples);

            SpatialVertices previousShadingPoint;
            ConvertToSpatialVertices(scene, *(previousReservoirCopy.shadingPoints), previousShadingPoint);

            previousReservoirCopy.x->pHat = EvalPHat(scene, rays, *(curIterationReservoir.shadingPoints), prevSpatialLightSamples, true);

            // Apply normal test
            Expr normalDotProduct = VectorDot(previousReservoirCopy.shadingPoints->normal, curIterationReservoir.shadingPoints->normal);
            Expr dist = Mean(Square(previousReservoirCopy.shadingPoints->position - curIterationReservoir.shadingPoints->position));
            Expr nFilter = ((normalDotProduct > Scalar(reservoirMergeNormalThreshold)) * 
                            (dist < Scalar(reservoirMergeDistThreshold)));
            previousReservoirCopy.m = Where(nFilter, previousReservoirCopy.m, Scalar(0));
            previousReservoirCopy.W = Where(nFilter, previousReservoirCopy.W, Scalar(0.0f));
            previousReservoirCopy.wAverage = Where(nFilter, previousReservoirCopy.wAverage, Scalar(0.0f));

            Tensorf rndReservoirSample = Tensorf::RandomFloat(
                Shape({ numRaysPerAnti }, VecType::Scalar1));

            Tensorf rndReservoir = IndexedRead(rndReservoirSample, rays.mRayIdx % Scalar(numRaysPerAnti), 0);
            Tensori curMask = UpdateReservoir(curIterationReservoir, previousReservoirCopy, scene, rays, rndReservoir);
            pickedIndex = Where(curMask > Scalar(0), Scalar(j + 1), Scalar(0));
        }

        // Tensori Z = Zeros(Shape({ rays.mNumRays }, VecType::Scalar1));
        // Tensorf Z = Zeros(Shape({ rays.mNumRays }, VecType::Scalar1));
        // Tensorf _M = Zeros(Shape({ rays.mNumRays }, VecType::Scalar1));
        Tensorf pI = Zeros(Shape({ rays.mNumRays }, VecType::Scalar1));
        Tensorf pSum = Zeros(Shape({ rays.mNumRays }, VecType::Scalar1));
        SpatialVertices curLightSampleSpatialPoints;
        ConvertToSpatialVertices(scene, *(curIterationReservoir.x->vertices), curLightSampleSpatialPoints);
        for (int j = 0; j < k1; ++j) {
            int i = reservoirToUse[j];
            Tensori curLocalShiftMapping = IndexedRead(localShiftMapping, 
                Tensori::ArrayRange(rays.mNumRays * i, rays.mNumRays * (i + 1), 1, false), 0);
            Tensori globalShiftMapping = IndexedRead(validReservoirIndex.index, curLocalShiftMapping, 0);
            Reservoir previousReservoirCopy = prefiltReservoir->GetIndexedCopy(globalShiftMapping, rays.mNumRays);

            SpatialVertices previousShadingPoint;
            ConvertToSpatialVertices(scene, *(previousReservoirCopy.shadingPoints), previousShadingPoint);

            // Apply normal test
            Expr normalDotProduct = VectorDot(previousReservoirCopy.shadingPoints->normal, curIterationReservoir.shadingPoints->normal);
            Expr dist = Mean(Square(previousReservoirCopy.shadingPoints->position - curIterationReservoir.shadingPoints->position));
            Expr nFilter = ((normalDotProduct > Scalar(reservoirMergeNormalThreshold)) * 
                            (dist < Scalar(reservoirMergeDistThreshold)));
            previousReservoirCopy.m = Where(nFilter, previousReservoirCopy.m, Scalar(0));
            previousReservoirCopy.W = Where(nFilter, previousReservoirCopy.W, Scalar(0.0f));

            Expr pHatPrevCur = EvalPHat(scene, rays, previousShadingPoint, curLightSampleSpatialPoints, true);
            pSum = pSum + previousReservoirCopy.m * pHatPrevCur * nFilter;
            pI = Where(pickedIndex == Scalar(j + 1), pHatPrevCur, pI);
            // Z = Z + Where(pHatPrevCur > Scalar(0.0f), previousReservoirCopy.m, Scalar(0));
            // Z = Z + Where(pHatPrevCur > Scalar(0.0f), previousReservoirCopy.m * previousReservoirCopy.wAverage + Scalar(RESTIR_PHAT_EPSILON), Scalar(0.0f));
            // _M = _M + previousReservoirCopy.m * previousReservoirCopy.wAverage + Scalar(RESTIR_PHAT_EPSILON);
        }

        // Test itself
        Expr pHatCurCur = EvalPHat(scene, rays, *(curIterationReservoir.shadingPoints), curLightSampleSpatialPoints, true);
        pSum = pSum + m0 * pHatCurCur;
        pI = Where(pickedIndex == Scalar(0), pHatCurCur, pI);
        // Z = Z + Where(pHatCurCur > Scalar(0.0f), m0 * w0 + Scalar(RESTIR_PHAT_EPSILON), Scalar(0.0f));
        // _M = _M + m0 * w0 + Scalar(RESTIR_PHAT_EPSILON);

        Tensorf mi = Where(pSum > Scalar(0.0f), pI / pSum, Scalar(0.0f));
        Tensorf wSum = curIterationReservoir.wAverage * curIterationReservoir.m;
        curIterationReservoir.W = Where(pHatCurCur > Scalar(0.0f), Scalar(1.0f) / pHatCurCur * wSum * mi, Scalar(0.0f));
        // curIterationReservoir.W = Where(curIterationReservoir.x->pHat > Scalar(0.0f), 
        //                                 curIterationReservoir.wAverage / curIterationReservoir.x->pHat, 
        //                                 Scalar(0.0f));
        // curIterationReservoir.W = Where(Z > Scalar(0), curIterationReservoir.W * Tensorf(_M / Z), Scalar(0.0f));
    }

    // Reservoir r;
    // Reservoir::Combine(*reservoirOnFile, curIterationReservoir, r);
    // mPathSampler->reservoirOnFile = std::make_shared<Reservoir>(r);

    ConvertToSpatialVertices(scene, *(curIterationReservoir.x->vertices), spatialLight);
    spatialLight.pdf = Where(curIterationReservoir.W > Scalar(0.0f), Scalar(1.0f) / curIterationReservoir.W,
                                    Scalar(0.0f));
}

std::shared_ptr<Reservoir> PixelBoundaryIntegrator2::PrefilerReservoir(const Scene& scene, 
        const std::shared_ptr<Reservoir>& reservoir,
        const Camera& camera,
        const Tensorf& curPosition) const {
    // Check if the shading point is in image plane
    SensorDirectSample sds = camera.sampleDirect(curPosition);
    IndexMask isInImagePlaneMask = IndexMask(sds.isValid);
    // Check visibility
    Ray primaryRay;
    camera.GenerateBoundaryRays(sds, primaryRay);
    auto pToCamDist = VectorLength(IndexedRead(curPosition, isInImagePlaneMask.index, 0) - 
            camera.mPosTensor);
    primaryRay.mMin = Scalar(SHADOW_EPSILON);
    primaryRay.mMax = pToCamDist - Scalar(SHADOW_EPSILON);
    Tensorb isVisible;
    scene.Occluded(primaryRay, isVisible);
    isVisible = IndexedWrite(isVisible, isInImagePlaneMask.index,
            isInImagePlaneMask.mask.GetShape(), 0);

    IndexMask validReservoirIndex = IndexMask(isInImagePlaneMask.mask * isVisible);
    return std::make_shared<Reservoir>(reservoir->GetIndexedCopy(validReservoirIndex.index, validReservoirIndex.sum));
}

}
}
