/***************************************************************************
 # Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
 #
 # NVIDIA CORPORATION and its licensors retain all intellectual property
 # and proprietary rights in and to this software, related documentation
 # and any modifications thereto.  Any use, reproduction, disclosure or
 # distribution of this software and related documentation without an express
 # license agreement from NVIDIA CORPORATION is strictly prohibited.
 **************************************************************************/
#include "Falcor.h"
#include "RenderGraph/RenderPassStandardFlags.h"
#include "PrtPrecompute.h"

namespace
{
    const char* kPrecomputeShaderFile = "precompute.rt.slang";
    const char* kPixelShaderFile = "pixel.cs.slang";

    const char* kEntryPointRayGen = "rayGen";
    const char* kEntryPointMiss0 = "primaryMiss";
    const char* kEntryPrimaryAnyHit = "primaryAnyHit";
    const char* kEntryPrimaryClosestHit = "primaryClosestHit";

    const char* kOutDir = "outDir";
    const char* kOutFilename = "outFilename";

    const ChannelList kOutputChannels =
    {
        { "debug", "gDebug", "Debug information", false, ResourceFormat::RGBA32Float },
    };
};

// Don't remove this. it's required for hot-reload to function properly
extern "C" __declspec(dllexport) const char* getProjDir()
{
    return PROJECT_DIR;
}

extern "C" __declspec(dllexport) void getPasses(Falcor::RenderPassLibrary& lib)
{
    lib.registerClass("PrtPrecompute", "Render Pass Template", PrtPrecompute::create);
}

PrtPrecompute::SharedPtr PrtPrecompute::create(RenderContext* pRenderContext, const Dictionary& dict)
{
    SharedPtr pPass = SharedPtr(new PrtPrecompute);
    return pPass->init(dict) ? pPass : nullptr;
}

bool PrtPrecompute::init(const Dictionary& dict)
{
    // Create program
    std::string shaderDir = std::string(PROJECT_DIR);
    mOutDir = shaderDir;

    // Parse dictionary parameters
    bool flag = parseDictionary(dict);
    if (!flag) throw std::exception("Failed to parse dictionary");

    // Create ray tracing program
    RtProgram::Desc progDesc;
    progDesc.addShaderLibrary(shaderDir + kPrecomputeShaderFile).setRayGen(kEntryPointRayGen);
    progDesc.addHitGroup(0, kEntryPrimaryClosestHit, kEntryPrimaryAnyHit).addMiss(0, kEntryPointMiss0);
    mRaytrace.pProgram = RtProgram::create(progDesc);
    if (!mRaytrace.pProgram) throw std::exception("Failed to create program");

    // Initialize ray tracing state
    mRaytrace.pState = RtState::create();
    mRaytrace.pState->setMaxTraceRecursionDepth(1);     // Max trace depth 1 allows TraceRay to be called from RGS, but no secondary rays.
    mRaytrace.pState->setProgram(mRaytrace.pProgram);

    // Create random engine
    mpSampleGenerator = SampleGenerator::create(SAMPLE_GENERATOR_UNIFORM);
    if (!mpSampleGenerator) throw std::exception("Failed to create sample generator");
    mpSampleGenerator->prepareProgram(mRaytrace.pProgram.get());

    // Create pixel shader program
    Program::Desc desc;
    desc.addShaderLibrary(shaderDir + kPixelShaderFile).csEntry("main").setShaderModel("5_1");
    mpPixelPass = ComputePass::create(desc, Program::DefineList(), true);
    if (!mpPixelPass) throw std::exception("Failed to create program");

    return true;
}

bool PrtPrecompute::parseDictionary(const Dictionary& dict)
{
    for (const auto& v : dict)
    {
        if (v.key() == kOutDir)
        {
            std::string tmpStr = v.val();
            mOutDir = tmpStr;
        }
        else if (v.key() == kOutFilename)
        {
            std::string tmpStr = v.val();
            mOutFilename = tmpStr;
        }
        else
        {
            logWarning("Unknown field `" + v.key() + "` in a GBuffer dictionary");
        }
    }
    return true;
}

Dictionary PrtPrecompute::getScriptingDictionary()
{
    Dictionary dict;
    dict[kOutDir] = mOutDir;
    dict[kOutFilename] = mOutFilename;
    return dict;
}

RenderPassReflection PrtPrecompute::reflect(const CompileData& compileData)
{
    RenderPassReflection r;

    // Add all outputs as UAVs.
    auto addOutput = [&](const ChannelDesc& output)
    {
        auto& f = r.addOutput(output.name, output.desc).format(output.format).bindFlags(Resource::BindFlags::UnorderedAccess);
        if (output.optional) f.flags(RenderPassReflection::Field::Flags::Optional);
    };
    for (auto it : kOutputChannels) addOutput(it);

    return r;
}

void PrtPrecompute::compile(RenderContext* pContext, const CompileData& compileData)
{
    mFrameSize = compileData.defaultTexDims;
}

void PrtPrecompute::preprocessScene()
{
    mNumVertex = 0;

    uint32_t numMeshInstance = mpScene->getMeshInstanceCount();
    for (uint32_t i = 0; i < numMeshInstance; i++)
    {
        const MeshInstanceData& inst = mpScene->getMeshInstance(i);
        const MeshDesc& mesh = mpScene->getMesh(inst.meshID);
        const auto& mat = mpScene->getMaterial(inst.materialID);

        if (mat->isEmissive())
            break;

        mNumVertex += mesh.vertexCount;
    }

    // Create VertexToMesh buffer
    mVertexToMeshData.clear();
    mVertexToMeshData.reserve(mNumVertex);
    for (uint32_t i = 0; i < numMeshInstance; i++)
    {
        const MeshInstanceData& inst = mpScene->getMeshInstance(i);
        const MeshDesc& mesh = mpScene->getMesh(inst.meshID);

        for (uint j = 0; j < mesh.vertexCount; j++)
        {
            mVertexToMeshData.push_back(uint2(i, j));
        }
    }

    if (mpVertexToMesh) mpVertexToMesh = nullptr;
    mpVertexToMesh = TypedBuffer<uint2>::create((uint32_t)mVertexToMeshData.size());
    mpVertexToMesh->setBlob((void*)(mVertexToMeshData.data()), 0, mVertexToMeshData.size() * sizeof(uint2));

    // Create buffer
    mpTcoeffs = TypedBuffer<float>::create(mNumVertex * 81);

    mSqrtSpp = 128;
    //mSqrtSpp = 256;
    mSampleCount = 0;

    FILE *fp = fopen((mOutDir + "debug_scene_T.txt").c_str(), "w");
    fprintf(fp, "Num vertex = %d\n", mNumVertex);
    fclose(fp);
}

void PrtPrecompute::setScene(RenderContext* pRenderContext, const Scene::SharedPtr& pScene)
{
    assert(pScene);
    mpScene = pScene;

    mRaytrace.pProgram->addDefines(pScene->getSceneDefines());
    mRaytrace.pVars = RtProgramVars::create(mRaytrace.pProgram, pScene);
    if (!mRaytrace.pVars) throw std::exception("Failed to create program vars");

    preprocessScene();
}

void PrtPrecompute::precomputeTransfer(RenderContext* pRenderContext, const RenderData& renderData)
{
    GraphicsVars::SharedPtr pGlobalVars = mRaytrace.pVars->getGlobalVars();

    bool success = mpSampleGenerator->setIntoProgramVars(pGlobalVars.get());
    if (!success) throw std::exception("Failed to bind sample generator");

    pGlobalVars["PerFrameCB"]["gSqrtSpp"] = mSqrtSpp;
    pGlobalVars["PerFrameCB"]["gSampleCount"] = mSampleCount;
    pGlobalVars["gVertexToMesh"] = mpVertexToMesh;

    pGlobalVars["gTcoeffs"] = mpTcoeffs;

    // Launch the rays.
    uvec3 targetDim = uvec3(mNumVertex, 1u, 1u);

    mpScene->raytrace(pRenderContext, mRaytrace.pState, mRaytrace.pVars, targetDim);
}

void PrtPrecompute::execute(RenderContext* pRenderContext, const RenderData& renderData)
{
    if (mpScene == nullptr)
    {
        logWarning("GBufferRT::execute() - No scene available");
        return;
    }

    if (!mPrecomputeDone)
    {
        precomputeTransfer(pRenderContext, renderData);

        mSampleCount++;
        if (mSampleCount == mSqrtSpp * mSqrtSpp)
        {
            mPrecomputeDone = true;

            // output
            float normFactor = 1.f / (mSqrtSpp * mSqrtSpp);
            float* bufData = static_cast<float*>(mpTcoeffs->getData());
            FILE *fp = fopen((mOutDir + mOutFilename).c_str(), "w");
            for (uint i = 0; i < mNumVertex; i++)
            {
                uint idx = i * 81;
                for (uint j = 0; j < 81; j++)
                    fprintf(fp, "%.6f ", bufData[idx + j] * normFactor);
                fprintf(fp, "\n");
            }
            fclose(fp);
        }
    }

    mpPixelPass["PerFrameCB"]["gSqrtSpp"] = mSqrtSpp;
    mpPixelPass["PerFrameCB"]["gSampleCount"] = mSampleCount;

    // Bind outputs.
    auto bind = [&](const ChannelDesc& output)
    {
        Texture::SharedPtr pTex = renderData[output.name]->asTexture();
        if (pTex) pRenderContext->clearUAV(pTex->getUAV().get(), glm::vec4(0, 0, 0, 0));
        mpPixelPass[output.texname] = pTex;
    };
    for (auto it : kOutputChannels) bind(it);

    mpPixelPass->execute(pRenderContext, mFrameSize.x, mFrameSize.y, 1);
}

void PrtPrecompute::renderUI(Gui::Widgets& widget)
{
}
