/***************************************************************************
 # Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
 #
 # NVIDIA CORPORATION and its licensors retain all intellectual property
 # and proprietary rights in and to this software, related documentation
 # and any modifications thereto.  Any use, reproduction, disclosure or
 # distribution of this software and related documentation without an express
 # license agreement from NVIDIA CORPORATION is strictly prohibited.
 **************************************************************************/
#include "Falcor.h"
#include "PrtGlossy.h"
#include "RenderGraph/RenderPassStandardFlags.h"
#include "RenderPasses/DepthPass.h"

namespace
{
    // Shaders
    const char* kShShaderFile = "PrtViewer.slang";

    const char* kUnitTestShaderFile = "ShUnitTest.cs.slang";
}

PrtGlossy::SharedPtr PrtGlossy::create(RenderContext* pRenderContext, const Dictionary& dict)
{
    SharedPtr pPass = SharedPtr(new PrtGlossy);
    return pPass->init(dict) ? pPass : nullptr;
}

bool PrtGlossy::init(const Dictionary& dict)
{
    mShaderDir = std::string(PROJECT_DIR) + R"(glossyShaders\)";
    mSMVersion = "6_1";

    // Create raster program
    Program::Desc desc;
    desc.addShaderLibrary(mShaderDir + kShShaderFile).vsEntry("vs").psEntry("ps").setShaderModel("5_1");
    mRaster.pProgram = GraphicsProgram::create(desc, Program::DefineList());
    if (!mRaster.pProgram) throw std::exception("Failed to create program");

    // Initialize graphics state
    mRaster.pState = GraphicsState::create();
    mRaster.pState->setProgram(mRaster.pProgram);

    // Set default cull mode
    setCullMode(mCullMode);

    // Set depth function
    DepthStencilState::Desc dsDesc;
    dsDesc.setDepthFunc(DepthStencilState::Func::Equal).setDepthWriteMask(false);
    DepthStencilState::SharedPtr pDsState = DepthStencilState::create(dsDesc);
    mRaster.pState->setDepthStencilState(pDsState);

    mpFbo = Fbo::create();

    // Parse dictionary parameters
    bool flag = parseDictionary(dict);
    if (!flag) throw std::exception("Failed to parse dictionary");

    // Create random engine
    mpSampleGenerator = SampleGenerator::create(SAMPLE_GENERATOR_TINY_UNIFORM);
    if (!mpSampleGenerator) throw std::exception("Failed to create sample generator");
    mpSampleGenerator->prepareProgram(mRaster.pProgram.get());

    // Prepare PRT: loading coefficients
    preparePRT();

    prepareShaders();

    return true;
}

void PrtGlossy::preparePRT()
{
    int lmax = 5;
    mPrtParams.lmax = lmax;

    FILE* fp;

    if (mOutDir == "" || mCoeffFilename == "")
        logError("No outDir or coeffFilename specified!");

    // load transfer coefficients
    std::string coeffFile = mOutDir + mCoeffFilename;

    std::vector<float> Tcoeffs;
    Tcoeffs.reserve(500000 * lmax * lmax);

    int numVertex = 0;
    fp = fopen(coeffFile.c_str(), "r");
    float tmpVal;
    while (fscanf(fp, "%f ", &tmpVal) != EOF)
    {
        Tcoeffs.push_back(tmpVal);
        for (int i = 1; i < lmax * lmax; i++)
        {
            fscanf(fp, "%f ", &tmpVal);
            Tcoeffs.push_back(tmpVal);
        }
        for (int i = lmax * lmax; i < 9 * 9; i++)
            fscanf(fp, "%f ", &tmpVal);
        numVertex++;
    }
    fclose(fp);

    mPrtParams.numVertex = numVertex;
    mpTcoeffs = TypedBuffer<float>::create(numVertex * lmax * lmax);
    mpTcoeffs->setBlob((void*)(Tcoeffs.data()), 0, numVertex * lmax * lmax * sizeof(float));

    // precompute Legendre polynomials
    int legendreReso = 10000;
    mPrtParams.legendreReso = legendreReso;

    std::vector<float> legendre_2345((legendreReso + 1) * 4, 0);
    std::vector<float> legendre_6789((legendreReso + 1) * 4, 0);

    for (int i = 0; i <= legendreReso; i++)
    {
        float X[10];
        X[1] = (float)i / legendreReso;
        for (int j = 2; j < 10; j++)
            X[j] = X[j - 1] * X[1];

        legendre_2345[i * 4 + 0] = (0.5f * (3 * X[2] - 1));
        legendre_2345[i * 4 + 1] = (0.5f * (5 * X[3] - 3 * X[1]));
        legendre_2345[i * 4 + 2] = ((35 * X[4] - 30 * X[2] + 3) / 8.0f);
        legendre_2345[i * 4 + 3] = ((63 * X[5] - 70 * X[3] + 15 * X[1]) / 8.0f);

        legendre_6789[i * 4 + 0] = ((231 * X[6] - 315 * X[4] + 105 * X[2] - 5) / 16.0f);
        legendre_6789[i * 4 + 1] = ((429 * X[7] - 693 * X[5] + 315 * X[3] - 35 * X[1]) / 16.0f);
        legendre_6789[i * 4 + 2] = ((6435 * X[8] - 12012 * X[6] + 6930 * X[4] - 1260 * X[2] + 35) / 128.0f);
        legendre_6789[i * 4 + 3] = ((12155 * X[9] - 25740 * X[7] + 18018 * X[5] - 4620 * X[3] + 315 * X[1]) / 128.0f);
    }

    mpLegendre2345 = Texture::create1D(legendreReso + 1, ResourceFormat::RGBA32Float, 1, Resource::kMaxPossible,
        (void*)(legendre_2345.data()),
        Resource::BindFlags::ShaderResource | Resource::BindFlags::RenderTarget | Resource::BindFlags::UnorderedAccess);
    mpLegendre6789 = Texture::create1D(legendreReso + 1, ResourceFormat::RGBA32Float, 1, Resource::kMaxPossible,
        (void*)(legendre_6789.data()),
        Resource::BindFlags::ShaderResource | Resource::BindFlags::RenderTarget | Resource::BindFlags::UnorderedAccess);

    Sampler::Desc samplerDesc;
    // linear interp
    samplerDesc.setFilterMode(Sampler::Filter::Linear, Sampler::Filter::Linear, Sampler::Filter::Linear);
    samplerDesc.setAddressingMode(Sampler::AddressMode::Clamp, Sampler::AddressMode::Clamp, Sampler::AddressMode::Clamp);
    mpLegendreSampler = Sampler::create(samplerDesc);

    mPrtParams.emitterGroupSize = 16;
    mPrtParams.gridReso = uint3(mGridReso, mGridReso, mGridReso);
    mPrtParams.lightDisplayScale = mLightDisplayScale;
    mPrtParams.lightIntensityScale = mLightIntensityScale;

    // Create lighting coefficient and vertex color buffers
    mpLcoeffs = TypedBuffer<float>::create(mPrtParams.numVertex * lmax * lmax * 3);
    mpVertexColors = TypedBuffer<float>::create(mPrtParams.numVertex * 3);

    uint3 gridReso = mPrtParams.gridReso;
    mpZhGrad = TypedBuffer<float>::create(gridReso.x * gridReso.y * gridReso.z * 3 * 4 * (2 * lmax - 1) * lmax);
    for (uint c = 0; c < 3; c++)
        mpShGrad[c] = TypedBuffer<float4>::create((lmax * lmax) * gridReso.x * gridReso.y * gridReso.z);
}

void PrtGlossy::computePerVertex(RenderContext* pRenderContext, const RenderData& renderData)
{
    const ComputeVars::SharedPtr& pVars = mpComputePerVertexPass->getVars();
    pVars->setParameterBlock("gScene", mpScene->getParameterBlock());

    // binding variables
    mpComputePerVertexPass["PerFrameCB"]["gParams"].setBlob(mGBufferParams);
    mpComputePerVertexPass["PerFrameCB"]["gPrtParams"].setBlob(mPrtParams);

    mpComputePerVertexPass["gLegendre2345"] = mpLegendre2345;
    mpComputePerVertexPass["gLegendre6789"] = mpLegendre6789;
    mpComputePerVertexPass["gLegendreSampler"] = mpLegendreSampler;

    mpComputePerVertexPass["gEmitterToMesh"] = mpEmitterToMesh;
    mpComputePerVertexPass["gTcoeffs"] = mpTcoeffs;

    mpComputePerVertexPass["gLcoeffs"] = mpLcoeffs;
    mpComputePerVertexPass["gVertexColors"] = mpVertexColors;

    for (uint i = 0; i < mPrtParams.numEmitterTriangles; i += mPrtParams.emitterGroupSize)
    {
        mpComputePerVertexPass["PerFrameCB"]["gLightStartIndex"] = i;
        mpComputePerVertexPass->execute(pRenderContext, mPrtParams.numVertex, 3, 1);
    }
}

void PrtGlossy::computePerVertexPerLight(RenderContext* pRenderContext, const RenderData& renderData)
{
    const ComputeVars::SharedPtr& pVars = mpComputePerVertexPerLightPass->getVars();
    pVars->setParameterBlock("gScene", mpScene->getParameterBlock());

    // binding variables
    mpComputePerVertexPerLightPass["PerFrameCB"]["gParams"].setBlob(mGBufferParams);
    mpComputePerVertexPerLightPass["PerFrameCB"]["gPrtParams"].setBlob(mPrtParams);

    mpComputePerVertexPerLightPass["gLegendre2345"] = mpLegendre2345;
    mpComputePerVertexPerLightPass["gLegendre6789"] = mpLegendre6789;
    mpComputePerVertexPerLightPass["gLegendreSampler"] = mpLegendreSampler;

    mpComputePerVertexPerLightPass["gEmitterToMesh"] = mpEmitterToMesh;
    mpComputePerVertexPerLightPass["gTcoeffs"] = mpTcoeffs;

    mpComputePerVertexPerLightPass["gLcoeffs"] = mpLcoeffs;
    mpComputePerVertexPerLightPass["gVertexColors"] = mpVertexColors;

    mpComputePerVertexPerLightPass->execute(pRenderContext, mPrtParams.numVertex, 3, 1);
}
