#include "Falcor.h"
#include "PrtBase.h"
#include "RenderGraph/RenderPassStandardFlags.h"
#include "RenderPasses/DepthPass.h"

namespace
{
    // Shaders
    const char* kComputePerVertexShaderFile = "computePerVertex.cs.slang";
    const char *kComputePerVertexPerLightShaderFile = "computePerVertexPerLight.cs.slang";

    const char* kPrecomputeShaderFile = "precompute.cs.slang";
    const char* kZhToShShaderFile = "zhToSh.cs.slang";

    const char* kHermiteShaderFile = "HermiteInterp.cs.slang";
    const char* kTrilinearShaderFile = "trilinearInterp.cs.slang";
    const char* kTaylorShaderFile = "TaylorInterp.cs.slang";

    const char* kLightCenterShaderFile = "calcLightCenters.cs.slang";

    // UI
    const char* kAlgo = "algo";
    const char* kAnimation = "animation";

    const char* kOutDir = "outDir";
    const char* kCoeffFilename = "coeffFilename";
    const char* kGridReso = "gridReso";
    const char* kLightDisplayScale = "lightDisplayScale";
    const char* kLightIntensityScale = "lightIntensityScale";

    // Output channels
    const ChannelList kPrtOutputChannels =
    {
        { "color", "gColor", "direct illumination color" },
        { "debug", "gDebug", "values for debugging" },
    };
}

PrtBase::PrtBase() : mPrtParams {}
{
}

bool PrtBase::parseDictionary(const Dictionary& dict)
{
    GBuffer::parseDictionary(dict);
    for (const auto& v : dict)
    {
        if (v.key() == kAlgo)
        {
            mAlgo = v.val();
        }
        else if (v.key() == kAnimation)
        {
            mAnimation = v.val();
        }
        else if (v.key() == kOutDir)
        {
            std::string tmpStr = v.val();
            mOutDir = tmpStr;
        }
        else if (v.key() == kCoeffFilename)
        {
            std::string tmpStr = v.val();
            mCoeffFilename = tmpStr;
        }
        else if (v.key() == kGridReso)
        {
            mGridReso = v.val();
        }
        else if (v.key() == kLightDisplayScale)
        {
            mLightDisplayScale = v.val();
        }
        else if (v.key() == kLightIntensityScale)
        {
            mLightIntensityScale = v.val();
        }
        else
        {
            logWarning("Unknown field `" + v.key() + "` in a GBuffer dictionary");
        }
    }
    return true;
}

Dictionary PrtBase::getScriptingDictionary()
{
    Dictionary dict = GBuffer::getScriptingDictionary();
    dict[kAlgo] = mAlgo;
    dict[kAnimation] = mAnimation;

    dict[kOutDir] = mOutDir;
    dict[kCoeffFilename] = mCoeffFilename;
    dict[kGridReso] = mGridReso;
    dict[kLightDisplayScale] = mLightDisplayScale;
    dict[kLightIntensityScale] = mLightIntensityScale;
    return dict;
}

RenderPassReflection PrtBase::reflect(const CompileData& compileData)
{
    RenderPassReflection r;

    // Add the required depth/stencil output. This always exists.
    r.addOutput("depthStencil", "depth and stencil").format(ResourceFormat::D32Float).bindFlags(Resource::BindFlags::DepthStencil);

    // Add all the other outputs.
    // The default channels are written as render targets, the rest as UAVs as there is way to assign/pack render targets yet.
    auto addOutput = [&](const ChannelDesc& output, Resource::BindFlags bindFlags)
    {
        auto& f = r.addOutput(output.name, output.desc).format(output.format).bindFlags(bindFlags);
        if (output.optional) f.flags(RenderPassReflection::Field::Flags::Optional);
    };
    for (auto it : kGBufferChannels) addOutput(it, Resource::BindFlags::RenderTarget);
    for (auto it : kPrtOutputChannels) addOutput(it, Resource::BindFlags::UnorderedAccess);

    return r;
}

void PrtBase::compile(RenderContext* pContext, const CompileData& compileData)
{
    GBuffer::compile(pContext, compileData);
    mpDepthPrePassGraph = RenderGraph::create("Depth Pre-Pass");
    DepthPass::SharedPtr pDepthPass = DepthPass::create(pContext);
    pDepthPass->setDepthBufferFormat(ResourceFormat::D32Float);
    mpDepthPrePassGraph->addPass(pDepthPass, "DepthPrePass");
    mpDepthPrePassGraph->markOutput("DepthPrePass.depth");
    mpDepthPrePassGraph->setScene(mpScene);
}

void PrtBase::preparePRT()
{
    int lmax = 9;
    mPrtParams.lmax = lmax;

    FILE* fp;

    if (mOutDir == "" || mCoeffFilename == "")
        logError("No outDir or coeffFilename specified!");

    // load transfer coefficients
    std::string coeffFile = mOutDir + mCoeffFilename;

    std::vector<float> Tcoeffs;
    Tcoeffs.reserve(500000 * lmax * lmax);

    int numVertex = 0;
    fp = fopen(coeffFile.c_str(), "r");
    float tmpVal;
    while (fscanf(fp, "%f ", &tmpVal) != EOF)
    {
        Tcoeffs.push_back(tmpVal);
        for (int i = 0; i < lmax * lmax - 1; i++)
        {
            fscanf(fp, "%f ", &tmpVal);
            Tcoeffs.push_back(tmpVal);
        }
        numVertex++;
    }
    fclose(fp);

    mPrtParams.numVertex = numVertex;
    mpTcoeffs = TypedBuffer<float>::create(numVertex * lmax * lmax);
    mpTcoeffs->setBlob((void*)(Tcoeffs.data()), 0, numVertex * lmax * lmax * sizeof(float));

    // precompute Legendre polynomials
    int legendreReso = 10000;
    mPrtParams.legendreReso = legendreReso;

    std::vector<float> legendre_2345((legendreReso + 1) * 4, 0);
    std::vector<float> legendre_6789((legendreReso + 1) * 4, 0);

    for (int i = 0; i <= legendreReso; i++)
    {
        float X[10];
        X[1] = (float)i / legendreReso;
        for (int j = 2; j < 10; j++)
            X[j] = X[j - 1] * X[1];

        legendre_2345[i * 4 + 0] = (0.5f * (3 * X[2] - 1));
        legendre_2345[i * 4 + 1] = (0.5f * (5 * X[3] - 3 * X[1]));
        legendre_2345[i * 4 + 2] = ((35 * X[4] - 30 * X[2] + 3) / 8.0f);
        legendre_2345[i * 4 + 3] = ((63 * X[5] - 70 * X[3] + 15 * X[1]) / 8.0f);

        legendre_6789[i * 4 + 0] = ((231 * X[6] - 315 * X[4] + 105 * X[2] - 5) / 16.0f);
        legendre_6789[i * 4 + 1] = ((429 * X[7] - 693 * X[5] + 315 * X[3] - 35 * X[1]) / 16.0f);
        legendre_6789[i * 4 + 2] = ((6435 * X[8] - 12012 * X[6] + 6930 * X[4] - 1260 * X[2] + 35) / 128.0f);
        legendre_6789[i * 4 + 3] = ((12155 * X[9] - 25740 * X[7] + 18018 * X[5] - 4620 * X[3] + 315 * X[1]) / 128.0f);
    }

    mpLegendre2345 = Texture::create1D(legendreReso + 1, ResourceFormat::RGBA32Float, 1, Resource::kMaxPossible,
        (void*)(legendre_2345.data()),
        Resource::BindFlags::ShaderResource | Resource::BindFlags::RenderTarget | Resource::BindFlags::UnorderedAccess);
    mpLegendre6789 = Texture::create1D(legendreReso + 1, ResourceFormat::RGBA32Float, 1, Resource::kMaxPossible,
        (void*)(legendre_6789.data()),
        Resource::BindFlags::ShaderResource | Resource::BindFlags::RenderTarget | Resource::BindFlags::UnorderedAccess);

    Sampler::Desc samplerDesc;
    // linear interp
    samplerDesc.setFilterMode(Sampler::Filter::Linear, Sampler::Filter::Linear, Sampler::Filter::Linear);
    samplerDesc.setAddressingMode(Sampler::AddressMode::Clamp, Sampler::AddressMode::Clamp, Sampler::AddressMode::Clamp);
    mpLegendreSampler = Sampler::create(samplerDesc);

    mPrtParams.emitterGroupSize = 8;
    mPrtParams.gridReso = uint3(mGridReso, mGridReso, mGridReso);
    mPrtParams.lightDisplayScale = mLightDisplayScale;
    mPrtParams.lightIntensityScale = mLightIntensityScale;

    // Create lighting coefficient and vertex color buffers
    mpVertexColors = TypedBuffer<float>::create(mPrtParams.numVertex * 3);

    uint3 gridReso = mPrtParams.gridReso;
    mpZhGrad = TypedBuffer<float>::create(gridReso.x * gridReso.y * gridReso.z * 3 * 4 * (2 * lmax - 1) * lmax);
    for (uint c = 0; c < 3; c++)
        mpShGrad[c] = TypedBuffer<float4>::create((lmax * lmax) * gridReso.x * gridReso.y * gridReso.z);
}

void PrtBase::prepareShaders()
{
    // Algo 0
    // create SH project per vertex program
    Program::Desc desc;
    desc.addShaderLibrary(mShaderDir + kComputePerVertexShaderFile).csEntry("main").setShaderModel(mSMVersion);
    mpComputePerVertexPass = ComputePass::create(desc, Program::DefineList(), false);
    if (!mpComputePerVertexPass) throw std::exception("Failed to create program");

    // Algo 4
    desc = Program::Desc();
    desc.addShaderLibrary(mShaderDir + kComputePerVertexPerLightShaderFile).csEntry("main").setShaderModel(mSMVersion);
    mpComputePerVertexPerLightPass = ComputePass::create(desc, Program::DefineList(), false);
    if (!mpComputePerVertexPerLightPass) throw std::exception("Failed to create program");

    // Algo 1 & 2 & 3
    // create SH grad precomputation program
    desc = Program::Desc();
    desc.addShaderLibrary(mShaderDir + kPrecomputeShaderFile).csEntry("main").setShaderModel(mSMVersion);

    mpPrecomputePass = ComputePass::create(desc, Program::DefineList(), false);
    if (!mpPrecomputePass) throw std::exception("Failed to create program");
    mpSampleGenerator->prepareProgram(mpPrecomputePass->getProgram().get());

    desc = Program::Desc();
    desc.addShaderLibrary(mShaderDir + kZhToShShaderFile).csEntry("main").setShaderModel(mSMVersion);
    mpZhToShPass = ComputePass::create(desc, Program::DefineList(), false);
    if (!mpZhToShPass) throw std::exception("Failed to create program");

    // Algo 1
    // Hermite interpolation
    desc = Program::Desc();
    desc.addShaderLibrary(mShaderDir + kHermiteShaderFile).csEntry("main").setShaderModel(mSMVersion);
    mpHermitePass = ComputePass::create(desc, Program::DefineList(), false);
    if (!mpHermitePass) throw std::exception("Failed to create program");

    // Algo 2
    // Trilinear interpolation
    desc = Program::Desc();
    desc.addShaderLibrary(mShaderDir + kTrilinearShaderFile).csEntry("main").setShaderModel(mSMVersion);
    mpTrilinearPass = ComputePass::create(desc, Program::DefineList(), false);
    if (!mpTrilinearPass) throw std::exception("Failed to create program");

    // Algo 3
    // Taylor-series based interpolation [Annen et al. 2004]
    desc = Program::Desc();
    desc.addShaderLibrary(mShaderDir + kTaylorShaderFile).csEntry("main").setShaderModel(mSMVersion);
    mpTaylorPass = ComputePass::create(desc, Program::DefineList(), false);
    if (!mpTaylorPass) throw std::exception("Failed to create program");

    // prep animation
    desc = Program::Desc();
    desc.addShaderLibrary(mShaderDir + kLightCenterShaderFile).csEntry("main").setShaderModel(mSMVersion);
    mpLightCenterPass = ComputePass::create(desc, Program::DefineList(), false);
    if (!mpLightCenterPass) throw std::exception("Failed to create program");
}

void PrtBase::preprocessScene()
{
    // reset frameCount
    mGBufferParams.frameCount = 0;

    // set up receiver/emitter mesh index
    mPrtParams.stReceiverIndex = 0;
    uint32_t offset = 0;

    uint32_t numMeshInstance = mpScene->getMeshInstanceCount();
    for (uint32_t i = 0; i < numMeshInstance; i++)
    {
        const MeshInstanceData& inst = mpScene->getMeshInstance(i);
        const MeshDesc& mesh = mpScene->getMesh(inst.meshID);
        const auto& mat = mpScene->getMaterial(inst.materialID);

        if (mat->isEmissive())
        {
            mPrtParams.stEmitterIndex = i;
            break;
        }

        offset += mesh.vertexCount;
    }

    mPrtParams.numReceivers = mPrtParams.stEmitterIndex - mPrtParams.stReceiverIndex;
    mPrtParams.numEmitters = numMeshInstance - mPrtParams.stEmitterIndex;

    // set up emitter triangles
    mPrtParams.numEmitterTriangles = 0;
    mEmitterToMeshData.clear();
    for (uint32_t i = mPrtParams.stEmitterIndex; i < mPrtParams.stEmitterIndex + mPrtParams.numEmitters / 2; i++)
    {
        const MeshInstanceData& inst = mpScene->getMeshInstance(i);
        const MeshDesc& mesh = mpScene->getMesh(inst.meshID);

        mPrtParams.numEmitterTriangles += mesh.indexCount / 3;
        for (uint j = 0; j < mesh.indexCount / 3; j++)
        {
            mEmitterToMeshData.push_back(uint2(i, j));
        }
    }

    // compute bounding box for receivers
    auto& pBlock = mpScene->getParameterBlock();

    TypedBufferBase::SharedPtr worldMatBuffer = pBlock->getTypedBuffer("worldMatrices");
    float4 *worldMatData = static_cast<float4*>(worldMatBuffer->getData());

    bool first = true;
    for (uint i = mPrtParams.stReceiverIndex; i < mPrtParams.stReceiverIndex + mPrtParams.numReceivers; i++)
    {
        const MeshInstanceData& inst = mpScene->getMeshInstance(i);
        const BoundingBox& meshBB = mpScene->getMeshBounds(inst.meshID);

        uint32_t matrixID = inst.globalMatrixID;
        glm::float4x4 worldMat = {
            worldMatData[matrixID * 4 + 0],
            worldMatData[matrixID * 4 + 1],
            worldMatData[matrixID * 4 + 2],
            worldMatData[matrixID * 4 + 3]
        };

        BoundingBox instBB = meshBB.transform(worldMat);

        if (first)
        {
            mBoundExLight = instBB;
            first = false;
        }
        else
        {
            mBoundExLight = BoundingBox::fromUnion(mBoundExLight, instBB);
        }
    }

    uint3 gridReso = mPrtParams.gridReso;
    float3 dBoundary(1e-3, 1e-3, 1e-3);

    mPrtParams.gridMin = mBoundExLight.getMinPos() - dBoundary;
    mPrtParams.gridStep = (mBoundExLight.getMaxPos() - mBoundExLight.getMinPos() + dBoundary * 2.f) /
        float3(gridReso - uint3(1, 1, 1));

    // create buffers
    if (mpEmitterToMesh) mpEmitterToMesh = nullptr;
    mpEmitterToMesh = TypedBuffer<uint2>::create((uint32_t)(mEmitterToMeshData.size()));
    mpEmitterToMesh->setBlob((void*)(mEmitterToMeshData.data()), 0, mEmitterToMeshData.size() * sizeof(uint2));

#if 0
    // debug
    FILE *fp = fopen((mOutDir + "debug_bound.txt").c_str(), "w");
    fprintf(fp, "%.6f %.6f %.6f\n", mBoundExLight.getMinPos().x, mBoundExLight.getMinPos().y, mBoundExLight.getMinPos().z);
    fprintf(fp, "%.6f %.6f %.6f\n", mBoundExLight.getMaxPos().x, mBoundExLight.getMaxPos().y, mBoundExLight.getMaxPos().z);
    fprintf(fp, "%.6f %.6f %.6f\n", mPrtParams.gridStep.x, mPrtParams.gridStep.y, mPrtParams.gridStep.z);
    fclose(fp);

    fp = fopen((mOutDir + "debug_scene.txt").c_str(), "w");
    fprintf(fp, "Prt params\n");
    fprintf(fp, "Receiver: %u %u\n", mPrtParams.stReceiverIndex, mPrtParams.numReceivers);
    fprintf(fp, "Emitter: %u %u\n", mPrtParams.stEmitterIndex, mPrtParams.numEmitters);
    fprintf(fp, "\n");

    fprintf(fp, "num vertices = %u\n", mPrtParams.numVertex);
    fprintf(fp, "num emitter triangles = %u\n", mPrtParams.numEmitterTriangles);
    //for (uint32_t i = 0; i < mPrtParams.numEmitterTriangles; i++)
    //    fprintf(fp, "%u %u\n", mEmitterToMeshData[i].x, mEmitterToMeshData[i].y);
    fprintf(fp, "\n");

    for (uint32_t i = 0; i < mpScene->getMeshInstanceCount(); i++)
    {
        const MeshInstanceData& inst = mpScene->getMeshInstance(i);
        const MeshDesc& mesh = mpScene->getMesh(inst.meshID);
        const auto& mat = mpScene->getMaterial(inst.materialID);
        fprintf(fp, "instance %u: mesh %u, mat %u, emissive %u\n", i, inst.meshID, inst.materialID,
            mat->isEmissive());
    }

    fprintf(fp, "\n");
    for (uint32_t i = 0; i < mpScene->getMeshCount(); i++)
    {
        const MeshDesc& mesh = mpScene->getMesh(i);
        fprintf(fp, "mesh %u: vbOffset = %u, ibOffset = %u\n",
            i, mesh.vbOffset, mesh.ibOffset);
        fprintf(fp, "vertex count = %u, index count = %u, material id = %u\n",
            mesh.vertexCount, mesh.indexCount, mesh.materialID);
    }

    fprintf(fp, "\n");
    for (uint32_t i = 0; i < mpScene->getMaterialCount(); i++)
    {
        const auto& mat = mpScene->getMaterial(i);
        fprintf(fp, "material %u\n", i);
        fprintf(fp, "diffuse color = (%.6f, %.6f, %.6f)\n", mat->getBaseColor().r, mat->getBaseColor().g,
            mat->getBaseColor().b);

        float roughness = 1.f - mat->getSpecularParams().a;
        float shininess = convertRoughnessToShininess(roughness);
        fprintf(fp, "specular color = (%.6f, %.6f, %.6f), %.6f\n", mat->getSpecularParams().r, mat->getSpecularParams().g,
            mat->getSpecularParams().b, shininess);

        fprintf(fp, "emissive color = (%.6f, %.6f, %.6f)\n", mat->getEmissiveColor().r, mat->getEmissiveColor().g,
            mat->getEmissiveColor().b);
        fprintf(fp, "double side = %u\n", mat->isDoubleSided());
    }
    fclose(fp);
#endif
}

void PrtBase::setScene(RenderContext* pRenderContext, const Scene::SharedPtr& pScene)
{
    GBuffer::setScene(pRenderContext, pScene);

    mRaster.pProgram->addDefines(pScene->getSceneDefines());
    mpSampleGenerator->prepareProgram(mRaster.pProgram.get());

    mRaster.pVars = GraphicsVars::create(mRaster.pProgram.get());
    if (!mRaster.pVars) throw std::exception("Failed to create program vars");

    mpComputePerVertexPass->getProgram()->addDefines(pScene->getSceneDefines());
    mpComputePerVertexPass->setVars(ComputeVars::create(mpComputePerVertexPass->getProgram().get()));

    mpComputePerVertexPerLightPass->getProgram()->addDefines(pScene->getSceneDefines());
    mpComputePerVertexPerLightPass->setVars(ComputeVars::create(mpComputePerVertexPerLightPass->getProgram().get()));

    mpPrecomputePass->getProgram()->addDefines(pScene->getSceneDefines());
    mpSampleGenerator->prepareProgram(mpPrecomputePass->getProgram().get());
    mpPrecomputePass->setVars(ComputeVars::create(mpPrecomputePass->getProgram().get()));

    mpZhToShPass->setVars(ComputeVars::create(mpZhToShPass->getProgram().get()));

    mpHermitePass->getProgram()->addDefines(pScene->getSceneDefines());
    mpHermitePass->setVars(ComputeVars::create(mpHermitePass->getProgram().get()));
 
    mpTrilinearPass->getProgram()->addDefines(pScene->getSceneDefines());
    mpTrilinearPass->setVars(ComputeVars::create(mpTrilinearPass->getProgram().get()));

    mpTaylorPass->getProgram()->addDefines(pScene->getSceneDefines());
    mpTaylorPass->setVars(ComputeVars::create(mpTaylorPass->getProgram().get()));

    mpLightCenterPass->getProgram()->addDefines(pScene->getSceneDefines());
    mpLightCenterPass->setVars(ComputeVars::create(mpLightCenterPass->getProgram().get()));

    if (mpDepthPrePassGraph) mpDepthPrePassGraph->setScene(pScene);

    // set up related parameters
    preprocessScene();
    prepAnimation();
}

void PrtBase::setCullMode(RasterizerState::CullMode mode)
{
    GBuffer::setCullMode(mode);
    RasterizerState::Desc rsDesc;
    rsDesc.setCullMode(mCullMode);
    mRaster.pState->setRasterizerState(RasterizerState::create(rsDesc));
}

void PrtBase::computePerVertex(RenderContext* pRenderContext, const RenderData& renderData)
{
    const ComputeVars::SharedPtr& pVars = mpComputePerVertexPass->getVars();
    pVars->setParameterBlock("gScene", mpScene->getParameterBlock());

    // binding variables
    mpComputePerVertexPass["PerFrameCB"]["gParams"].setBlob(mGBufferParams);
    mpComputePerVertexPass["PerFrameCB"]["gPrtParams"].setBlob(mPrtParams);

    mpComputePerVertexPass["gLegendre2345"] = mpLegendre2345;
    mpComputePerVertexPass["gLegendre6789"] = mpLegendre6789;
    mpComputePerVertexPass["gLegendreSampler"] = mpLegendreSampler;

    mpComputePerVertexPass["gEmitterToMesh"] = mpEmitterToMesh;
    mpComputePerVertexPass["gTcoeffs"] = mpTcoeffs;

    mpComputePerVertexPass["gVertexColors"] = mpVertexColors;

    for (uint i = 0; i < mPrtParams.numEmitterTriangles; i += mPrtParams.emitterGroupSize)
    {
        mpComputePerVertexPass["PerFrameCB"]["gLightStartIndex"] = i;
        mpComputePerVertexPass->execute(pRenderContext, mPrtParams.numVertex, 3, 1);
    }
}

void PrtBase::computePerVertexPerLight(RenderContext* pRenderContext, const RenderData& renderData)
{
    const ComputeVars::SharedPtr& pVars = mpComputePerVertexPerLightPass->getVars();
    pVars->setParameterBlock("gScene", mpScene->getParameterBlock());

    // binding variables
    mpComputePerVertexPerLightPass["PerFrameCB"]["gParams"].setBlob(mGBufferParams);
    mpComputePerVertexPerLightPass["PerFrameCB"]["gPrtParams"].setBlob(mPrtParams);

    mpComputePerVertexPerLightPass["gLegendre2345"] = mpLegendre2345;
    mpComputePerVertexPerLightPass["gLegendre6789"] = mpLegendre6789;
    mpComputePerVertexPerLightPass["gLegendreSampler"] = mpLegendreSampler;

    mpComputePerVertexPerLightPass["gEmitterToMesh"] = mpEmitterToMesh;
    mpComputePerVertexPerLightPass["gTcoeffs"] = mpTcoeffs;

    mpComputePerVertexPerLightPass["gVertexColors"] = mpVertexColors;

    mpComputePerVertexPerLightPass->execute(pRenderContext, mPrtParams.numVertex, 3, 1);
}

void PrtBase::precomputeSH(RenderContext* pRenderContext, const RenderData& renderData)
{
    // precompute ZH coefficients at grid points
    const ComputeVars::SharedPtr& pVars = mpPrecomputePass->getVars();
    pVars->setParameterBlock("gScene", mpScene->getParameterBlock());

    bool success = mpSampleGenerator->setIntoProgramVars(pVars.get());
    if (!success) throw std::exception("Failed to bind sample generator");

    // binding variables
    mpPrecomputePass["PerFrameCB"]["gParams"].setBlob(mGBufferParams);
    mpPrecomputePass["PerFrameCB"]["gPrtParams"].setBlob(mPrtParams);

    mpPrecomputePass["gLegendre2345"] = mpLegendre2345;
    mpPrecomputePass["gLegendre6789"] = mpLegendre6789;
    mpPrecomputePass["gLegendreSampler"] = mpLegendreSampler;

    mpPrecomputePass["gEmitterToMesh"] = mpEmitterToMesh;

    mpPrecomputePass["gZhGrad"] = mpZhGrad;

    for (uint i = 0; i < mPrtParams.numEmitterTriangles; i += mPrtParams.emitterGroupSize)
    {
        mpPrecomputePass["PerFrameCB"]["gLightStartIndex"] = i;
        mpPrecomputePass->execute(pRenderContext, mPrtParams.gridReso.x * mPrtParams.gridReso.y * mPrtParams.gridReso.z,
            2 * mPrtParams.lmax - 1, 1);
    }

    // ZH to SH
    mpZhToShPass["PerFrameCB"]["gParams"].setBlob(mGBufferParams);
    mpZhToShPass["PerFrameCB"]["gPrtParams"].setBlob(mPrtParams);

    mpZhToShPass["gZhGrad"] = mpZhGrad;

    for (uint c = 0; c < 3; c++)
    {
        std::string st = "gShGrad[" + std::to_string(c) + "]";
        mpZhToShPass[st] = mpShGrad[c];
    }

    mpZhToShPass->execute(pRenderContext, mPrtParams.gridReso.x * mPrtParams.gridReso.y * mPrtParams.gridReso.z, 3, 1);
}

void PrtBase::interpolateHermite(RenderContext* pRenderContext, const RenderData& renderData)
{
    const ComputeVars::SharedPtr& pVars = mpHermitePass->getVars();
    pVars->setParameterBlock("gScene", mpScene->getParameterBlock());

    // binding variables
    mpHermitePass["PerFrameCB"]["gParams"].setBlob(mGBufferParams);
    mpHermitePass["PerFrameCB"]["gPrtParams"].setBlob(mPrtParams);

    mpHermitePass["gTcoeffs"] = mpTcoeffs;

    for (uint c = 0; c < 3; c++)
    {
        std::string st = "gShGrad[" + std::to_string(c) + "]";
        mpHermitePass[st] = mpShGrad[c];
    }

    mpHermitePass["gVertexColors"] = mpVertexColors;

    mpHermitePass->execute(pRenderContext, mPrtParams.numVertex, 3, 1);
}

void PrtBase::interpolateTrilinear(RenderContext* pRenderContext, const RenderData& renderData)
{
    const ComputeVars::SharedPtr& pVars = mpTrilinearPass->getVars();
    pVars->setParameterBlock("gScene", mpScene->getParameterBlock());

    // binding variables
    mpTrilinearPass["PerFrameCB"]["gParams"].setBlob(mGBufferParams);
    mpTrilinearPass["PerFrameCB"]["gPrtParams"].setBlob(mPrtParams);

    mpTrilinearPass["gTcoeffs"] = mpTcoeffs;

    for (uint c = 0; c < 3; c++)
    {
        std::string st = "gShGrad[" + std::to_string(c) + "]";
        mpTrilinearPass[st] = mpShGrad[c];
    }

    mpTrilinearPass["gVertexColors"] = mpVertexColors;

    mpTrilinearPass->execute(pRenderContext, mPrtParams.numVertex, 3, 1);
}

void PrtBase::interpolateTaylor(RenderContext* pRenderContext, const RenderData& renderData)
{
    const ComputeVars::SharedPtr& pVars = mpTaylorPass->getVars();
    pVars->setParameterBlock("gScene", mpScene->getParameterBlock());

    // binding variables
    mpTaylorPass["PerFrameCB"]["gParams"].setBlob(mGBufferParams);
    mpTaylorPass["PerFrameCB"]["gPrtParams"].setBlob(mPrtParams);

    mpTaylorPass["gTcoeffs"] = mpTcoeffs;

    for (uint c = 0; c < 3; c++)
    {
        std::string st = "gShGrad[" + std::to_string(c) + "]";
        mpTaylorPass[st] = mpShGrad[c];
    }

    mpTaylorPass["gVertexColors"] = mpVertexColors;

    mpTaylorPass->execute(pRenderContext, mPrtParams.numVertex, 3, 1);
}

void PrtBase::execute(RenderContext* pRenderContext, const RenderData& renderData)
{
    // Update refresh flag if options that affect the output have changed.
    if (mOptionsChanged)
    {
        Dictionary& dict = renderData.getDictionary();
        auto prevFlags = (Falcor::RenderPassRefreshFlags)(dict.keyExists(kRenderPassRefreshFlags) ? dict[Falcor::kRenderPassRefreshFlags] : 0u);
        dict[Falcor::kRenderPassRefreshFlags] = (uint32_t)(prevFlags | Falcor::RenderPassRefreshFlags::RenderOptionsChanged);
        mOptionsChanged = false;

        mGBufferParams.frameCount = 0;
    }

    if (mpScene == nullptr)
    {
        logWarning("PrtViewer::execute() - No scene available");
        return;
    }

    // Compute light centers for animation
    // Update once the scene is reloaded
    if (!mLightCentersReady && mAnimation > 0)
    {
        calcLightCenters(pRenderContext, renderData);
        mLightCentersReady = true;
    }

    // Perform animation.
    if (mAnimation == 1)
        animateDragonAndBunny();

    mpDepthPrePassGraph->execute(pRenderContext);
    mpFbo->attachDepthStencilTarget(mpDepthPrePassGraph->getOutput("DepthPrePass.depth")->asTexture());
    pRenderContext->copyResource(renderData["depthStencil"].get(), mpDepthPrePassGraph->getOutput("DepthPrePass.depth").get());

    for (int i = 0; i < kGBufferChannels.size(); ++i)
    {
        Texture::SharedPtr pTex = renderData[kGBufferChannels[i].name]->asTexture();
        mpFbo->attachColorTarget(pTex, i);
    }

    pRenderContext->clearFbo(mpFbo.get(), vec4(0), 1.f, 0, FboAttachmentType::Color);
    mRaster.pState->setFbo(mpFbo);

    bool success = mpSampleGenerator->setIntoProgramVars(mRaster.pVars.get());
    if (!success) throw std::exception("Failed to bind sample generator");

    if (mAlgo == 0)
    {
        Profiler::startEvent("Analytic 2018");
        computePerVertex(pRenderContext, renderData);
        Profiler::endEvent("Analytic 2018");
    }
    else if (mAlgo == 1)
    {
        Profiler::startEvent("Precompute grid");
        precomputeSH(pRenderContext, renderData);
        Profiler::endEvent("Precompute grid");

        Profiler::startEvent("Hermite interp");
        interpolateHermite(pRenderContext, renderData);
        Profiler::endEvent("Hermite interp");
    }
    else if (mAlgo == 2)
    {
        Profiler::startEvent("Precompute grid");
        precomputeSH(pRenderContext, renderData);
        Profiler::endEvent("Precompute grid");

        Profiler::startEvent("Trilinear interp");
        interpolateTrilinear(pRenderContext, renderData);
        Profiler::endEvent("Trilinear interp");
    }
    else if (mAlgo == 3)
    {
        Profiler::startEvent("Precompute grid");
        precomputeSH(pRenderContext, renderData);
        Profiler::endEvent("Precompute grid");

        Profiler::startEvent("Taylor-series interp");
        interpolateTaylor(pRenderContext, renderData);
        Profiler::endEvent("Taylor-series interp");
    }
    else if (mAlgo == 4)
    {
        Profiler::startEvent("Reference");
        if (mGBufferParams.frameCount < mPrtParams.numEmitterTriangles)
            computePerVertexPerLight(pRenderContext, renderData);
        Profiler::endEvent("Reference");
    }

    // binding variables
    mRaster.pVars["PerFrameCB"]["gParams"].setBlob(mGBufferParams);
    mRaster.pVars["PerFrameCB"]["gPrtParams"].setBlob(mPrtParams);

    mRaster.pVars["gVertexColors"] = mpVertexColors;

    // UAV output variables
    for (auto it : kPrtOutputChannels)
    {
        Texture::SharedPtr pTex = renderData[it.name]->asTexture();
        if (pTex)
        {
            if (it.name == "color")
                pRenderContext->clearUAV(pTex->getUAV().get(), glm::vec4(0, 0, 0, 0));
                //pRenderContext->clearUAV(pTex->getUAV().get(), glm::vec4(0.5, 0.8, 0.25, 1.0));
            else
                pRenderContext->clearUAV(pTex->getUAV().get(), glm::vec4(0, 0, 0, 0));
        }
        mRaster.pVars[it.texname] = pTex;
    }

    Scene::RenderFlags flags = mForceCullMode ? Scene::RenderFlags::UserRasterizerState : Scene::RenderFlags::None;
    mpScene->render(pRenderContext, mRaster.pState.get(), mRaster.pVars.get(), flags);

    if (!mPauseAnimation)
        mGBufferParams.frameCount++;
}

void PrtBase::renderUI(Gui::Widgets& widget)
{
    static const Gui::DropdownList algos =
    {
        {0, "Analytic 2018"},
        {1, "Ours Hermite"},
        {2, "Trilinear"},
        {3, "Taylor"},
        {4, "Reference"}
    };
    static const Gui::DropdownList animations =
    {
        {0, "None"},
        {1, "Dragon and bunny"}
    };

    bool dirty = false;
    dirty |= widget.dropdown("Algorithm", algos, mAlgo);
    dirty |= widget.dropdown("Animation", animations, mAnimation);

    if (dirty)
        mOptionsChanged = true;
}

bool PrtBase::onKeyEvent(const KeyboardEvent& keyEvent)
{
    const CameraData& cameraData = mpScene->getCamera()->getData();
    auto& pBlock = mpScene->getParameterBlock();
    bool changed = false;

    TypedBufferBase::SharedPtr worldMatBuffer = pBlock->getTypedBuffer("worldMatrices");
    TypedBufferBase::SharedPtr invTWorldMatBuffer = pBlock->getTypedBuffer("inverseTransposeWorldMatrices");

    float4 *worldMatData = static_cast<float4*>(worldMatBuffer->getData());
    float4 *invTWorldMatData = static_cast<float4*>(invTWorldMatBuffer->getData());

    std::vector<glm::float4x4> worldMats(mPrtParams.numEmitters);
    std::vector<glm::float4x4> invTWorldMats(mPrtParams.numEmitters);

    for (uint i = mPrtParams.stEmitterIndex; i < mPrtParams.stEmitterIndex + mPrtParams.numEmitters; i++)
    {
        uint32_t matrixID = mpScene->getMeshInstance(i).globalMatrixID;
        worldMats[i - mPrtParams.stEmitterIndex] = {
           worldMatData[matrixID * 4 + 0],
           worldMatData[matrixID * 4 + 1],
           worldMatData[matrixID * 4 + 2],
           worldMatData[matrixID * 4 + 3]
        };
    }

    std::vector<glm::float4x4> transfMats(mPrtParams.numEmitters);

    if (keyEvent.key == KeyboardEvent::Key::Left && keyEvent.type == KeyboardEvent::Type::KeyPressed)
    {
        for (uint i = 0; i < mPrtParams.numEmitters; i++)
        {
            transfMats[i] = glm::translate(mat4(1), glm::float3(-0.05f, 0, 0));
        }
        changed = true;
    }
    else if (keyEvent.key == KeyboardEvent::Key::Right && keyEvent.type == KeyboardEvent::Type::KeyPressed)
    {
        for (uint i = 0; i < mPrtParams.numEmitters; i++)
        {
            transfMats[i] = glm::translate(mat4(1), glm::float3(0.05f, 0, 0));
        }
        changed = true;
    }
    else if (keyEvent.key == KeyboardEvent::Key::Up && keyEvent.type == KeyboardEvent::Type::KeyPressed)
    {
        for (uint i = 0; i < mPrtParams.numEmitters; i++)
        {
            transfMats[i] = glm::translate(mat4(1), glm::float3(0, 0.02f, 0));
        }
        changed = true;
    }
    else if (keyEvent.key == KeyboardEvent::Key::Down && keyEvent.type == KeyboardEvent::Type::KeyPressed)
    {
        for (uint i = 0; i < mPrtParams.numEmitters; i++)
        {
            transfMats[i] = glm::translate(mat4(1), glm::float3(0, -0.02f, 0));
        }
        changed = true;
    }
    else if (keyEvent.key == KeyboardEvent::Key::LeftBracket && keyEvent.type == KeyboardEvent::Type::KeyPressed)
    {
        for (uint i = 0; i < mPrtParams.numEmitters; i++)
        {
            transfMats[i] = glm::rotate(mat4(1), glm::radians(-1.f), glm::float3(1.f, 0, 0));
        }
        changed = true;
    }
    else if (keyEvent.key == KeyboardEvent::Key::RightBracket && keyEvent.type == KeyboardEvent::Type::KeyPressed)
    {
        for (uint i = 0; i < mPrtParams.numEmitters; i++)
        {
            transfMats[i] = glm::rotate(mat4(1), glm::radians(1.f), glm::float3(1.f, 0, 0));
        }
        changed = true;
    }
    else if (keyEvent.key == KeyboardEvent::Key::Semicolon && keyEvent.type == KeyboardEvent::Type::KeyPressed)
    {
        for (uint i = 0; i < mPrtParams.numEmitters; i++)
        {
            transfMats[i] = glm::rotate(mat4(1), glm::radians(-1.f), glm::float3(0, 1.f, 0));
        }
        changed = true;
    }
    else if (keyEvent.key == KeyboardEvent::Key::Apostrophe && keyEvent.type == KeyboardEvent::Type::KeyPressed)
    {
        for(uint i = 0; i < mPrtParams.numEmitters; i++)
        {
            transfMats[i] = glm::rotate(mat4(1), glm::radians(1.f), glm::float3(0, 1.f, 0));
        }
        changed = true;
    }
    else if (keyEvent.key == KeyboardEvent::Key::Period && keyEvent.type == KeyboardEvent::Type::KeyPressed)
    {
        for (uint i = 0; i < mPrtParams.numEmitters; i++)
        {
            transfMats[i] = glm::rotate(mat4(1), glm::radians(-1.f), glm::float3(0, 0, 1.f));
        }
        changed = true;
    }
    else if (keyEvent.key == KeyboardEvent::Key::Slash && keyEvent.type == KeyboardEvent::Type::KeyPressed)
    {
        for (uint i = 0; i < mPrtParams.numEmitters; i++)
        {
            transfMats[i] = glm::rotate(mat4(1), glm::radians(1.f), glm::float3(0, 0, 1.f));
        }
        changed = true;
    }
    // Control animation
    else if (keyEvent.key == KeyboardEvent::Key::Key0 && keyEvent.type == KeyboardEvent::Type::KeyPressed)
    {
        mPauseAnimation = (!mPauseAnimation);
    }
    else if (keyEvent.key == KeyboardEvent::Key::Minus && keyEvent.type == KeyboardEvent::Type::KeyPressed)
    {
        mGBufferParams.frameCount -= 1;
    }
    else if (keyEvent.key == KeyboardEvent::Key::Equal && keyEvent.type == KeyboardEvent::Type::KeyPressed)
    {
        mGBufferParams.frameCount += 1;
    }
    else if (keyEvent.key == KeyboardEvent::Key::Key9 && keyEvent.type == KeyboardEvent::Type::KeyPressed)
    {
        mPrintDebug = (!mPrintDebug);
    }

    if (changed)
    {
        for (uint i = 0; i < mPrtParams.numEmitters; i++)
        {
            worldMats[i] = transfMats[i] * worldMats[i];
            invTWorldMats[i] = glm::transpose(glm::inverse(worldMats[i]));
        }

        for (uint i = mPrtParams.stEmitterIndex; i < mPrtParams.stEmitterIndex + mPrtParams.numEmitters; i++)
        {
            uint32_t matrixID = mpScene->getMeshInstance(i).globalMatrixID;
            for (int j = 0; j < 4; j++)
            {
                worldMatData[matrixID * 4 + j] = worldMats[i - mPrtParams.stEmitterIndex][j];
                invTWorldMatData[matrixID * 4 + j] = invTWorldMats[i - mPrtParams.stEmitterIndex][j];
            }
        }
        worldMatBuffer->setBlob(worldMatData, 0, worldMatBuffer->getSize());
        invTWorldMatBuffer->setBlob(invTWorldMatData, 0, invTWorldMatBuffer->getSize());

        mGBufferParams.frameCount = 0;
    }

    return changed;
}

void PrtBase::prepAnimation()
{
    mLightCentersReady = false;
    if (mpLightCenters)
        mpLightCenters = nullptr;

    mGBufferParams.frameCount = 0;

    // assume each light mesh has only 1 triangle!
    mpLightCenters = TypedBuffer<float>::create(mPrtParams.numEmitters * 3);

    // save objectToWorld transformation matrices for all the light triangles
    auto& pBlock = mpScene->getParameterBlock();
    TypedBufferBase::SharedPtr worldMatBuffer = pBlock->getTypedBuffer("worldMatrices");
    mObjectToWorld.resize(mPrtParams.numEmitters);

    float4* worldMatData = static_cast<float4*>(worldMatBuffer->getData());
    for (uint i = mPrtParams.stEmitterIndex; i < mPrtParams.stEmitterIndex + mPrtParams.numEmitters; i++)
    {
        uint32_t matrixID = mpScene->getMeshInstance(i).globalMatrixID;
        mObjectToWorld[i - mPrtParams.stEmitterIndex] = {
           worldMatData[matrixID * 4 + 0],
           worldMatData[matrixID * 4 + 1],
           worldMatData[matrixID * 4 + 2],
           worldMatData[matrixID * 4 + 3]
        };
    }
}

void PrtBase::calcLightCenters(RenderContext* pRenderContext, const RenderData& renderData)
{
    // calculate light centers
    const ComputeVars::SharedPtr& pVars = mpLightCenterPass->getVars();
    pVars->setParameterBlock("gScene", mpScene->getParameterBlock());

    // binding variables
    mpLightCenterPass["PerFrameCB"]["gParams"].setBlob(mGBufferParams);
    mpLightCenterPass["PerFrameCB"]["gPrtParams"].setBlob(mPrtParams);

    mpLightCenterPass["gLightCenters"] = mpLightCenters;

    mpLightCenterPass->execute(pRenderContext, mPrtParams.numEmitters, 1, 1);

#if 0
    // debug
    float3* lightCenters = static_cast<float3*>(mpLightCenters->getData());
    FILE* fp = fopen((mOutDir + "debug_animation.txt").c_str(), "w");
    for (uint i = 0; i < mPrtParams.numEmitters; i++)
    {
        fprintf(fp, "Light %u: %.6f %.6f %.6f\n", i, lightCenters[i].x, lightCenters[i].y, lightCenters[i].z);
        for (uint j = 0; j < 4; j++)
        {
            for (uint k = 0; k < 4; k++)
                fprintf(fp, "%.6f ", mObjectToWorld[i][k][j]);
            fprintf(fp, "\n");
        }
        fprintf(fp, "\n");
    }
    fclose(fp);
#endif
}

void PrtBase::setObjToWorldMatrices(std::vector<float4x4> *before, std::vector<float4x4> *after)
{
    auto& pBlock = mpScene->getParameterBlock();

    TypedBufferBase::SharedPtr worldMatBuffer = pBlock->getTypedBuffer("worldMatrices");
    TypedBufferBase::SharedPtr invTWorldMatBuffer = pBlock->getTypedBuffer("inverseTransposeWorldMatrices");

    float4* worldMatData = static_cast<float4*>(worldMatBuffer->getData());
    float4* invTWorldMatData = static_cast<float4*>(invTWorldMatBuffer->getData());

    std::vector<glm::float4x4> worldMats(mPrtParams.numEmitters);
    std::vector<glm::float4x4> invTWorldMats(mPrtParams.numEmitters);

    for (uint i = 0; i < mPrtParams.numEmitters; i++)
    {
        if (before != NULL)
            worldMats[i] = mObjectToWorld[i] * (*before)[i];
        else
            worldMats[i] = mObjectToWorld[i];

        if (after != NULL)
            worldMats[i] = (*after)[i] * worldMats[i];
        invTWorldMats[i] = glm::transpose(glm::inverse(worldMats[i]));
    }

    for (uint i = mPrtParams.stEmitterIndex; i < mPrtParams.stEmitterIndex + mPrtParams.numEmitters; i++)
    {
        uint32_t matrixID = mpScene->getMeshInstance(i).globalMatrixID;
        for (int j = 0; j < 4; j++)
        {
            worldMatData[matrixID * 4 + j] = worldMats[i - mPrtParams.stEmitterIndex][j];
            invTWorldMatData[matrixID * 4 + j] = invTWorldMats[i - mPrtParams.stEmitterIndex][j];
        }
    }

    worldMatBuffer->setBlob(worldMatData, 0, worldMatBuffer->getSize());
    invTWorldMatBuffer->setBlob(invTWorldMatData, 0, invTWorldMatBuffer->getSize());
}

void PrtBase::animateDragonAndBunny()
{
    const uint32_t duration = 900;
    const float radiationScale = .8f;
    const float rollScale = 0.17f;
    const float rollScale2 = 0.2f;

    float t = (mGBufferParams.frameCount % duration) / (float)duration;
    float3* lightCenters = static_cast<float3*>(mpLightCenters->getData());

    std::vector<mat4> transforms(mPrtParams.numEmitters);
    const float timeline[7] = { 0.f, 0.05f, 0.2f, 0.45f, 0.7f, 0.95f, 1.f };

    if (t < timeline[1])
    {
        for (uint i = 0; i < mPrtParams.numEmitters; i++)
        {
            transforms[i] = mat4(1);
        }
    }
    else if (t < timeline[2])
    {
        // moving out
        t = (t - timeline[1]) / (timeline[2] - timeline[1]);
        mat4 rotation = glm::rotate(mat4(1), 1.5f * t * pi<float>(), glm::vec3(0, 1, 0));
        for (uint i = 0; i < mPrtParams.numEmitters; i++)
        {
            vec3 center = lightCenters[i];
            center.y = 0;

            mat4 radiation = translate(mat4(1), center * radiationScale * t);
            transforms[i] = rotation * radiation;
        }
    }
    else if (t < timeline[3])
    {
        // rolling 1
        t = (t - timeline[2]) / (timeline[3] - timeline[2]);
        mat4 rotation = glm::rotate(mat4(1), 1.5f * pi<float>(), glm::vec3(0, 1, 0));
        for (uint i = 0; i < mPrtParams.numEmitters; i++)
        {
            vec3 center = lightCenters[i];
            mat4 roll = rotate(mat4(1), rollScale2 * glm::sin(t * 2.f * pi<float>()), glm::vec3(1, 0, 0));
            center.y = 0;
            mat4 radiation = translate(mat4(1), center * radiationScale);

            transforms[i] = rotation * radiation * roll;
        }
    }
    else if (t < timeline[4])
    {
        // rolling 2
        t = (t - timeline[3]) / (timeline[4] - timeline[3]);
        mat4 rotation = glm::rotate(mat4(1), 1.5f * pi<float>(), glm::vec3(0, 1, 0));
        for (uint i = 0; i < mPrtParams.numEmitters; i++)
        {
            vec3 center = lightCenters[i];
            mat4 roll = rotate(mat4(1), rollScale2 * glm::sin(t * 2.f * pi<float>()), glm::vec3(0, 0, 1));

            center.y = 0;
            mat4 radiation = translate(mat4(1), center * radiationScale);

            transforms[i] = rotation * radiation * roll;
        }
    }
    else if (t < timeline[5])
    {
        // moving back
        t = (t - timeline[4]) / (timeline[5] - timeline[4]);
        mat4 rotation = glm::rotate(mat4(1), (1.5f + 2.5f * t) * pi<float>(), glm::vec3(0, 1, 0));
        for (uint i = 0; i < mPrtParams.numEmitters; i++)
        {
            vec3 center = lightCenters[i];
            center.y = 0;

            mat4 radiation = translate(mat4(1), center * radiationScale * (1 - t));
            transforms[i] = rotation * radiation;
        }
    }
    else
    {
        for (uint i = 0; i < mPrtParams.numEmitters; i++)
        {
            transforms[i] = mat4(1);
        }
    }

    setObjToWorldMatrices(&transforms, NULL);
}
