// Copyright @yucwang 2022

#pragma once
#ifndef LIBTENSORRAY_ALGORITHM1_DIRECT2_H_
#define LIBTENSORRAY_ALGORITHM1_DIRECT2_H_
#endif
#include "../Ray.h"
#include "../Records.h"
#include "Algorithm1.h"
#include "Path2.h"
#include "PathSampler.h"
#include "ReSTIR/ScreenSpaceReSTIR.h"
#include "ReSTIR/MaterialSpaceReSTIR.h"

namespace EDX
{
    namespace TensorRay
    {
        class Direct2 : public PathTracer
        {
        public:
            Direct2() : PathTracer(), mPathSampler(MaterialSpaceReSTIRPathSampler::GetInstance()) {}

            void SetParam(const RenderOptions& options);

            void Integrate(const Scene& scene, Tensorf& image) const;

            void SamplePath(const Scene& scene) const override;

            virtual Tensorf RenderD(const Scene& scene, const RenderOptions& options, const Tensorf& dLdI)
            {
                const Camera& camera = *scene.mSensors[0];
                mDLoss = dLdI;
                SetParam(options);
                if (options.mExportDerivative)
                    mGradHandler.InitGradient(camera.GetFilmSizeX(), camera.GetFilmSizeY());
                Tensorf gradImage = Zeros(Shape({ camera.GetFilmSizeX() * camera.GetFilmSizeY() }, VecType::Vec3));
                Integrate(scene, gradImage);
                if (options.mExportDerivative)
                    mGradHandler.GetGradientImages(gradImage);
                mGradHandler.ClearGradImages();
                mGradHandler.ClearIndex();
                mDLoss.Free();
                return gradImage;
            }

            virtual void Step() override;

            // Algorithm1
            Expr EvalPath(const Scene& scene, const PathSpatial& paths, const PathContribTerms& h, const Shape& imageShape) const;
            std::shared_ptr<MaterialSpaceReSTIRPathSampler> mPathSampler;

            mutable std::shared_ptr<PathSampleResult> mPath;
        };
    }
}