#include <bsdf/diffuse.h>
// #include <bsdf/diffuse2.h>
#include <bsdf/null.h>
#include <integrator/test.h>
#include <render/intersection.h>

void eval(const BSDF *bsdf, const Intersection &its, const Vector &wo,
          EBSDFMode mode, BSDFEvalType &ret) {
    ret = bsdf->eval(its, wo, mode);
}

void d_rev_eval(const BSDF *bsdf, BSDF *d_bsdf, const Intersection &its, const Vector &wo,
                EBSDFMode mode, BSDFEvalType &ret, BSDFEvalType &d_ret) {
    __enzyme_autodiff((void *) eval, enzyme_dup, bsdf, d_bsdf, enzyme_const, its, enzyme_const, wo, enzyme_const, mode, enzyme_dup, &ret, &d_ret);
}

void d_fwd_eval(const BSDF *bsdf, BSDF *d_bsdf, const Intersection &its, const Vector &wo,
                EBSDFMode mode, BSDFEvalType &ret, BSDFEvalType &d_ret) {
    __enzyme_fwddiff((void *) eval, enzyme_dup, bsdf, d_bsdf, enzyme_const, its, enzyme_const, wo, enzyme_const, mode, enzyme_dup, &ret, &d_ret);
}

void test_diffuse() {
    BSDF        *b       = new DiffuseBSDF(Spectrum(0.5, 0.2, 0.3));
    BSDF        *d_rev_b = new DiffuseBSDF(Spectrum::Zero());
    BSDF        *d_fwd_b = new DiffuseBSDF(Spectrum(1, 0, 0));
    Spectrum     rev_ret, d_rev_ret(1, 0, 0);
    Spectrum     fwd_ret, d_fwd_ret(0, 0, 0);
    Intersection its;
    Vector       wo;
    EBSDFMode    mode = EBSDFMode::ERadiance;
    its.wi.z()        = 1;
    wo.z()            = 1;
    // ret = (0.5, 0.2, 0.3) / pi = (0.1591, 0.0636, 0.09549)
    // d_ret = (0.5, 0, 0) / pi
    d_rev_eval(b, d_rev_b, its, wo, mode, rev_ret, d_rev_ret);
    d_fwd_eval(b, d_fwd_b, its, wo, mode, fwd_ret, d_fwd_ret);
    printf("reverse mode:\n");
    printf("  ret   = (%f, %f, %f)\n", rev_ret.x(), rev_ret.y(), rev_ret.z());
    Spectrum d_rev_res = dynamic_cast<DiffuseBSDF *>(d_rev_b)->reflectance.m_data[0];
    printf("  d_ret = (%f, %f, %f)\n", d_rev_res.x(), d_rev_res.y(), d_rev_res.z());

    printf("forward mode:\n");
    printf("  ret   = (%f, %f, %f)\n", fwd_ret.x(), fwd_ret.y(), fwd_ret.z());
    printf("  d_ret = (%f, %f, %f)\n", d_fwd_ret.x(), d_fwd_ret.y(), d_fwd_ret.z());
}

// void test_diffuse2() {
//     BSDF        *b       = new DiffuseBSDF2(Spectrum(0.5, 0.2, 0.3));
//     BSDF        *d_rev_b = new DiffuseBSDF2(Spectrum::Zero());
//     BSDF        *d_fwd_b = new DiffuseBSDF2(Spectrum(1, 0, 0));
//     Spectrum     rev_ret, d_rev_ret(1, 0, 0);
//     Spectrum     fwd_ret, d_fwd_ret(0, 0, 0);
//     Intersection its;
//     Vector       wo;
//     EBSDFMode    mode = EBSDFMode::ERadiance;
//     its.wi.z()        = 1;
//     wo.z()            = 1;
//     // ret = (0.5, 0.2, 0.3) * 1.2 = (0.6, 0.24, 0.36)
//     // d_ret = (0.5, 0, 0) * 1.2 = (0.6, 0, 0)
//     d_rev_eval(b, d_rev_b, its, wo, mode, rev_ret, d_rev_ret);
//     d_fwd_eval(b, d_fwd_b, its, wo, mode, fwd_ret, d_fwd_ret);
//     printf("reverse mode:\n");
//     printf("  ret   = (%f, %f, %f)\n", rev_ret.x(), rev_ret.y(), rev_ret.z());
//     Spectrum d_rev_res = dynamic_cast<DiffuseBSDF2 *>(d_rev_b)->reflectance.m_data[0];
//     printf("  d_ret = (%f, %f, %f)\n", d_rev_res.x(), d_rev_res.y(), d_rev_res.z());

//     printf("forward mode:\n");
//     printf("  ret   = (%f, %f, %f)\n", fwd_ret.x(), fwd_ret.y(), fwd_ret.z());
//     printf("  d_ret = (%f, %f, %f)\n", d_fwd_ret.x(), d_fwd_ret.y(), d_fwd_ret.z());
// }

void __configure_shape(Shape *shape) {
    shape->configureC();
} 

void d_configure_shape(Shape *shape, Shape *d_shape) {
    __enzyme_autodiff((void *)__configure_shape, enzyme_dup, shape, d_shape);
}

void do_test(Shape *s, Shape *d_s) {
    test_diffuse();
//   test_diffuse2();
    
    // Shape shape = *s;
    // // shape.m_to_world = Matrix4x4::Identity();
    // shape.vertices_raw.push_back(Vector3(1, 0, 0));
    // shape.vertices_raw.push_back(Vector3(0, 1, 0));
    // shape.vertices_raw.push_back(Vector3(0, 0, 1));
    Shape &shape = *s;
    Shape &d_shape = *d_s;
    d_shape.setZero();
    for (int i = 0; i < static_cast<int>(d_shape.vertices.size()); ++i) {
        d_shape.vertices[i] = Vector::Ones();
    }
    d_shape.m_to_world = Matrix4x4::Zero();
    d_shape.configureD(shape);
    std::cout << std::endl;
    for (int i = 0; i < shape.vertices.size(); ++i) {
        auto &v = shape.vertices[i];
        printf("%f, %f, %f\n", v.x(), v.y(), v.z());
    }
    for (int i = 0; i < shape.vertices_raw.size(); ++i) {
        auto &v = shape.vertices_raw[i];
        printf("%f, %f, %f\n", v.x(), v.y(), v.z());
    }
    std::cout << shape.m_to_world << std::endl;
    for (int i = 0; i < d_shape.vertices_raw.size(); ++i) {
        auto &v = d_shape.vertices_raw[i];
        printf("%f, %f, %f\n", v.x(), v.y(), v.z());
    }
    for (int i = 0; i < d_shape.vertices.size(); ++i) {
        auto &v = d_shape.vertices[i];
        printf("%f, %f, %f\n", v.x(), v.y(), v.z());
    }
    std::cout << d_shape.m_to_world << std::endl;
    printf("shape done\n");
}

#include <integrator/path2.h>
#include <integrator/d_scene.h>

void run(Scene *p_scene) {
    // std::string name = "test";
    // Scene &scene = *p_scene;
    // Scene d_scene = scene;
    // SceneAD sceneAD(*p_scene);
    // // d_scene.setZero();
    // // scene.configure(scene.m_properties);
    // // d_scene.configureD(scene);
    // // d_scene.configureD(scene);
    // Path2 path;
    // RenderOptions options(0, 10, 10, 0, 0, 0, 0);
    // auto image = path.renderC(scene, options);
    // // SceneAD sceneAD(scene);
    // // ArrayXd d_image(3 * 256 * 256);
    // path.renderD(sceneAD, options, image);
    // scene_configure_d(sceneAD);
    // // sceneAD.der.configureD(sceneAD.val);
    // Bitmap bitmap(image, Vector2i(scene.camera.width, scene.camera.height));
    // bitmap.save((name + ".exr").c_str());
    // // print_test_res(sceneAD, name + ".txt");
    // printf("run done\n");
}