#include <render/scene.h>
#include <core/properties.h>
#include <core/cube_distrb.h>
#include <core/spectrum.h>
#include <emitter/area.h>
#include <emitter/envmap.h>
// #include <emitter/point.h>
#include <bsdf/null.h>
#include <bsdf/diffuse.h>
#include <bsdf/roughconductor.h>
#include <bsdf/roughdielectric.h>

#include <render/medium.h>
#include <medium/homogeneous.h>
#include <medium/heterogeneous.h>
#include <tetra.hpp>
#include <render/phase.h>
#include <render/volumegrid.h>

// #include <integrator/direct.h>
// #include <integrator/direct2.h>
// #include <integrator/path.h>
#include <integrator/path2.h>
#include "integrator/mask.h"
// #include <integrator/volpath.h>
// #include <integrator/ptracer2.h>
// #include <integrator/bdpt2.h>
#include <integrator/boundary.h>
// #include <integrator/vol_boundary.h>
#include <integrator/boundary_pixel.h>
#include "bsdf/roughdielectric.h"

#include <config.h>

// #include "test/test.h"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/eigen.h>
#include <core/transform.h>
#include <pybind11/operators.h>

#include "light_path.h"

#include "adaptive3D.h"

#include <core/bitmap.h>
#include <render/imageblock.h>
#include <render/common.h>
#include <algorithm1.h>
#include <igl/copyleft/tetgen/tetrahedralize.h>
#include <scene_loader.h>
#include <core/properties.h>
#include <core/logger.h>
// #include <integrator/new_volpath/volpath_merged.h>
// #include <integrator/new_volpath/volpath2.h>
// #include <integrator/new_volpath/boundary_bidir.h>
// #include <integrator/new_volpath/boundary_unidir.h>
// #include <integrator/new_volpath/interior.h>
// #include <integrator/test.h>

namespace py = pybind11;
using namespace py::literals;

std::tuple<Eigen::MatrixXd, Eigen::MatrixXi, Eigen::MatrixXi> tetrahedralize1(const Eigen::MatrixXd &V, const Eigen::MatrixXi &F, const std::string &switches)
{
    Eigen::MatrixXd TV;
    Eigen::MatrixXi TT;
    Eigen::MatrixXi TF;
    igl::copyleft::tetgen::tetrahedralize(V, F, switches, TV, TT, TF);
    return {TV, TT, TF};
}

PYBIND11_MODULE(psdr_cpu, m)
{
    // static initialization
    Logger::static_init();
    PSDR_INFO("psdr_cpu loaded");

    m.doc() = "psdr_cpu"; // optional module docstring

    py::class_<ptr<float>>(m, "float_ptr")
        .def(py::init<std::size_t>());
    py::class_<ptr<int>>(m, "int_ptr")
        .def(py::init<std::size_t>());

    py::class_<Spectrum3f>(m, "Spectrum3f")
        .def(py::init<float, float, float>());

    py::class_<Vector3f>(m, "Vector3f")
        .def(py::init<float, float, float>());

    py::class_<Ray>(m, "Ray")
        .def(py::init<Vector, Vector>())
        .def_readwrite("org", &Ray::org)
        .def_readwrite("dir", &Ray::dir);
    py::class_<Object>(m, "Object");
    py::class_<ImageBlock>(m, "ImageBlock")
        .def(py::init<Array2i, Array2i>())
        .def("put", static_cast<void (ImageBlock::*)(const Vector2i &, const Spectrum &)>(&ImageBlock::put))
        .def("put", static_cast<void (ImageBlock::*)(const ImageBlock &)>(&ImageBlock::put))
        .def_readwrite("offset", &ImageBlock::m_offset)
        .def_readwrite("blockSize", &ImageBlock::m_blockSize)
        .def_readwrite("data", &ImageBlock::m_data);

    py::class_<VolumeGrid, std::unique_ptr<VolumeGrid, py::nodelete>>(m, "VolumeGrid")
        .def(py::init<const Properties &>())
        .def(py::init<const Vector3i &, int,
                      const Vector &, const Vector &,
                      const std::vector<Float> &,
                      const Matrix4x4 &>(),
             "res"_a, "channels"_a, "aabb_min"_a, "aabb_max"_a, "data"_a, "transform"_a = Matrix4x4::Identity())
        .def(py::init<const Spectrum &>())
        .def(py::init<const std::string &>())
        .def("setZero", &VolumeGrid::setZero)
        .def("lookupFloat", &VolumeGrid::lookupFloat)
        .def("lookupSpectrum", &VolumeGrid::lookupSpectrum)
        .def_readwrite("size", &VolumeGrid::m_res)
        .def_readwrite("m_nchannel", &VolumeGrid::m_nchannel)
        .def_readwrite("m_data", &VolumeGrid::m_data)
        .def_readwrite("m_max", &VolumeGrid::m_max)
        .def_readwrite("m_channel_max", &VolumeGrid::m_channel_max)
        .def_readwrite("is_constant", &VolumeGrid::is_constant)
        .def_readwrite("m_volumeToGrid", &VolumeGrid::m_volumeToGrid)
        .def_readwrite("m_volumeToWorld", &VolumeGrid::m_volumeToWorld)
        .def_readwrite("m_worldToVolume", &VolumeGrid::m_worldToVolume)
        .def_readwrite("m_worldToGrid", &VolumeGrid::m_worldToGrid);

    py::enum_<EVertex>(m, "EVertex")
        .value("EVInvalid", EVInvalid)
        .value("EVSensor", EVSensor)
        .value("EVSurface", EVSurface)
        .value("EVEmitter", EVEmitter)
        .value("EVVolume", EVVolume)
        .value("EVNull", EVNull)
        .export_values();

    py::class_<Properties>(m, "Properties")
        .def(py::init<>())
        .def(py::init<const Properties::PropertiesPrivate&>())
        .def("merge", &Properties::merge)
        .def("get", static_cast<int (Properties::*)(const std::string &) const>(&Properties::get<int>))
        .def("set", &Properties::set<bool>)
        .def("set", &Properties::set<int>)
        .def("set", &Properties::set<double>)
        .def("set", &Properties::set<Matrix4x4>)
        .def("set", &Properties::set<std::string>)
        .def("setVectorX", &Properties::set<std::vector<Vector>>)
        .def("setVectorX2", &Properties::set<std::vector<Vector2>>)
        .def("setVectorX3i", &Properties::set<std::vector<Vector3i>>)
        .def("setBitmap", &Properties::set<Bitmap>)
        .def("set", &Properties::set<Properties>)
        .def("data", &Properties::data);
    py::implicitly_convertible<Properties::PropertiesPrivate, Properties>();

    py::class_<Intersection>(m, "Intersection")
        .def(py::init<>())
        .def_readwrite("ptr_shape", &Intersection::ptr_shape)
        .def_readwrite("ptr_med_int", &Intersection::ptr_med_int)
        .def_readwrite("ptr_med_ext", &Intersection::ptr_med_ext)
        .def_readwrite("ptr_bsdf", &Intersection::ptr_bsdf)
        .def_readwrite("ptr_emitter", &Intersection::ptr_emitter)
        .def_readwrite("t", &Intersection::t)
        .def_readwrite("p", &Intersection::p)
        .def_readwrite("geoFrame", &Intersection::geoFrame)
        .def_readwrite("shFrame", &Intersection::shFrame)
        .def_readwrite("uv", &Intersection::uv)
        .def_readwrite("wi", &Intersection::wi)
        .def_readwrite("J", &Intersection::J)
        .def_readwrite("type", &Intersection::type)
        .def_readwrite("pdf", &Intersection::pdf)
        .def_readwrite("shape_id", &Intersection::shape_id)
        .def_readwrite("triangle_id", &Intersection::triangle_id)
        .def_readwrite("int_med_id", &Intersection::int_med_id)
        .def_readwrite("ext_med_id", &Intersection::ext_med_id)
        .def_readwrite("medium_id", &Intersection::medium_id)
        .def_readwrite("pixel_idx", &Intersection::pixel_idx)
        .def_readwrite("value", &Intersection::value)
        .def_readwrite("nee_bsdf", &Intersection::nee_bsdf)
        .def_readwrite("bsdf_bsdf", &Intersection::bsdf_bsdf)
        .def_readwrite("l_nee_id", &Intersection::l_nee_id)
        .def_readwrite("l_bsdf_id", &Intersection::l_bsdf_id)
        .def_readwrite("barycentric", &Intersection::barycentric)
        .def_readwrite("indices", &Intersection::indices);

    py::class_<LightPath>(m, "LightPath")
        .def(py::init<>())
        .def_readwrite("pixel_idx", &LightPath::pixel_idx)
        .def_readwrite("vertices", &LightPath::vertices)
        .def_readwrite("vs", &LightPath::vs);

    py::class_<LightPathAD>(m, "LightPathAD")
        .def(py::init<const LightPath &>())
        .def_readwrite("val", &LightPathAD::val)
        .def_readwrite("der", &LightPathAD::der);

    py::class_<RadianceQueryRecord>(m, "RadianceQueryRecord")
        .def(py::init<const Array2i &, RndSampler *, int>())
        .def_readwrite("pixel_idx", &RadianceQueryRecord::pixel_idx)
        .def_readwrite("sampler", &RadianceQueryRecord::sampler)
        .def_readwrite("max_bounces", &RadianceQueryRecord::max_bounces);

    py::class_<PixelQueryRecord, RadianceQueryRecord>(m, "PixelQueryRecord")
        .def(py::init<Array2i, RndSampler *, int, int, bool>(),
             "pixel_idx"_a, "sampler"_a, "max_bounces"_a, "nsamples"_a, "enable_antithetic"_a = true);

    py::class_<RGBSpectrum<Float, 1, 1>>(m, "RGBSpectrum", py::buffer_protocol())
        .def(py::init<float, float, float>())
        .def("data", &RGBSpectrum<Float, 1, 1>::data)
        .def("rows", &RGBSpectrum<Float, 1, 1>::rows)
        .def("cols", &RGBSpectrum<Float, 1, 1>::cols)
        .def_buffer([](RGBSpectrum<Float, 1, 1> &s) -> py::buffer_info
                    { return py::buffer_info(
                          s.data(),
                          sizeof(Float),
                          py::format_descriptor<Float>::format(),
                          1,
                          {3},
                          {sizeof(Float)}); });

    py::class_<RndSampler>(m, "RndSampler")
        .def(py::init<int, int>())
        .def("next1D", &RndSampler::next1D)
        .def("next2D", &RndSampler::next2D)
        .def("next3D", &RndSampler::next3D)
        .def("save", &RndSampler::save)
        .def("restore", &RndSampler::restore)
        .def_readonly("saved", &RndSampler::saved)
        .def_readonly("state", &RndSampler::state)
        .def_readonly("inc", &RndSampler::inc);

    // aq

    py::class_<DiscreteDistribution>(m, "DiscreteDistribution")
        .def(py::init<>())
        .def_readwrite("m_cdf", &DiscreteDistribution::m_cdf);

    py::class_<CubeDistribution>(m, "CubeDistribution")
        .def(py::init<>())
        .def_readwrite("m_distrb", &CubeDistribution::m_distrb)
        .def_readwrite("m_res", &CubeDistribution::m_res)
        .def_readwrite("m_num_cells", &CubeDistribution::m_num_cells)
        .def_readwrite("m_cells", &CubeDistribution::m_cells)
        .def_readwrite("m_unit", &CubeDistribution::m_unit)
        .def_readwrite("m_ready", &CubeDistribution::m_ready);
    

    py::class_<Sort_config>(m, "Sort_config")
        .def(py::init<>())
        .def(py::init<int, Float, Float>())
        .def_readwrite("max_length", &Sort_config::max_length)
        .def_readwrite("cos_val_local", &Sort_config::cos_val_local)
        .def_readwrite("cos_val_global", &Sort_config::cos_val_global);

    py::class_<Adaptive_Sampling::adaptive3D>(m, "adaptive3D")
        .def("printTree", &Adaptive_Sampling::adaptive3D::print)
        .def("sample", &Adaptive_Sampling::adaptive3D::sample);

    py::class_<Adaptive_Sampling::adaptive3D_config>(m, "aq_config")
        .def(py::init<>())
        .def(py::init<double, int, double, int, int>())
        .def_readwrite("thold", &Adaptive_Sampling::adaptive3D_config::thold)
        .def_readwrite("spg", &Adaptive_Sampling::adaptive3D_config::spg)
        .def_readwrite("min_spg", &Adaptive_Sampling::adaptive3D_config::min_spg)
        .def_readwrite("sample_decay", &Adaptive_Sampling::adaptive3D_config::sample_decay)
        .def_readwrite("weight_decay", &Adaptive_Sampling::adaptive3D_config::weight_decay)
        .def_readwrite("max_depth", &Adaptive_Sampling::adaptive3D_config::max_depth)
        .def_readwrite("npass", &Adaptive_Sampling::adaptive3D_config::npass)
        .def_readwrite("use_heap", &Adaptive_Sampling::adaptive3D_config::use_heap)
        .def_readwrite("edge_draw", &Adaptive_Sampling::adaptive3D_config::edge_draw)
        .def_readwrite("max_depth_x", &Adaptive_Sampling::adaptive3D_config::max_depth_x)
        .def_readwrite("max_depth_y", &Adaptive_Sampling::adaptive3D_config::max_depth_y)
        .def_readwrite("max_depth_z", &Adaptive_Sampling::adaptive3D_config::max_depth_z)
        .def_readwrite("eps", &Adaptive_Sampling::adaptive3D_config::eps)
        .def_readwrite("shape_opt_id", &Adaptive_Sampling::adaptive3D_config::shape_opt_id)
        .def_readwrite("local_backward", &Adaptive_Sampling::adaptive3D_config::local_backward);

    py::class_<Grid3D_Sampling::grid3D_config>(m, "grid_config")
        .def(py::init<const Vector3i &, int>())
        .def_readwrite("dims", &Grid3D_Sampling::grid3D_config::dims)
        .def_readwrite("spp", &Grid3D_Sampling::grid3D_config::spp);

    py::class_<ReconstructionFilter>(m, "ReconstructionFilter");

    py::class_<TentFilter, ReconstructionFilter>(m, "TentFilter")
        .def(py::init<Float>())
        .def_readwrite("padding", &TentFilter::m_padding);

    py::class_<BoxFilter, ReconstructionFilter>(m, "BoxFilter")
        .def(py::init<>());

    py::class_<AnisotropicGaussianFilter, ReconstructionFilter>(m, "AnisotropicGaussianFilter")
        .def(py::init<>());

    py::class_<CropRectangle>(m, "CropRectangle")
        .def_readwrite("offset_x", &CropRectangle::offset_x)
        .def_readwrite("offset_y", &CropRectangle::offset_y)
        .def_readwrite("crop_width", &CropRectangle::crop_width)
        .def_readwrite("crop_height", &CropRectangle::crop_height);

    py::class_<Camera, std::unique_ptr<Camera, py::nodelete>>(m, "Camera")
        .def(py::init<int, int, const Matrix4x4 &, const Matrix3x3 &, int, int>())
        .def(py::init<int, int, Float, const Matrix4x4 &, int>(),
             "width"_a, "height"_a, "fov"_a, "to_world"_a, "type"_a)
        .def(py::init<int, int, Float, const Vector &, const Vector &, const Vector &, int>(),
             "width"_a, "height"_a, "fov"_a, "origin"_a, "target"_a, "up"_a, "type"_a)
        .def(py::init<const Properties>(), "props"_a)
        .def_readwrite("cpos", &Camera::cpos)
        .def_readonly("world_to_cam", &Camera::world_to_cam)
        .def_readwrite("cam_to_world", &Camera::cam_to_world)
        .def_readonly("cam_to_ndc", &Camera::cam_to_ndc)
        .def_readwrite("width", &Camera::width)
        .def_readwrite("height", &Camera::height)
        .def_readwrite("fov", &Camera::m_fov)
        .def_readwrite("rfilter", &Camera::rfilter)
        .def_readwrite("sigX", &Camera::sigX)
        .def_readwrite("sigY", &Camera::sigY)
        .def_readwrite("rect", &Camera::rect)
        .def("set_rect", &Camera::setCropRect)
        .def("getCropSize", &Camera::getCropSize)
        .def("samplePrimaryRay", static_cast<Ray (Camera::*)(Float, Float) const>(&Camera::samplePrimaryRay))
        .def("samplePrimaryBoundaryRay", &Camera::samplePrimaryBoundaryRay);

    py::class_<Edge, std::unique_ptr<Edge, py::nodelete>>(m, "Edge")
        .def_readonly("v0", &Edge::v0)
        .def_readonly("v1", &Edge::v1)
        .def_readonly("f0", &Edge::f0)
        .def_readonly("f1", &Edge::f1);

    py::class_<Shape, std::unique_ptr<Shape, py::nodelete>>(m, "Shape")
        .def(py::init<const Properties &>())
        .def(py::init<const std::vector<Vector> &, const std::vector<Vector3i> &,
                      const std::vector<Vector2> &, const std::vector<Vector> &,
                      int, int, int, int, int, int>())
        .def(py::init<const std::vector<Vector> &, const std::vector<Vector3i> &,
                      const std::vector<Vector2> &, const std::vector<Vector> &,
                      int, int, int, int, int, int, const Sort_config &>())
        .def("load", &Shape::load)
        .def("setZero", &Shape::setZero)
        .def("configure", &Shape::configure)
        .def("setVertex", &Shape::setVertex)
        .def("setVertices", &Shape::setVertices)
        .def_readonly("num_vertices", &Shape::num_vertices)
        .def_readonly("light_id", &Shape::light_id)
        .def("has_uvs", &Shape::hasUVs)
        .def("has_normals", &Shape::hasNormals)
        .def_readwrite("to_world", &Shape::m_to_world)
        .def_readwrite("vertices_world", &Shape::vertices)
        .def_property("vertices", &Shape::getVertices, &Shape::setVertices)
        .def_readwrite("vertices_raw", &Shape::vertices_raw)
        .def_readwrite("normals", &Shape::normals)
        .def_readwrite("faceNormals", &Shape::faceNormals)
        .def_readwrite("uvs", &Shape::uvs)
        .def_readwrite("uv_indices", &Shape::uv_indices)
        .def_readwrite("use_face_normals", &Shape::m_use_face_normals)
        .def_readwrite("edges", &Shape::edges)
        .def_readwrite("adjacentFaces", &Shape::adjacentFaces)
        .def_readwrite("indices", &Shape::indices)
        .def_readwrite("bsdf_id", &Shape::bsdf_id)
        .def_readwrite("light_id", &Shape::light_id)
        .def_readwrite("med_int_id", &Shape::med_int_id)
        .def_readwrite("med_ext_id", &Shape::med_ext_id)
        .def_readwrite("enable_edge", &Shape::enable_edge)
        .def_readwrite("enable_draw", &Shape::enable_draw)
        .def("getEdges", &Shape::getEdges)
        .def_readwrite("sort_config", &Shape::sort_config)
        .def("save", &Shape::save)
        .def("setTranslation", &Shape::setTranslation)
        .def("setRotation", &Shape::setRotation)
        .def("setScale", &Shape::setScale)
        .def_readwrite("velocities", &Shape::velocities)
        .def("setVelocities", &Shape::setVelocities)
        .def_readwrite("vertex_idx", &Shape::vertex_idx)
        .def_readwrite("requires_grad", &Shape::requires_grad)
        .def_readwrite("param", &Shape::param);

#ifdef NORMAL_PREPROCESS
    m.def("d_shape_precompute_normal", static_cast<void (*)(const Shape &, Shape &)>(&d_precompute_normal));
    m.def("d_scene_precompute_normal", static_cast<void (*)(const Scene &, Scene &)>(&d_precompute_normal));
#endif

    py::class_<Bitmap, std::unique_ptr<Bitmap, py::nodelete>>(m, "Bitmap")
        .def(py::init<>())
        .def(py::init<const Spectrum &>())
        .def(py::init<const char *>())
        .def(py::init<const ArrayXd &, const Vector2i &>())
        .def("save", &Bitmap::save)
        .def_property("m_data", &Bitmap::getData, &Bitmap::setData)
        .def_readwrite("m_res", &Bitmap::m_res);

    py::class_<BSDF, std::unique_ptr<BSDF, py::nodelete>>(m, "BSDF");
    py::class_<DiffuseBSDF, BSDF, std::unique_ptr<DiffuseBSDF, py::nodelete>>(m, "DiffuseBSDF")
        .def(py::init<>())
        .def(py::init<Spectrum>())
        .def_readwrite("reflectance", &DiffuseBSDF::reflectance);
    py::class_<RoughConductorBSDF, BSDF, std::unique_ptr<RoughConductorBSDF, py::nodelete>>(m, "RoughConductorBSDF")
        .def(py::init<float, Spectrum, Spectrum>());
    py::class_<RoughDielectricBSDF, BSDF, std::unique_ptr<RoughDielectricBSDF, py::nodelete>>(m, "RoughDielectricBSDF")
        .def(py::init<Float, Float, Float>());
    py::class_<NullBSDF, BSDF, std::unique_ptr<NullBSDF, py::nodelete>>(m, "NullBSDF")
        .def(py::init<>());

    py::class_<PhaseFunction, std::unique_ptr<PhaseFunction, py::nodelete>>(m, "Phase");
    py::class_<HGPhaseFunction, PhaseFunction, std::unique_ptr<HGPhaseFunction, py::nodelete>>(m, "HG")
        .def(py::init<float>());
    py::class_<IsotropicPhaseFunction, PhaseFunction, std::unique_ptr<IsotropicPhaseFunction, py::nodelete>>(m, "Isotropic")
        .def(py::init<>());

    py::class_<TetrahedronMesh, std::unique_ptr<TetrahedronMesh, py::nodelete>>(m, "TetrahedronMesh")
        .def(py::init<const std::vector<Vector> &,
                      const std::vector<Eigen::Matrix<int, 3, 1>> &,
                      const std::vector<std::pair<int, int>> &>())
        .def_readonly("TV", &TetrahedronMesh::m_TV)
        .def_readonly("TT", &TetrahedronMesh::m_TT)
        .def_readonly("volume", &TetrahedronMesh::m_vol)
        .def_readonly("vertices", &TetrahedronMesh::m_vertices)
        .def_readonly("ids", &TetrahedronMesh::m_ids);

    m.def("tetrahedralize", &tetrahedralize1);

    py::class_<Medium, std::unique_ptr<Medium, py::nodelete>>(m, "Medium");
    py::class_<Homogeneous, Medium, std::unique_ptr<Homogeneous, py::nodelete>>(m, "Homogeneous")
        .def(py::init<Float, const Spectrum &, int>())
        .def(py::init<Float, const VolumeGrid &, int>())
        .def_readwrite("sigma_t", &Homogeneous::sigma_t)
        .def_readwrite("albedo", &Homogeneous::albedo)
        .def_readwrite("sampling_weight", &Homogeneous::sampling_weight)
        .def_readwrite("phase_id", &Homogeneous::phase_id)
        .def_readwrite("tetmesh", &Homogeneous::m_tetmesh);
    py::class_<Heterogeneous, Medium, std::unique_ptr<Heterogeneous, py::nodelete>>(m, "Heterogeneous")
        .def(py::init<const Properties &>())
        .def(py::init<const VolumeGrid &, const VolumeGrid &, Float, int>())
        .def("configure", &Heterogeneous::configure)
        .def_readwrite("max_density", &Heterogeneous::m_max_density)
        .def_readwrite("inv_max_density", &Heterogeneous::m_inv_max_density)
        .def_readwrite("scale", &Heterogeneous::m_scale)
        .def_readwrite("sigmaT", &Heterogeneous::m_sigmaT)
        .def_readwrite("albedo", &Heterogeneous::m_albedo);

    py::class_<Emitter, std::unique_ptr<Emitter, py::nodelete>>(m, "Emitter");

    py::class_<AreaLight, Emitter, std::unique_ptr<AreaLight, py::nodelete>>(m, "AreaLight")
        .def(py::init<int, Spectrum>());

    py::class_<EnvironmentMap, Emitter, std::unique_ptr<EnvironmentMap, py::nodelete>>(m, "EnvironmentMap")
        .def(py::init<const Properties&>())
        .def(py::init<std::string, int>())
        .def_readwrite("m_data", &EnvironmentMap::m_data)
        .def_readwrite("m_cube_distrb", &EnvironmentMap::m_cube_distrb)
        .def_readwrite("m_scale", &EnvironmentMap::m_scale);

    py::class_<Scene, Object>(m, "Scene")
        .def(py::init<>())
        .def(py::init<const std::string &, bool, const Properties&>(), "filename"_a, "auto_configure"_a = true, "props"_a = Properties())
        .def(py::init<const Camera &,
                      const std::vector<Shape *> &,
                      const std::vector<BSDF *> &,
                      const std::vector<Emitter *> &,
                      const std::vector<PhaseFunction *> &,
                      const std::vector<Medium *> &>())
        .def(py::init<const Camera &,
                      const std::vector<Shape *> &,
                      const std::vector<BSDF *> &,
                      const std::vector<Emitter *> &,
                      const std::vector<PhaseFunction *> &,
                      const std::vector<Medium *> &,
                      bool>())
        .def("clone", &Scene::clone)
        .def("setZero", &Scene::setZero)
        .def("configure", &Scene::configure, "props"_a = Properties())
        .def("load_file", &Scene::load_file, "file_name"_a, "auto_configure"_a = true, "props"_a = Properties())
        .def("load_envmap", &Scene::load_envmap, "file_name"_a)
        .def_readonly("param_map", &Scene::m_param_map, "Parameter map")
        .def_readwrite("shapes", &Scene::shape_list)
        .def_readwrite("bsdfs", &Scene::bsdf_list)
        .def_readwrite("emitters", &Scene::emitter_list)
        .def_readwrite("mediums", &Scene::medium_list)
        .def_readwrite("phases", &Scene::phase_list)
        .def_readwrite("cameras", &Scene::cameras)
        .def_readwrite("camera", &Scene::camera)
        .def_readwrite("state", &Scene::state)
        .def_readwrite("medium_shape_distrb", &Scene::medium_shape_distrb);

    py::class_<SceneAD>(m, "SceneAD")
        .def(py::init<const Scene &>())
        .def_readonly("val", &SceneAD::val)
        .def_readonly("der", &SceneAD::der)
        .def_readonly("gm", &SceneAD::gm)
        .def("getDer", &SceneAD::getDer)
        .def("zeroGrad", &SceneAD::zeroGrad);

    py::class_<GradientManager<Scene>>(m, "GradientManager")
        .def_readonly("d_scenes", &GradientManager<Scene>::d_scenes)
        .def("merge", &GradientManager<Scene>::merge);

    py::class_<RenderOptions>(m, "RenderOptions")
        .def(py::init<>())
        .def(py::init<uint64_t, int, int, int, int, bool>())
        .def(py::init<uint64_t, int, int, int, int, bool, int>())
        .def(py::init<uint64_t, int, int, int, int, bool, int, float>())
        .def_readwrite("seed", &RenderOptions::seed)
        .def_readwrite("spp", &RenderOptions::num_samples)
        .def_readwrite("sppe", &RenderOptions::num_samples_primary_edge)
        .def_readwrite("sppse", &RenderOptions::num_samples_secondary_edge)
        .def_readwrite("max_bounces", &RenderOptions::max_bounces)
        .def_readwrite("quiet", &RenderOptions::quiet)
        .def_readwrite("mode", &RenderOptions::mode)
        .def_readwrite("ddistCoeff", &RenderOptions::ddistCoeff)
        .def_readwrite("sppse0", &RenderOptions::num_samples_secondary_edge_direct)
        .def_readwrite("sppse1", &RenderOptions::num_samples_secondary_edge_indirect)
        .def_readwrite("sppe0", &RenderOptions::sppe0)
        .def_readwrite("grad_threshold", &RenderOptions::grad_threshold);

    py::class_<MALAOptions>(m, "MALAOptions")
        .def(py::init<>())
        .def(py::init<int, int, Float, Float, int, int, int, int, bool, bool, bool, Vector>())
        .def_readwrite("num_chains", &MALAOptions::num_chains)
        .def_readwrite("num_samples", &MALAOptions::num_samples)
        .def_readwrite("step_length", &MALAOptions::step_length)
        .def_readwrite("p_global", &MALAOptions::p_global);

    py::class_<Integrator>(m, "Integrator")
        .def("configure", &Integrator::configure)
        .def_readwrite("enable_antithetic", &Integrator::enable_antithetic)
        .def_readwrite("two_point_antithetic", &Integrator::two_point_antithetic)
        .def("renderC", &Integrator::renderC)
        .def("renderD", &Integrator::renderD);

    py::class_<MISIntegrator> misIntegrator(m, "MISIntegrator");

    misIntegrator
        .def(py::init<const MISIntegrator::ESamplingMode &>())
        .def(py::init<const Properties &>());

    py::enum_<MISIntegrator::ESamplingMode>(misIntegrator, "ESamplingMode", py::arithmetic())
        .value("ESolidAngle", MISIntegrator::ESolidAngle)
        .value("EArea", MISIntegrator::EArea)
        .value("EMISBalance", MISIntegrator::EMISBalance)
        .value("EMISPower", MISIntegrator::EMISPower)
        .value("EMISFirst", MISIntegrator::EMISFirst)
        .value("EMISPath", MISIntegrator::EMISPath)
        .value("ESkipSensor", MISIntegrator::ESkipSensor)
        .value("EDebugMISBalance", MISIntegrator::EDebugMISBalance)
        .value("EDebugMISPower", MISIntegrator::EDebugMISPower)
        .export_values();

    

    py::class_<IntegratorBoundary>(m, "IntegratorBoundary")
        .def("configure", &IntegratorBoundary::configure);

    py::class_<UnidirectionalPathTracer, Integrator>(m, "UnidirectionalPathTracer")
        .def("renderC", &UnidirectionalPathTracer::renderC)
        .def("renderD", &UnidirectionalPathTracer::renderD)
        .def("pixelColor", &UnidirectionalPathTracer::pixelColor)
        .def("pixelColorAD", &UnidirectionalPathTracer::pixelColorAD);

//     py::class_<Direct, Integrator>(m, "Direct")
//         .def(py::init<>())
//         .def("radiance", &Direct::radiance)
//         .def("Li", &Direct::Li)
//         .def("d_Li", &Direct::d_Li)
//         .def("render", &Direct::render)
//         .def("renderC", &Direct::renderC)
//         .def("renderD", &Direct::renderD);

//     py::class_<Direct2, Integrator>(m, "Direct2")
//         .def(py::init<>())
//         .def("renderC", &Direct2::renderC)
//         .def("renderD", &Direct2::renderD);

//     py::class_<Path, Integrator>(m, "Path")
//         .def(py::init<>())
//         .def("render", &Path::render)
//         .def("renderC", &Path::renderC)
//         .def("renderD", &Path::renderD);

    py::class_<Path2, UnidirectionalPathTracer>(m, "Path2")
        .def(py::init<>())
        .def(py::init<const Properties &>())
        .def("Li", &Path2::Li)
        .def("LiAD", &Path2::LiAD)
        .def("forwardRenderD", &Path2::forwardRenderD);
        
    py::class_<Mask, UnidirectionalPathTracer>(m, "Mask")
        .def(py::init<>())
        .def(py::init<const Properties &>())
        .def("Li", &Mask::Li)
        .def("LiAD", &Mask::LiAD);

//     py::class_<Volpath, UnidirectionalPathTracer>(m, "Volpath")
//         .def(py::init<>())
//         .def(py::init<const Properties &>())
//         .def("Li", &Volpath::Li)
//         .def("Lins", &Volpath::Lins)
//         .def("LiAD", &Volpath::LiAD)
//         .def("LiFwd", &Volpath::LiFwd);

//     py::class_<VolpathInterior, UnidirectionalPathTracer>(m, "VolpathInterior")
//         .def(py::init<>())
//         .def(py::init<const Properties &>())
//         .def("configure", &VolpathInterior::configure)
//         .def("Li", &VolpathInterior::Li)
//         .def("LiAD", &VolpathInterior::LiAD);

//     py::class_<BoundaryUnidirectional, UnidirectionalPathTracer>(m, "BoundaryUnidirectional")
//         .def(py::init<>())
//         .def(py::init<const Properties &>())
//         .def("Li", &BoundaryUnidirectional::Li)
//         .def("LiAD", &BoundaryUnidirectional::LiAD);

//     py::class_<Volpath2, Integrator>(m, "Volpath2")
//         .def(py::init<>());

//     py::class_<VolpathMerged, Integrator>(m, "VolpathMerged")
//         .def(py::init<>())
//         .def(py::init<const Properties &>());

//     py::class_<BoundaryBidirectional, Integrator>(m, "BoundaryBidirectional")
//         .def(py::init<const Scene&>())
//         .def(py::init<const Properties &>())
//         .def("configure", &BoundaryBidirectional::configure)
//         .def_readwrite("adaptive", &BoundaryBidirectional::m_adaptive)
//         .def_readwrite("shapeDistribution", &BoundaryBidirectional::m_shapeDistribution)
//         .def_readwrite("m_sampling_mode", &BoundaryBidirectional::m_sampling_mode);

//     py::class_<ParticleTracer2, Integrator>(m, "PTracer")
//         .def(py::init<>())
//         .def_readwrite("is_equal_trans", &ParticleTracer2::is_equal_trans)
//         .def("renderC", &ParticleTracer2::renderC)
//         .def("renderD", &ParticleTracer2::renderD);
    
//     py::class_<BDPT2>(m, "Bdpt")
//         .def(py::init<>())
//         .def(py::init<bool>())
//         .def_readwrite("mApplyAntithetic", &BDPT2::mApplyAntithetic)
//         .def("renderC", &BDPT2::renderC)
//         .def("renderD", &BDPT2::renderD);

    py::class_<PrimaryEdgeIntegrator, IntegratorBoundary>(m, "PrimaryEdgeIntegrator")
        .def(py::init<const Scene &>())
        .def(py::init<const Properties &>())
        // .def("render", &DirectADps::render)
        .def("renderD", &PrimaryEdgeIntegrator::renderD)
        .def("forwardRenderD", &PrimaryEdgeIntegrator::forwardRenderD)
        .def("configure", &PrimaryEdgeIntegrator::configure);

    py::class_<DirectEdgeIntegrator, IntegratorBoundary>(m, "DirectEdgeIntegrator")
        .def(py::init<const Scene &>())
        .def(py::init<const Properties &>())
        // .def("render", &DirectADps::render)
        .def("preprocess_aq", &DirectEdgeIntegrator::preprocess_aq)
        .def("preprocess_grid", &DirectEdgeIntegrator::preprocess_grid)
        .def("renderD", &DirectEdgeIntegrator::renderD)
        .def("forwardRenderD", &DirectEdgeIntegrator::forwardRenderD)
        .def_readwrite("edge_indices", &DirectEdgeIntegrator::edge_indices)
        .def_readwrite("aq_distrb", &DirectEdgeIntegrator::aq_distrb);

    // py::class_<MetropolisDirectEdgeIntegrator, DirectEdgeIntegrator>(m, "DirectEdgeMLT")
    //     .def(py::init<const Scene &>())
    //     .def(py::init<const Properties &>())
    //     // .def("render", &DirectADps::render)
    //     .def("renderD", &MetropolisDirectEdgeIntegrator::renderD)
    //     .def("load_MALA_config", &MetropolisDirectEdgeIntegrator::load_MALA_config)
    //     .def("get_sample_vol", &MetropolisDirectEdgeIntegrator::get_sample_vol)
    //     // .def("preprocess_aq", &MetropolisDirectEdgeIntegrator::preprocess_aq)
    //     .def("preprocess_grid", &MetropolisDirectEdgeIntegrator::preprocess_grid)
    //     .def("diff_bsdf_test", &MetropolisDirectEdgeIntegrator::diff_bsdf_test);

    py::class_<IndirectEdgeIntegrator, IntegratorBoundary>(m, "IndirectEdgeIntegrator")
        .def(py::init<const Scene &>())
        .def(py::init<const Properties &>())
        // .def("render", &DirectADps::render)
        .def("preprocess_aq", &IndirectEdgeIntegrator::preprocess_aq)
        .def("preprocess_grid", &IndirectEdgeIntegrator::preprocess_grid)
        .def("renderD", &IndirectEdgeIntegrator::renderD)
        .def("forwardRenderD", &IndirectEdgeIntegrator::forwardRenderD)
        .def_readwrite("edge_indices", &IndirectEdgeIntegrator::edge_indices)
        .def_readwrite("aq_distrb", &IndirectEdgeIntegrator::aq_distrb);

    py::class_<MetropolisIndirectEdgeIntegrator, IndirectEdgeIntegrator>(m, "IndirectEdgeMLT")
        .def(py::init<const Scene &>())
        .def(py::init<const Properties &>())
        // .def("render", &DirectADps::render)
        .def("renderD", &MetropolisIndirectEdgeIntegrator::renderD)
        .def("load_MALA_config", &MetropolisIndirectEdgeIntegrator::load_MALA_config)
        .def("preprocess_aq", &MetropolisIndirectEdgeIntegrator::preprocess_aq)
        .def("preprocess_grid", &MetropolisIndirectEdgeIntegrator::preprocess_grid)
        .def("solve_Grid", &MetropolisIndirectEdgeIntegrator::solve_Grid)
        .def("solve_MALA", &MetropolisIndirectEdgeIntegrator::solve_MALA)
        .def("solve_Grid_rough", &MetropolisIndirectEdgeIntegrator::solve_Grid_rough)
        .def("solve_MALA_rough", &MetropolisIndirectEdgeIntegrator::solve_MALA_rough)
        .def("get_edge_ray", &MetropolisIndirectEdgeIntegrator::get_edge_ray)
        .def("perturbe_sample", &MetropolisIndirectEdgeIntegrator::perturbe_sample);

    py::class_<BoundaryIntegrator, IntegratorBoundary>(m, "BoundaryIntegrator")
        .def(py::init<const Scene &>())
        .def("configure_primary", &BoundaryIntegrator::configure_primary)
        .def("configure_mala", &BoundaryIntegrator::configure_mala)
        .def("recompute_direct_edge", &BoundaryIntegrator::recompute_direct_edge)
        .def("recompute_indirect_edge", &BoundaryIntegrator::recompute_indirect_edge)
        .def("preprocess_grid_direct", &BoundaryIntegrator::preprocess_grid_direct)
        .def("preprocess_aq_direct", &BoundaryIntegrator::preprocess_aq_direct)
        .def("preprocess_grid_indirect", &BoundaryIntegrator::preprocess_grid_indirect)
        .def("preprocess_aq_indirect", &BoundaryIntegrator::preprocess_aq_indirect)
        // .def("render", &DirectADps::render)
        .def("renderD", &BoundaryIntegrator::renderD)
        .def("forwardRenderD", &BoundaryIntegrator::forwardRenderD)
        .def("configure_primary", &BoundaryIntegrator::configure_primary);

//     py::class_<VolPrimaryEdgeIntegrator, IntegratorBoundary>(m, "VolPrimaryEdgeIntegrator")
//         .def(py::init<const Scene &>())
//         .def(py::init<const Properties &>())
//         .def("configure", &VolPrimaryEdgeIntegrator::configure)
//         // .def("render", &DirectADps::render)
//         .def("renderD", &VolPrimaryEdgeIntegrator::renderD);

//     py::class_<VolDirectEdgeIntegrator, IntegratorBoundary>(m, "VolDirectEdgeIntegrator")
//         .def(py::init<const Scene &>())
//         .def(py::init<const Properties &>())
//         .def("renderD", &VolDirectEdgeIntegrator::renderD)
//         .def_readwrite("edge_indices", &VolDirectEdgeIntegrator::edge_indices);

//     py::class_<VolIndirectEdgeIntegrator, IntegratorBoundary>(m, "VolIndirectEdgeIntegrator")
//         .def(py::init<const Scene &>())
//         .def(py::init<const Properties &>())
//         .def("renderD", &VolIndirectEdgeIntegrator::renderD)
//         .def_readwrite("edge_indices", &VolIndirectEdgeIntegrator::edge_indices);

//     py::class_<VolBoundaryIntegrator, IntegratorBoundary>(m, "VolBoundaryIntegrator")
//         .def(py::init<const Scene &>())
//         .def(py::init<const Properties &>())
//         .def("configure_primary", &VolBoundaryIntegrator::configure_primary)
//         .def("renderD", &VolBoundaryIntegrator::renderD);

    py::class_<BoundarySegmentInfo>(m, "BoundarySegmentInfo")
        .def(py::init<>())
        .def_readwrite("xS_0", &BoundarySegmentInfo::xS_0)
        .def_readwrite("xS_1", &BoundarySegmentInfo::xS_1)
        .def_readwrite("xS_2", &BoundarySegmentInfo::xS_2)
        .def_readwrite("xB_0", &BoundarySegmentInfo::xB_0)
        .def_readwrite("xB_1", &BoundarySegmentInfo::xB_1)
        .def_readwrite("xD_0", &BoundarySegmentInfo::xD_0)
        .def_readwrite("xD_1", &BoundarySegmentInfo::xD_1)
        .def_readwrite("xD_2", &BoundarySegmentInfo::xD_2);

    py::class_<PixelBoundarySamplingRecord>(m, "PixelBoundarySamplingRecord")
        .def(py::init<>())
        .def_readwrite("onSurface_S", &PixelBoundarySamplingRecord::onSurface_S)
        .def_readwrite("dir_local", &PixelBoundarySamplingRecord::dir_local)
        .def_readwrite("dir_visible", &PixelBoundarySamplingRecord::dir_visible)
        .def_readwrite("dir_edge", &PixelBoundarySamplingRecord::dir_edge)
        .def_readwrite("shape_id", &PixelBoundarySamplingRecord::shape_id)
        .def_readwrite("triangle_id", &PixelBoundarySamplingRecord::triangle_id)
        .def_readwrite("med_id_S", &PixelBoundarySamplingRecord::med_id_S)
        .def_readwrite("tet_id_S", &PixelBoundarySamplingRecord::tet_id_S)
        .def_readwrite("barycentric4_S", &PixelBoundarySamplingRecord::barycentric4_S);

    py::class_<PixelBoundarySegmentInfo>(m, "PixelBoundarySegmentInfo")
        .def(py::init<>())
        .def_readwrite("xD", &PixelBoundarySegmentInfo::xD)
        .def_readwrite("xB", &PixelBoundarySegmentInfo::xB)
        .def_readwrite("xS_0", &PixelBoundarySegmentInfo::xS_0)
        .def_readwrite("xS_1", &PixelBoundarySegmentInfo::xS_1)
        .def_readwrite("xS_2", &PixelBoundarySegmentInfo::xS_2)
        .def("getVelocities", &PixelBoundarySegmentInfo::getVelocities)
        .def("maxCoeff", static_cast<Float (PixelBoundarySegmentInfo::*)() const>(&PixelBoundarySegmentInfo::maxCoeff))
        .def("setZero", &PixelBoundarySegmentInfo::setZero);

    py::class_<PixelBoundaryIntegrator, Integrator, IntegratorBoundary>(m, "PixelBoundaryIntegrator")
        .def(py::init<>())
        .def(py::init<const Properties &>())
        .def(py::init<const Scene &>())
        .def("renderC", &PixelBoundaryIntegrator::renderC)
        .def("renderD", &PixelBoundaryIntegrator::renderD)
        .def("guide", &PixelBoundaryIntegrator::guide)
        .def("getSampleMap", &PixelBoundaryIntegrator::getSampleMap)
        .def("getVarMap", &PixelBoundaryIntegrator::getVarMap)
        .def("guidingDensity", &PixelBoundaryIntegrator::guidingDensity)
        .def("normalVelocity", &PixelBoundaryIntegrator::normalVelocity)
        .def_readwrite("guideMap", &PixelBoundaryIntegrator::guideMap)
        .def("Li", &PixelBoundaryIntegrator::Li)
        .def("velocities", &PixelBoundaryIntegrator::velocities);

    // algorithm 1 for volume
//     py::module_ m_algorithm1_vol = m.def_submodule("algorithm1_vol", "algorithm 1 for volume");
//     m_algorithm1_vol.def("eval", &algorithm1_vol::eval);
//     m_algorithm1_vol.def("d_eval", &algorithm1_vol::d_eval);
//     m_algorithm1_vol.def("evalFwd", &algorithm1_vol::evalFwd);
//     m_algorithm1_vol.def("d_evalPath", &algorithm1_vol::d_evalPath);
//     m_algorithm1_vol.def("evalPathFwd", &algorithm1_vol::evalPathFwd);
//     m_algorithm1_vol.def("d_getPoint", &algorithm1_vol::d_getPoint);
//     m_algorithm1_vol.def("d_getPath", &algorithm1_vol::d_getPath);
//     m_algorithm1_vol.def("getPathFwd", &algorithm1_vol::getPathFwd);
//     m_algorithm1_vol.def("d_evalVertex", &algorithm1_vol::d_evalVertex);
//     m_algorithm1_vol.def("evalVertexFwd", &algorithm1_vol::evalVertexFwd);
//     m_algorithm1_vol.def("baseline", &algorithm1_vol::baseline);
//     m_algorithm1_vol.def("baselineFwd", &algorithm1_vol::baselineFwd);

//     py::module_ m_volpath_meta = m.def_submodule("volpath_meta", "volume path meta");
//     m_volpath_meta.def("__Li", &volpath_meta::__Li);

//     py::module_ m_test = m.def_submodule("test", "test");
//     m_test.def("lookupFloat", &lookupFloat);
//     m_test.def("d_lookupFloat", &d_lookupFloat);
    // ============================ Transform ============================
    m.def("look_at", psdr::transform::look_at, "origin"_a, "target"_a, "up"_a);

//     m.attr("angleEps") = AngleEpsilon;
//     m.attr("edgeEps") = EdgeEpsilon;
//     m.attr("version") = "0.0.6";
// #ifdef FORWARD
//     m.attr("forward_enabled") = true;
// #else
//     m.attr("forward_enabled") = false;
// #endif

    // ====================== verbose begin ====================
    m.attr("verbose") = &verbose;
    m.def("get_verbose", &get_verbose);
    m.def("set_verbose", &set_verbose);
    // ====================== verbose end ======================
    m.attr("forward") = &forward;
    m.def("get_forward", &get_forward);
    m.def("set_forward", &set_forward);
    // ====================== test begin ======================
    // m.def("__test_geometry", &__test_geometry);
    m.def("load_from_string", &psdr::SceneLoader::load_from_string1);

    m.def("omp_get_thread_num", &omp_get_thread_num);
    m.def("omp_get_num_procs", &omp_get_num_procs);
}

// int main(int argc, const char** argv) {
//     Logger::static_init();
//     set_verbose(true);
//     assert(argc == 2);
//     std::string filename = argv[1];
//     PSDR_INFO("Loading file : {}", filename);
//     Scene scene(filename);
//     Volpath volpath;
//     RenderOptions options(0, 1000, 10, 0, 0, 0, 0);
//     auto image = volpath.renderC(scene, options);
//     Bitmap bitmap(image, Vector2i(scene.camera.width, scene.camera.height));
//     bitmap.save("test.exr");
//     return 0;
// }

#include <core/logger.h>
#include <integrator/direct.h>
#include <render/common.h>
#include <integrator/path2.h>
#include <render/scene.h>
#include <integrator/test.h>
#include <bsdf/diffuse.h>
#include <bsdf/roughconductor.h>
#include <bsdf/roughdielectric.h>
#include <iostream>
#include <fstream>

int main(int argc, const char *argv[]) {
    Logger::static_init();
    set_verbose(true);
    assert(argc == 2);
    std::string filename = argv[1];
    PSDR_INFO("Loading file : {}", filename);
    Scene scene(filename);
    Path2 path;
    RenderOptions options(0, 1000, 10, 0, 0, 0, 0);
    auto image = path.renderC(scene, options);
    Bitmap bitmap(image, Vector2i(scene.camera.width, scene.camera.height));
    bitmap.save("test.exr");
    return 0;
}