// Copyright @yucwang 2022

#include "ScreenSpaceReSTIR.h"

namespace EDX {
namespace TensorRay {

void ScreenSpaceReSTIRPathSampler::SetParam(const RenderOptions& options) {
    PathSampler::SetParam(options);
    const RISRenderOptions& _options = dynamic_cast<const RISRenderOptions&>(options);
    this->M = _options.M;
    this->haveTemporalReuse = _options.haveTemporalReuse;
}

void ScreenSpaceReSTIRPathSampler::SamplePaths(const Scene& scene, PathSampleResult& res) const {
    res.mPathVertices.clear();
    res.mPathVertices.resize(mMaxBounces + 1);
    res.mNEEVertices.clear();
    res.mNEEVertices.resize(mMaxBounces);

    const Camera& camera = *scene.mSensors[0];
    Ray rays;
    // Generate (antithetic) rays
    camera.GenerateAntitheticRays(rays, mSppBatch, mAntitheticSpp);

    Intersection its;
    // Handle primary rays
    scene.IntersectHit(rays, its);

    if (rays.mNumRays > 0)
    {
        scene.PostIntersect(its);
        MaterialVertices& matV = res.mPathVertices[0];
        matV.numOfSamples = rays.mNumRays;
        matV.prevId = rays.mPixelIdx;
        matV.triangleId = its.mTriangleId;
        matV.emitterId = its.mEmitterId;
        matV.bsdfId = its.mBsdfId;
        matV.u = its.mBaryU;
        matV.v = its.mBaryV;
        matV.pdf = Ones(matV.u.GetShape());
    }

    // Handle secondary rays
    int nRaysInit = (mSppBatch / mAntitheticSpp) * camera.mResX * camera.mResY;
    Reservoir curIterationReservoir(rays.mNumRays);
    ConvertToSpatialVertices(scene, res.mPathVertices[0], *curIterationReservoir.shadingPoints);
    for (int i = 0; i < M; ++i) {
        if (rays.mNumRays == 0) break;
        LightSample curLightSample;
        Tensorf antitheticRndLight = Tensorf::RandomFloat(Shape({ nRaysInit }, VecType::Vec2));
        Tensorf rndLight = IndexedRead(antitheticRndLight, rays.mRayIdx % Scalar(nRaysInit), 0);

        SampleEmitterDirect(scene, rays, its, rndLight, *curLightSample.vertices);
        SpatialVertices curSpatialLightSamples;
        ConvertToSpatialVertices(scene, *curLightSample.vertices, curSpatialLightSamples);

        // Not check shadow ray for now
        curLightSample.pHat = EvalPHat(scene, rays, *curIterationReservoir.shadingPoints, curSpatialLightSamples, true);
        Tensorf rndReservoirSample = Tensorf::RandomFloat(
            Shape({ nRaysInit }, VecType::Scalar1));
        Expr rndReservoir = IndexedRead(rndReservoirSample, rays.mRayIdx % Scalar(nRaysInit), 0);
        UpdateReservoir(curIterationReservoir, curLightSample, 
            curLightSample.pHat / curLightSample.vertices->pdf, Scalar(1), rndReservoir);
    }
    // curIterationReservoir.x.pHat = EvalPHat(scene, rays, its, curIterationReservoir.x, true);
    curIterationReservoir.W = Detach(Where(curIterationReservoir.x->pHat > Scalar(0.0f), 
                                    curIterationReservoir.wAverage / curIterationReservoir.x->pHat, 
                                    Scalar(0.0f)));

    if (haveTemporalReuse && reservoir.numOfSamples > 0) {

        Expr m0 = curIterationReservoir.m;
        Reservoir previousReservoirCopy = reservoir.GetIndexedCopy(rays.mPixelIdx, rays.mNumRays);
        previousReservoirCopy.m = Where(previousReservoirCopy.m > Scalar(M * 4), Scalar(M * 4), previousReservoirCopy.m);
        SpatialVertices prevSpatialLightSamples;
        ConvertToSpatialVertices(scene, *previousReservoirCopy.x->vertices, prevSpatialLightSamples);
        previousReservoirCopy.x->pHat = EvalPHat(scene, rays, *curIterationReservoir.shadingPoints, prevSpatialLightSamples, true);

        // Apply normal test
        Expr normalDotProduct = VectorDot(previousReservoirCopy.shadingPoints->normal, curIterationReservoir.shadingPoints->normal);
        Expr nFilter = (normalDotProduct > Scalar(0.5f));
        previousReservoirCopy.m = Where(nFilter, previousReservoirCopy.m, Scalar(0));
        previousReservoirCopy.W = Where(nFilter, previousReservoirCopy.W, Scalar(0.0f));

        Tensorf rndReservoirSample = Tensorf::RandomFloat(
            Shape({ nRaysInit }, VecType::Scalar1));
        Tensorf rndReservoir = IndexedRead(rndReservoirSample, rays.mRayIdx % Scalar(nRaysInit), 0);
        UpdateReservoir(curIterationReservoir, previousReservoirCopy, scene, rays, rndReservoir);

        Tensori Z = Zeros(Shape({ rays.mNumRays }, VecType::Scalar1));
        SpatialVertices curLightSampleSpatialPoints;
        ConvertToSpatialVertices(scene, *curIterationReservoir.x->vertices, curLightSampleSpatialPoints);
        SpatialVertices previousShadingPoint;
        ConvertToSpatialVertices(prevReferenceSpace, *previousReservoirCopy.shadingPoints, previousShadingPoint);
        Expr pHatPrevCur = EvalPHat(prevReferenceSpace, rays, previousShadingPoint, curLightSampleSpatialPoints, false);
        Z = Z + Where(pHatPrevCur > Scalar(0.0f), previousReservoirCopy.m, Scalar(0));

        // Test itself
        Expr pHatCurCur = EvalPHat(scene, rays, *curIterationReservoir.shadingPoints, curLightSampleSpatialPoints, false);
        Z = Z + Where(pHatCurCur > Scalar(0.0f), m0, Scalar(0));
        Expr _M = curIterationReservoir.m;
        curIterationReservoir.W = Where(curIterationReservoir.x->pHat > Scalar(0.0f), 
                                        curIterationReservoir.wAverage / curIterationReservoir.x->pHat, 
                                        Scalar(0.0f));
        curIterationReservoir.W = Where(Z > Scalar(0), curIterationReservoir.W * Tensorf(_M / Z), Scalar(0.0f));
    }
    if (reservoir.numOfSamples == 0)
        reservoir = Reservoir(nRaysInit);
    reservoir.UpdateIndexed(curIterationReservoir, rays.mPixelIdx);
    prevReferenceSpace = scene;

    res.mNEEVertices[0] = *curIterationReservoir.x->vertices;
    res.mNEEVertices[0].pdf = Where(curIterationReservoir.W > Scalar(0.0f), Scalar(1.0f) / curIterationReservoir.W,
                                    Scalar(0.0f));
}

void EvalScreenSpaceReSTIR(const Scene& scene, 
                           const SpatialVertices& vCamera, 
                           const SpatialVertices& vCur, 
                           const SpatialVertices& vLight, 
                           Tensorf& res) {
    Expr li = Zeros(Shape({ vCur.numOfSamples }, VecType::Vec3));

    Expr wi = vLight.position - vCur.position;
    Expr its_light_dist_squared = VectorSquaredLength(wi);
    Expr its_light_dist = Sqrt(its_light_dist_squared);
    wi = wi / its_light_dist;

    Expr dotShNorm = VectorDot(vCur.normal, wi);
    Expr dotGeoNorm = VectorDot(vCur.geoNormal, wi);

    Expr connectValid = Scalar(1.0f);
    auto numStable = (its_light_dist > Scalar(SHADOW_EPSILON) && Abs(dotShNorm) > Scalar(EDGE_EPSILON));
    auto noLightLeak = ((dotShNorm > Scalar(0.0f)) && (dotGeoNorm > Scalar(0.0f)));
    Expr visible;
    {
        Ray shadowRays;
        shadowRays.mNumRays = vCur.numOfSamples;
        shadowRays.mOrg = vCur.position;
        shadowRays.mDir = wi;
        shadowRays.mMin = Ones(vCur.numOfSamples) * Scalar(SHADOW_EPSILON);
        shadowRays.mMax = its_light_dist - Scalar(SHADOW_EPSILON);
        scene.Occluded(shadowRays, visible);
    }
    connectValid = numStable * noLightLeak * visible;

    Expr wo = VectorNormalize(vCamera.position - vCur.position);
    Expr nDotWi = VectorDot(vLight.normal, -wi);
    Expr G = Abs(nDotWi) / its_light_dist_squared;
    Expr pdf = Detach(vLight.pdf);
    Expr Le = scene.mLights[scene.mAreaLightIndex]->Eval(vLight, -wi);

    // Sample point on area light
    for (int iBSDF = 0; iBSDF < scene.mBSDFCount; ++iBSDF) 
    {
        IndexMask mask_bsdf = vCur.bsdfId == Scalar(iBSDF);
        Expr condBsdf = (vCur.bsdfId == Scalar(iBSDF));
        if (mask_bsdf.sum == 0) continue;

        SpatialVertices curV = vCur.GetMaskedCopy(mask_bsdf);
        Expr bsdfVal = scene.mBsdfs[iBSDF]->Eval(curV, 
            Mask(wo, mask_bsdf, 0), Mask(wi, mask_bsdf, 0)) * Abs(Mask(dotShNorm, mask_bsdf, 0));

        Expr curG = Mask(G, mask_bsdf, 0);
        Expr curLe = Mask(Le, mask_bsdf, 0);
        Expr curConnectValid = Mask(connectValid, mask_bsdf, 0);
        Expr curJ = Mask(vLight.J, mask_bsdf, 0);

        li = li + IndexedWrite(bsdfVal * curG * curLe * curJ * curConnectValid, mask_bsdf.index, li->GetShape(), 0);
    }

    // r.W = Tensorf(r.W);
    li = Tensorf(li);

    res = NaNToZero(li * Where(pdf > Scalar(0.0f), Scalar(1.0f) / pdf, Scalar(0.0f)));
}

} // namespace TensorRay
} // namespace EDX