// Copyright @yucwang 2022

#include "ReSTIRDataTypes.h"

#include "../Algorithm1.h"

namespace EDX {
namespace TensorRay {

Tensori UpdateReservoir(Reservoir& r, const LightSample& xI,const Expr& wI, const Expr& m1, 
                     const Expr& sample) {
    r.m = r.m + m1;
    Expr haveSampleMask = (r.m > Scalar(0));
    Expr factor = Where(haveSampleMask, Scalar(1.0f) / r.m, Scalar(0.0f));

    r.wAverage = (r.m - m1) * factor * r.wAverage + m1 * factor * wI;

    Expr wAverageVal = Where(r.wAverage > Scalar(0.0f), Scalar(1.0f) / r.wAverage, Scalar(0.0f));
    Expr updateMask = (haveSampleMask * (sample < (wI * m1 * factor * wAverageVal)) * (m1 > Scalar(0)));

    // Fix for the case that all M cadidate samples have p_hat = 0
    // Can we remove this
    // Expr forceUpdate =  ((m1 > Scalar(0)) * (r.x.pHat < Scalar(RESTIR_PHAT_EPSILON)));
    // r.x.UpdateMasked(xI, forceUpdate);

    r.x->UpdateMasked(xI, updateMask);

    return updateMask;
}

void UpdateReservoir(Reservoir& r, const Reservoir& r1, const Expr& sample, int mMax) {
    Expr G = (VectorDot(r.shadingPoints->normal, r1.shadingPoints->normal));
    Expr r1Filter = ((G > Scalar(0.6f)) * (r1.x->pHat > Scalar(0.0f)));

    Expr m1 = Where(r1Filter, r1.m, Scalar(0));
    Expr W1 = Where(r1Filter, r1.W, Scalar(0.0f));

    Expr m = r.m;

    UpdateReservoir(r, *r1.x, W1 * r1.x->pHat, m1, sample);
    r.W = Detach(Where(r.x->pHat > Scalar(0.0f), r.wAverage / r.x->pHat, Scalar(0.0f)));
    r.m = Where(r.m > Scalar(mMax), Scalar(mMax), r.m);
}

void UpdateReservoir2(Reservoir& r, const Reservoir& r1, const Expr& sample) {
    Expr G = (VectorDot(r.shadingPoints->normal, r1.shadingPoints->normal));
    Expr r1Filter = ((G > Scalar(0.6f)) * (r1.x->pHat > Scalar(0.0f)));

    Expr m1 = Where(r1Filter, r1.m, Scalar(0));
    Expr W1 = Where(r1Filter, r1.W, Scalar(0.0f));

    UpdateReservoir(r, *r1.x, W1 * r1.x->pHat, m1, sample);
    r.W = Detach(Where(r.x->pHat > Scalar(0.0f), r.wAverage / r.x->pHat, Scalar(0.0f)));
}

Tensori UpdateReservoir(Reservoir& r, const Reservoir& r1, const Scene& scene, 
                     const Ray& rays, const Expr& sample) {
    // Normal Test
    return UpdateReservoir(r, *r1.x, r1.W * r1.x->pHat, r1.m, sample);
    
    // SpatialVertices curSpatialLightSample;
    // ConvertToSpatialVertices(scene, r.x.vertices, curSpatialLightSample);
 
    // Tensori Z = Zeros(Shape({ r.numOfSamples }, VecType::Scalar1));
    // Tensorf pHat1 = EvalPHat(scene, rays, r.shadingPoints, curSpatialLightSample, true);
    // Z = Z + Where(pHat1 > Scalar(0.0f), rm, Scalar(0));
 
    // SpatialVertices previousShadingPoints;
    // ConvertToSpatialVertices(scene, r1.shadingPoints, previousShadingPoints);
    // Tensorf pHat2 = EvalPHat(scene, rays, previousShadingPoints, curSpatialLightSample, true);
    // Z = Z + Where(pHat2 > Scalar(0.0f), m1, Scalar(0));
 
    // r.W = Detach(Where(r.x.pHat > Scalar(0.0f), r.wAverage / r.x.pHat, Scalar(0.0f)));
    // Tensorf factor = Where(Z > Scalar(0), Scalar(1.0f) / Z, Scalar(0.0f));
    // r.W = Detach(Where(Z > Scalar(0), r.W * r.m * factor, Scalar(0.0f)));
}

Tensorf EvalPHat(const Scene& scene, const Ray& rays, const Intersection& its,
              const LightSample& lightSamples, bool testShadowRay) {
    Expr pHat = Zeros(Shape({ lightSamples.numOfSamples }, VecType::Scalar1));

    SpatialVertices spatialLightSamples;
    ConvertToSpatialVertices(scene, *lightSamples.vertices, spatialLightSamples);

    Expr wi = spatialLightSamples.position - its.mPosition;
    Expr its_light_dist_squared = VectorSquaredLength(wi);
    Expr its_light_dist = Sqrt(its_light_dist_squared);
    wi = wi / its_light_dist;

    Expr dotShNorm = VectorDot(its.mNormal, wi);
    Expr dotGeoNorm = VectorDot(its.mGeoNormal, wi);

    Expr connectValid = Scalar(1.0f);
    if (testShadowRay) {
        auto numStable = (its_light_dist > Scalar(SHADOW_EPSILON) && Abs(dotShNorm) > Scalar(EDGE_EPSILON));
        auto noLightLeak = ((dotShNorm > Scalar(0.0f)) * (dotGeoNorm > Scalar(0.0f)));
        Expr visible;
        {
            Ray shadowRays;
            shadowRays.mNumRays = rays.mNumRays;
            shadowRays.mOrg = its.mPosition;
            shadowRays.mDir = wi;
            shadowRays.mMin = Ones(rays.mNumRays) * Scalar(SHADOW_EPSILON);
            shadowRays.mMax = its_light_dist - Scalar(SHADOW_EPSILON);
            scene.Occluded(shadowRays, visible);
        }
        connectValid = numStable * noLightLeak * visible;
    }

    Tensorf wo = -rays.mDir;
    Expr nDotWi = VectorDot(spatialLightSamples.normal, -wi);
    Expr G = Abs(nDotWi) / its_light_dist_squared;
    Expr Le = scene.mLights[scene.mAreaLightIndex]->Eval(spatialLightSamples, -wi);

    // Sample point on area light
    for (int iBSDF = 0; iBSDF < scene.mBSDFCount; ++iBSDF) 
    {
        IndexMask mask_bsdf = its.mBsdfId == Scalar(iBSDF);
        Expr condBsdf = (its.mBsdfId == Scalar(iBSDF));
        if (mask_bsdf.sum == 0) continue;

        Intersection curIts = its.GetMaskedCopy(mask_bsdf, true);
        Expr bsdfVal = scene.mBsdfs[iBSDF]->Eval(curIts, 
            Mask(wo, mask_bsdf, 0), Mask(wi, mask_bsdf, 0)) * Abs(Mask(dotShNorm, mask_bsdf, 0));

        Expr curG = Mask(G, mask_bsdf, 0);
        Expr curLe = Mask(Le, mask_bsdf, 0);
        Expr curConnectValid = Mask(connectValid * mask_bsdf.mask, mask_bsdf, 0);
        Expr curJ = Mask(spatialLightSamples.J, mask_bsdf, 0);

        pHat = pHat + IndexedWrite(Luminance(bsdfVal * curG * curLe * curJ * curConnectValid), 
                    mask_bsdf.index, pHat->GetShape(), 0);
    }
    pHat = pHat + Where(connectValid > Scalar(0), Scalar(RESTIR_PHAT_EPSILON), Scalar(0.0f));
    // pHat = Where(dotShNorm > Scalar(0.0f), dotShNorm, Scalar(0.0f));

    return Tensorf(Detach(pHat));
}

Tensorf EvalPHat(const Scene& scene, const Ray& rays, const SpatialVertices& curV,
              const SpatialVertices& lightSamples, bool testShadowRay) {
    Expr pHat = Zeros(Shape({ lightSamples.numOfSamples }, VecType::Scalar1));

    Expr wi = lightSamples.position - curV.position;
    Expr its_light_dist_squared = VectorSquaredLength(wi);
    Expr its_light_dist = Sqrt(its_light_dist_squared);
    wi = wi / its_light_dist;

    Expr dotShNorm = VectorDot(curV.normal, wi);
    Expr dotGeoNorm = VectorDot(curV.geoNormal, wi);

    Expr connectValid = Scalar(1.0f);
    if (testShadowRay) {
        auto numStable = (its_light_dist > Scalar(SHADOW_EPSILON) && Abs(dotShNorm) > Scalar(EDGE_EPSILON));
        auto noLightLeak = ((dotShNorm > Scalar(0.0f)) * (dotGeoNorm > Scalar(0.0f)));
        Expr visible;
        {
            Ray shadowRays;
            shadowRays.mNumRays = rays.mNumRays;
            shadowRays.mOrg = curV.position;
            shadowRays.mDir = wi;
            shadowRays.mMin = Ones(rays.mNumRays) * Scalar(SHADOW_EPSILON);
            shadowRays.mMax = its_light_dist - Scalar(SHADOW_EPSILON);
            scene.Occluded(shadowRays, visible);
        }
        connectValid = numStable * noLightLeak * visible;
    }

    Tensorf wo = -rays.mDir;
    Expr nDotWi = VectorDot(lightSamples.normal, -wi);
    Expr G = Abs(nDotWi) / its_light_dist_squared;
    Expr Le = scene.mLights[scene.mAreaLightIndex]->Eval(lightSamples, -wi);

    // Sample point on area light
    for (int iBSDF = 0; iBSDF < scene.mBSDFCount; ++iBSDF) 
    {
        IndexMask mask_bsdf = curV.bsdfId == Scalar(iBSDF);
        Expr condBsdf = (curV.bsdfId == Scalar(iBSDF));
        if (mask_bsdf.sum == 0) continue;

        SpatialVertices curIts = curV.GetMaskedCopy(mask_bsdf);
        Expr bsdfVal = scene.mBsdfs[iBSDF]->Eval(curIts, 
            Mask(wo, mask_bsdf, 0), Mask(wi, mask_bsdf, 0)) * Abs(Mask(dotShNorm, mask_bsdf, 0));

        Expr curG = Mask(G, mask_bsdf, 0);
        Expr curLe = Mask(Le, mask_bsdf, 0);
        Expr curConnectValid = Mask(connectValid * mask_bsdf.mask, mask_bsdf, 0);
        Expr curJ = Mask(lightSamples.J, mask_bsdf, 0);

        pHat = pHat + IndexedWrite(Luminance(bsdfVal * curG * curLe * curJ * curConnectValid), 
                    mask_bsdf.index, pHat->GetShape(), 0);
    }
    pHat = pHat + Where((connectValid * nDotWi * dotShNorm) > Scalar(0), Scalar(RESTIR_PHAT_EPSILON), Scalar(0.0f));

    return Tensorf(Detach(Abs(pHat)));
}

void EvalScreenSpaceReSTIR(const Scene& scene, 
                           const SpatialVertices& vCamera, 
                           const SpatialVertices& vCur, 
                           const SpatialVertices& vLight, 
                           Tensorf& res) {
    Expr li = Zeros(Shape({ vCur.numOfSamples }, VecType::Vec3));

    Expr wi = vLight.position - vCur.position;
    Expr its_light_dist_squared = VectorSquaredLength(wi);
    Expr its_light_dist = Sqrt(its_light_dist_squared);
    wi = wi / its_light_dist;

    Expr dotShNorm = VectorDot(vCur.normal, wi);
    Expr dotGeoNorm = VectorDot(vCur.geoNormal, wi);

    Expr connectValid = Scalar(1.0f);
    auto numStable = (its_light_dist > Scalar(SHADOW_EPSILON) && Abs(dotShNorm) > Scalar(EDGE_EPSILON));
    auto noLightLeak = ((dotShNorm > Scalar(0.0f)) && (dotGeoNorm > Scalar(0.0f)));
    Expr visible;
    {
        Ray shadowRays;
        shadowRays.mNumRays = vCur.numOfSamples;
        shadowRays.mOrg = vCur.position;
        shadowRays.mDir = wi;
        shadowRays.mMin = Ones(vCur.numOfSamples) * Scalar(SHADOW_EPSILON);
        shadowRays.mMax = its_light_dist - Scalar(SHADOW_EPSILON);
        scene.Occluded(shadowRays, visible);
    }
    connectValid = numStable * noLightLeak * visible;

    Expr wo = VectorNormalize(vCamera.position - vCur.position);
    Expr nDotWi = VectorDot(vLight.normal, -wi);
    Expr G = Abs(nDotWi) / its_light_dist_squared;
    Expr pdf = Detach(vLight.pdf);
    Expr Le = scene.mLights[scene.mAreaLightIndex]->Eval(vLight, -wi);

    // Sample point on area light
    for (int iBSDF = 0; iBSDF < scene.mBSDFCount; ++iBSDF) 
    {
        IndexMask mask_bsdf = vCur.bsdfId == Scalar(iBSDF);
        Expr condBsdf = (vCur.bsdfId == Scalar(iBSDF));
        if (mask_bsdf.sum == 0) continue;

        SpatialVertices curV = vCur.GetMaskedCopy(mask_bsdf);
        Expr bsdfVal = scene.mBsdfs[iBSDF]->Eval(curV, 
            Mask(wo, mask_bsdf, 0), Mask(wi, mask_bsdf, 0)) * Abs(Mask(dotShNorm, mask_bsdf, 0));

        Expr curG = Mask(G, mask_bsdf, 0);
        Expr curLe = Mask(Le, mask_bsdf, 0);
        Expr curConnectValid = Mask(connectValid, mask_bsdf, 0);
        Expr curJ = Mask(vLight.J, mask_bsdf, 0);

        li = li + IndexedWrite(bsdfVal * curG * curLe * curJ * curConnectValid, mask_bsdf.index, li->GetShape(), 0);
    }

    // r.W = Tensorf(r.W);
    li = Tensorf(li);

    res = NaNToZero(li * Where(pdf > Scalar(0.0f), Scalar(1.0f) / pdf, Scalar(0.0f)));
}

} // namespace TensorRay 
} // namespace EDX