// Copyright @yucwang 2022

#include "MaterialSpaceReSTIR.h"

#include "../Algorithm1.h"
#include "cukd/cpfinder.h"
#include "cukd/knnfinder.h"

namespace EDX {
namespace TensorRay {

void MaterialSpaceReSTIRPathSampler::SetParam(const RenderOptions& options) {
    PathSampler::SetParam(options);
    const RISRenderOptions& _options = dynamic_cast<const RISRenderOptions&>(options);
    this->M = _options.M;
    this->haveTemporalReuse = _options.haveTemporalReuse;
    this->k = _options.k;
    this->k1 = _options.k1;
    this->historyLength = _options.historyLength;
    this->reservoirMergeNormalThreshold = _options.reservoirMergeNormalThreshold;
    this->reservoirMergeDistThreshold = _options.reservoirMergeDistThreshold;
    this->storeReservoir = _options.storeReservoir;

    reservoirUsed.clear();
    for (int i = 0; i < this->k; ++i) reservoirUsed.push_back(i);
}

void MaterialSpaceReSTIRPathSampler::SamplePaths(const Scene& scene, PathSampleResult& res) const {
    res.mPathVertices.clear();
    res.mPathVertices.resize(mMaxBounces + 1);
    res.mNEEVertices.clear();
    res.mNEEVertices.resize(mMaxBounces);

    const Camera& camera = *scene.mSensors[0];
    Ray rays;
    // Generate (antithetic) rays
    camera.GenerateAntitheticRays(rays, mSppBatch, mAntitheticSpp);

    Intersection its;
    // Handle primary rays
    scene.IntersectHit(rays, its);

    if (rays.mNumRays > 0)
    {
        scene.PostIntersect(its);
        MaterialVertices& matV = res.mPathVertices[0];
        matV.numOfSamples = rays.mNumRays;
        matV.prevId = rays.mPixelIdx;
        matV.triangleId = its.mTriangleId;
        matV.emitterId = its.mEmitterId;
        matV.bsdfId = its.mBsdfId;
        matV.u = its.mBaryU;
        matV.v = its.mBaryV;
        matV.pdf = Ones(matV.u.GetShape());
    }

    // Handle secondary rays
    int nRaysInit = (mSppBatch / mAntitheticSpp) * camera.mResX * camera.mResY;
    Reservoir curIterationReservoir(rays.mNumRays);
    ConvertToSpatialVertices(scene, res.mPathVertices[0], *(curIterationReservoir.shadingPoints));
    for (int i = 0; i < M; ++i) {
        if (rays.mNumRays == 0) break;
        LightSample curLightSample;
        Tensorf antitheticRndLight = Tensorf::RandomFloat(Shape({ nRaysInit }, VecType::Vec2));
        Tensorf rndLight = IndexedRead(antitheticRndLight, rays.mRayIdx % Scalar(nRaysInit), 0);

        SampleEmitterDirect(scene, rays, its, rndLight, *curLightSample.vertices);

        SpatialVertices curSpatialLightSamples;
        ConvertToSpatialVertices(scene, *curLightSample.vertices, curSpatialLightSamples);

        curLightSample.pHat = EvalPHat(scene, rays, 
                *(curIterationReservoir.shadingPoints), curSpatialLightSamples, true);
        Tensorf rndReservoirSample = Tensorf::RandomFloat(
            Shape({ nRaysInit }, VecType::Scalar1));
        Expr rndReservoir = IndexedRead(rndReservoirSample, rays.mRayIdx % Scalar(nRaysInit), 0);
        UpdateReservoir(curIterationReservoir, curLightSample, 
            curLightSample.pHat / curLightSample.vertices->pdf, Scalar(1), rndReservoir);
    }
    curIterationReservoir.W = Detach(Where(curIterationReservoir.x->pHat > Scalar(0.0f), 
                                    curIterationReservoir.wAverage / curIterationReservoir.x->pHat, 
                                    Scalar(0.0f)));

    if (haveTemporalReuse && reservoir->numOfSamples > 0) {
        Expr m0 = curIterationReservoir.m;
        Expr w0 = curIterationReservoir.wAverage;
        Tensori pickedIndex = Zeros(Shape({curIterationReservoir.numOfSamples}, VecType::Scalar1));
        reservoir->m = Where(reservoir->m > Scalar(M * historyLength), Scalar(M * historyLength), reservoir->m);

        SpatialVertices prevSpatialPoints;
        ConvertToSpatialVertices(scene, *reservoir->shadingPoints, prevSpatialPoints);
        // Prefilte reservoir
        std::shared_ptr<Reservoir> prefiltReservoir = PrefilerReservoir(scene, 
            camera, prevSpatialPoints.position);

        Tensori localShiftMapping = Zeros(Shape({ rays.mNumRays * k }, VecType::Scalar1));
        IndexMask validReservoirIndex = (prefiltReservoir->m > Scalar(0) && prefiltReservoir->wAverage > Scalar(0.0));
        {
            Tensorf curPositions = curIterationReservoir.shadingPoints->position;
            ConvertToSpatialVertices(scene, 
                prefiltReservoir->shadingPoints->GetMaskedCopy(validReservoirIndex), 
                prevSpatialPoints);
            cukd::FindKNN(prevSpatialPoints.position.Data(), 
                prevSpatialPoints.numOfSamples, 
                curPositions.Data(), 
                rays.mNumRays, 
                localShiftMapping.Data(), k);
            localShiftMapping.CopyToHost();
        }

        shuffle(reservoirUsed.begin(), reservoirUsed.end(), randomGenerator);
        for (int j = 0; j < k1; ++j) {
            int i = reservoirUsed[j];
            Tensori curLocalShiftMapping = IndexedRead(localShiftMapping, 
                Tensori::ArrayRange(rays.mNumRays * i, rays.mNumRays * (i + 1), 1, false), 0);
            Tensori globalShiftMapping = IndexedRead(validReservoirIndex.index, curLocalShiftMapping, 0);
            Reservoir previousReservoirCopy = prefiltReservoir->GetIndexedCopy(globalShiftMapping, rays.mNumRays);

            SpatialVertices prevSpatialLightSamples;
            ConvertToSpatialVertices(scene, *previousReservoirCopy.x->vertices, prevSpatialLightSamples);

            SpatialVertices previousShadingPoint;
            ConvertToSpatialVertices(scene, *(previousReservoirCopy.shadingPoints), previousShadingPoint);

            previousReservoirCopy.x->pHat = EvalPHat(scene, rays, *(curIterationReservoir.shadingPoints), prevSpatialLightSamples, true);

            // Apply normal test
            Tensorf normalDotProduct = VectorDot(previousShadingPoint.normal, curIterationReservoir.shadingPoints->normal);
            Tensorf dist = VectorLength(previousShadingPoint.position - curIterationReservoir.shadingPoints->position);
            Tensorf nFilter = ((normalDotProduct > Scalar(reservoirMergeNormalThreshold)) * 
                            (dist < Scalar(reservoirMergeDistThreshold)));
            // Expr nFilter = ((dist < Scalar(reservoirMergeDistThreshold)));
            previousReservoirCopy.m = Where(nFilter, previousReservoirCopy.m, Scalar(0));
            previousReservoirCopy.W = Where(nFilter, previousReservoirCopy.W, Scalar(0.0f));

            Tensorf rndReservoirSample = Tensorf::RandomFloat(
                Shape({ nRaysInit }, VecType::Scalar1));

            Tensorf rndReservoir = IndexedRead(rndReservoirSample, rays.mRayIdx % Scalar(nRaysInit), 0);
            Tensori updateMask = UpdateReservoir(curIterationReservoir, previousReservoirCopy, scene, rays, rndReservoir);
            pickedIndex = Where(updateMask > Scalar(0), Scalar(j + 1), pickedIndex);
        }

        // Tensori Z = Zeros(Shape({ rays.mNumRays }, VecType::Scalar1));
        // Tensori _M = Zeros(Shape({ rays.mNumRays }, VecType::Scalar1));
        Tensorf pI = Zeros(Shape({ rays.mNumRays }, VecType::Scalar1));
        Tensorf pSum = Zeros(Shape({ rays.mNumRays }, VecType::Scalar1));
        SpatialVertices curLightSampleSpatialPoints;
        ConvertToSpatialVertices(scene, *(curIterationReservoir.x->vertices), curLightSampleSpatialPoints);
        for (int j = 0; j < k1; ++j) {
            int i = reservoirUsed[j];
            Tensori curLocalShiftMapping = IndexedRead(localShiftMapping, 
                Tensori::ArrayRange(rays.mNumRays * i, rays.mNumRays * (i + 1), 1, false), 0);
            Tensori globalShiftMapping = IndexedRead(validReservoirIndex.index, curLocalShiftMapping, 0);
            Reservoir previousReservoirCopy = prefiltReservoir->GetIndexedCopy(globalShiftMapping, rays.mNumRays);

            SpatialVertices previousShadingPoint;
            ConvertToSpatialVertices(scene, *(previousReservoirCopy.shadingPoints), previousShadingPoint);

            // Apply normal test
            Expr normalDotProduct = VectorDot(previousShadingPoint.normal, curIterationReservoir.shadingPoints->normal);
            Tensorf dist = VectorLength(previousShadingPoint.position - curIterationReservoir.shadingPoints->position);
            Expr nFilter = ((normalDotProduct > Scalar(reservoirMergeNormalThreshold)) * 
                                (dist < Scalar(reservoirMergeDistThreshold)));
            // Expr nFilter = ((dist < Scalar(reservoirMergeDistThreshold)));
            previousReservoirCopy.m = Where(nFilter, previousReservoirCopy.m, Scalar(0));
            previousReservoirCopy.W = Where(nFilter, previousReservoirCopy.W, Scalar(0.0f));

            Expr pHatPrevCur = EvalPHat(scene, rays, previousShadingPoint, curLightSampleSpatialPoints, true);
            pSum = pSum + previousReservoirCopy.m * pHatPrevCur * nFilter;
            pI = Where(pickedIndex == Scalar(j + 1), pHatPrevCur, pI);
            // Z = Z + Where(pHatPrevCur > Scalar(0.0f), previousReservoirCopy.m, Scalar(0));
            // _M = _M + previousReservoirCopy.m;
        }

        // Test itself
        Expr pHatCurCur = EvalPHat(scene, rays, *(curIterationReservoir.shadingPoints), curLightSampleSpatialPoints, true);
        pSum = pSum + m0 * pHatCurCur;
        pI = Where(pickedIndex == Scalar(0), pHatCurCur, pI);
        // Z = Z + Where(pHatCurCur > Scalar(0.0f), m0, Scalar(0));
        // _M = _M + m0;

        Tensorf mi = Where(pSum > Scalar(0.0f), pI / pSum, Scalar(0.0f));
        Tensorf wSum = curIterationReservoir.wAverage * curIterationReservoir.m;
        curIterationReservoir.W = Where(pHatCurCur > Scalar(0.0f), Scalar(1.0f) / pHatCurCur * wSum * mi, Scalar(0.0f));

        // curIterationReservoir.W = Where(curIterationReservoir.x->pHat > Scalar(0.0f), 
        //                                 curIterationReservoir.wAverage / curIterationReservoir.x->pHat, 
        //                                 Scalar(0.0f));
        // curIterationReservoir.W = Where(Z > Scalar(0.0f), curIterationReservoir.W * Tensorf(_M / Z), Scalar(0.0f));
    }

    // reservoir = curIterationReservoir;
    if (haveTemporalReuse && storeReservoir) {
        Reservoir r;
        Reservoir::Combine(*reservoirOnFile, curIterationReservoir, r);
        reservoirOnFile = std::make_shared<Reservoir>(r);
    }

    res.mNEEVertices[0] = *curIterationReservoir.x->vertices;
    res.mNEEVertices[0].pdf = Where(curIterationReservoir.W > Scalar(0.0f), Scalar(1.0f) / curIterationReservoir.W,
                                    Scalar(0.0f));
}

void EvalMaterialSpaceReSTIR(const Scene& scene, 
                           const SpatialVertices& vCamera, 
                           const SpatialVertices& vCur, 
                           const SpatialVertices& vLight, 
                           Tensorf& res) {
    Expr li = Zeros(Shape({ vCur.numOfSamples }, VecType::Vec3));

    Expr wi = vLight.position - vCur.position;
    Expr its_light_dist_squared = VectorSquaredLength(wi);
    Expr its_light_dist = Sqrt(its_light_dist_squared);
    wi = wi / its_light_dist;

    Expr dotShNorm = VectorDot(vCur.normal, wi);
    Expr dotGeoNorm = VectorDot(vCur.geoNormal, wi);

    Expr connectValid = Scalar(1.0f);
    auto numStable = (its_light_dist > Scalar(SHADOW_EPSILON) && Abs(dotShNorm) > Scalar(EDGE_EPSILON));
    auto noLightLeak = ((dotShNorm > Scalar(0.0f)) && (dotGeoNorm > Scalar(0.0f)));
    Expr visible;
    {
        Ray shadowRays;
        shadowRays.mNumRays = vCur.numOfSamples;
        shadowRays.mOrg = vCur.position;
        shadowRays.mDir = wi;
        shadowRays.mMin = Ones(vCur.numOfSamples) * Scalar(SHADOW_EPSILON);
        shadowRays.mMax = its_light_dist - Scalar(SHADOW_EPSILON);
        scene.Occluded(shadowRays, visible);
    }
    connectValid = numStable * noLightLeak * visible;

    Expr wo = VectorNormalize(vCamera.position - vCur.position);
    Expr nDotWi = Abs(VectorDot(vLight.normal, -wi));
    Expr G = nDotWi / its_light_dist_squared;
    Expr pdf = Detach(vLight.pdf);
    Expr Le = scene.mLights[scene.mAreaLightIndex]->Eval(vLight, -wi);

    // Sample point on area light
    for (int iBSDF = 0; iBSDF < scene.mBSDFCount; ++iBSDF) 
    {
        IndexMask mask_bsdf = vCur.bsdfId == Scalar(iBSDF);
        Expr condBsdf = (vCur.bsdfId == Scalar(iBSDF));
        if (mask_bsdf.sum == 0) continue;

        SpatialVertices curV = vCur.GetMaskedCopy(mask_bsdf);
        Expr bsdfVal = scene.mBsdfs[iBSDF]->Eval(curV, 
            Mask(wo, mask_bsdf, 0), Mask(wi, mask_bsdf, 0)) * Abs(Mask(dotShNorm, mask_bsdf, 0));

        Expr curG = Mask(G, mask_bsdf, 0);
        Expr curLe = Mask(Le, mask_bsdf, 0);
        Expr curConnectValid = Mask(connectValid, mask_bsdf, 0);
        Expr curJ = Mask(vLight.J, mask_bsdf, 0);

        li = li + IndexedWrite(bsdfVal * curG * curLe * curJ * curConnectValid, mask_bsdf.index, li->GetShape(), 0);
    }

    // r.W = Tensorf(r.W);
    li = Tensorf(li);

    res = NaNToZero(li * Where(pdf > Scalar(0.0f), Scalar(1.0f) / pdf, Scalar(0.0f)));
}

std::shared_ptr<Reservoir> MaterialSpaceReSTIRPathSampler::PrefilerReservoir(const Scene& scene, const Camera& camera,
                                                                             const Tensorf& curPosition) const {
    // Check if the shading point is in image plane
    SensorDirectSample sds = camera.sampleDirect(curPosition);
    IndexMask isInImagePlaneMask = IndexMask(sds.isValid);
    // Check visibility
    Ray primaryRay;
    camera.GenerateBoundaryRays(sds, primaryRay);
    Tensorf pos = IndexedRead(curPosition, isInImagePlaneMask.index, 0);
    auto pToCamDist = VectorLength(pos - camera.mPosTensor);
    primaryRay.mMin = Scalar(SHADOW_EPSILON);
    primaryRay.mMax = pToCamDist - Scalar(SHADOW_EPSILON);
    Tensorb isVisible;
    scene.Occluded(primaryRay, isVisible);
    isVisible = IndexedWrite(isVisible, isInImagePlaneMask.index, 
            isInImagePlaneMask.mask.GetShape(), 0);

    IndexMask validReservoirIndex = IndexMask(isInImagePlaneMask.mask * isVisible);
    reuseMask = reuseMask * (Scalar(1) - validReservoirIndex.mask);
    return std::make_shared<Reservoir>(reservoir->GetIndexedCopy(validReservoirIndex.index, validReservoirIndex.sum));
}

} // namespace TensorRay
} // namespace EDX
