#pragma once

#include <render/integrator.h>
#include <render/scene.h>
#include "adaptive3D.h"
#include <render/photon_map.h>
#include <algorithm1.h>
#include "mala_utils.h"

struct PrimaryEdgeIntegrator : IntegratorBoundary
{
    DiscreteDistribution edge_dist;
    std::vector<Vector2i> edge_indices; // [{face_id, edge_id}]
    PrimaryEdgeIntegrator(const Scene &scene);
    PrimaryEdgeIntegrator(const Properties &props);
    void configure(const Scene &scene);
    ArrayXd renderD(SceneAD &sceneAD,
                    RenderOptions &options, const ArrayXd &d_image) const;
    ArrayXd forwardRenderD(SceneAD &sceneAD, RenderOptions &options) const;
};

struct DirectEdgeIntegrator : IntegratorBoundary
{
    DiscreteDistribution draw_dist;
    DiscreteDistribution edge_dist;
    std::vector<Vector2i> edge_indices; // edge_dist_idx->{face_id, edge_id}
    std::vector<std::vector<int>> edge_indices_inv; // {face_id, edge_id}->edge_dist_idx

    Adaptive_Sampling::adaptive3D aq_distrb;
    Grid3D_Sampling::grid3D grid_distrb;

    DirectEdgeIntegrator(const Scene &scene);
    DirectEdgeIntegrator(const Properties &props);
    void configure(const Scene &scene);
    void recompute_edge(const Scene &scene);
    void preprocess_grid(const Scene &scene, const Grid3D_Sampling::grid3D_config &config, int max_bounces);
    void preprocess_aq(const Scene &scene, const Adaptive_Sampling::adaptive3D_config &config, int max_bounces);
    ArrayXd renderD(SceneAD &sceneAD,
                    RenderOptions &options, const ArrayXd &d_image) const;
    ArrayXd forwardRenderD(SceneAD &sceneAD, RenderOptions &options) const;
};

// struct MetropolisDirectEdgeIntegrator : DirectEdgeIntegrator
// {
//     MALAOptions mala_config;
//     MetropolisDirectEdgeIntegrator(const Scene &scene): DirectEdgeIntegrator(scene) {}
//     MetropolisDirectEdgeIntegrator(const Properties &props): DirectEdgeIntegrator(props) {}
//     void configure(const Scene &scene) { DirectEdgeIntegrator::configure(scene);}
//     void load_MALA_config(const MALAOptions &config) { mala_config = config; };
//     ArrayXd renderD(SceneAD &sceneAD,
//                     RenderOptions &options, const ArrayXd &d_image) const;
//     Float mutateSmallStep(const Scene &scene, Scene &d_scene, RndSampler *sampler, int max_bounces,
//                             MALA::Mutation &mutation, MALA::KNNCache &cache,
//                             const MALA::MALAVector &current, MALA::MALAVector &proposal) const;
//     Float mutateLargeStep(const Scene &scene, RndSampler *sampler, int max_bounces, const MALA::MALAVector &current, MALA::MALAVector &proposal) const;
//     // guiding related
//     ArrayXd get_sample_vol(const Scene &scene, const Vector3i size, const Vector &min_bound, const Vector &max_bound);
//     ArrayXd get_sample_slice(const Scene &scene, int axis, Float u0, int size_1, int size_2);
//     void preprocess_grid(const Scene &scene, const Grid3D_Sampling::grid3D_config &config, int max_bounces);
//     void preprocess_aq(const Scene &scene, const Adaptive_Sampling::adaptive3D_config &config, int max_bounces);
//     ArrayXd diff_bsdf_test(const Float &roughness, const Float &theta0, const Float &theta1) const;
//     // void BSDF_test(const Scene &scene) const;
// };

struct IndirectEdgeIntegrator : IntegratorBoundary
{
    DiscreteDistribution draw_dist;
    DiscreteDistribution edge_dist;
    std::vector<Vector2i> edge_indices; // [{face_id, edge_id}]
    // std::vector<std::vector<int>> edge_indices_inv; // {face_id, edge_id}->edge_dist_idx

    Adaptive_Sampling::adaptive3D aq_distrb;
    Grid3D_Sampling::grid3D grid_distrb;

    IndirectEdgeIntegrator(const Scene &scene);
    IndirectEdgeIntegrator(const Properties &props);
    void configure(const Scene &scene);
    void recompute_edge(const Scene &scene);
    void preprocess_grid(const Scene &scene, const Grid3D_Sampling::grid3D_config &config, int max_bounces);
    void preprocess_aq(const Scene &scene, const Adaptive_Sampling::adaptive3D_config &config, int max_bounces);
    ArrayXd renderD(SceneAD &sceneAD,
                    RenderOptions &options, const ArrayXd &d_image) const;
    ArrayXd forwardRenderD(SceneAD &sceneAD, RenderOptions &options) const;
};

struct MetropolisIndirectEdgeIntegrator : IndirectEdgeIntegrator
{
    MALAOptions mala_config;
    MetropolisIndirectEdgeIntegrator(const Scene &scene): IndirectEdgeIntegrator(scene) {}
    MetropolisIndirectEdgeIntegrator(const Properties &props): IndirectEdgeIntegrator(props) {}
    void configure(const Scene &scene) { IndirectEdgeIntegrator::configure(scene);}
    void load_MALA_config(const MALAOptions &config) { mala_config = config; };
    ArrayXd renderD(SceneAD &sceneAD,
                    RenderOptions &options, const ArrayXd &d_image) const;
    // guiding related
    // void BSDF_test(const Scene &scene) const;
    void preprocess_grid(const Scene &scene, const Grid3D_Sampling::grid3D_config &config, int max_bounces);
    void preprocess_aq(const Scene &scene, const Adaptive_Sampling::adaptive3D_config &config, int max_bounces);

    // testing related
    ArrayX3d get_edge_ray(const Scene &scene, const Vector &rnd) const;
    ArrayX3d perturbe_sample(const Scene &scene, const Vector &rnd, const Vector &mutation) const;

    // plotting helper
    ArrayXd solve_Grid(const Scene &scene, const Vector3i &size, 
                    const Vector &min, const Vector &max) const;
    ArrayXd solve_Grid_rough(const Scene &scene, const Vector3i &size, 
                    const Vector &min, const Vector &max) const;
    ArrayX3d solve_MALA(const Scene &scene, const Vector3i &size, 
                    const Vector &min, const Vector &max) const;
    Eigen::Array<Float, -1, 4, 1> solve_MALA_rough(const Scene &scene, const Vector3i &size, 
                    const Vector &min, const Vector &max) const;
};

struct BoundaryIntegrator : IntegratorBoundary
{
    PrimaryEdgeIntegrator p;
    // MetropolisDirectEdgeIntegrator d;
    DirectEdgeIntegrator d;
    MetropolisIndirectEdgeIntegrator i;
    // IndirectEdgeIntegrator i;
    BoundaryIntegrator(const Scene &scene) : p(scene), d(scene), i(scene) { configure(scene); }
    BoundaryIntegrator(const Properties &props) : p(props), d(props), i(props) {}

    void configure_mala(const MALAOptions &config) { 
        // d.load_MALA_config(config); 
        i.load_MALA_config(config);
    }

    void configure(const Scene &scene)
    {
        p.configure(scene);
        d.configure(scene);
        i.configure(scene);
    }
    void configure_primary(const Scene &scene)
    {
        p.configure(scene);
    }

    void recompute_direct_edge(const Scene &scene)
    {
        d.recompute_edge(scene);
    }

    void recompute_indirect_edge(const Scene &scene)
    {
        i.recompute_edge(scene);
    }

    void preprocess_grid_direct(const Scene &scene, const Grid3D_Sampling::grid3D_config &config, int max_bounces)
    {
        d.preprocess_grid(scene, config, max_bounces);
    }

    void preprocess_aq_direct(const Scene &scene, const Adaptive_Sampling::adaptive3D_config &config, int max_bounces)
    {
        d.preprocess_aq(scene, config, max_bounces);
    }

    void preprocess_grid_indirect(const Scene &scene, const Grid3D_Sampling::grid3D_config &config, int max_bounces)
    {
        i.preprocess_grid(scene, config, max_bounces);
    }

    void preprocess_aq_indirect(const Scene &scene, const Adaptive_Sampling::adaptive3D_config &config, int max_bounces)
    {
        i.preprocess_aq(scene, config, max_bounces);
    }

    ArrayXd renderD(SceneAD &sceneAD,
                    RenderOptions &options, const ArrayXd &d_image) const
    {
        return p.renderD(sceneAD, options, d_image) +
               d.renderD(sceneAD, options, d_image) +
               i.renderD(sceneAD, options, d_image);
    }

    ArrayXd forwardRenderD(SceneAD &sceneAD, RenderOptions &options) const {
        return p.forwardRenderD(sceneAD, options) +
               d.forwardRenderD(sceneAD, options) +
               i.forwardRenderD(sceneAD, options);
    }
};
