#include <catch2/catch.hpp>
#include <iostream>
#include <core/fwd.h>
#include <render/camera.h>
#include <core/sampler.h>
#include <render/common.h>
#include <render/intersection.h>
#include "rfilter.h"

using namespace std;

// namespace
// {
struct SceneData
{
    Camera camera;
    Vector v0, v1, v2;
};

Vector faceNormal(Vector &v0, Vector &v1, Vector &v2)
{
    return (v1 - v0).cross(v2 - v0).normalized();
}

Float eval(Ray &ray, SceneData &scene, Array2i &pixel_idx)
{
    printf("ray.org : %f, %f, %f \n", ray.org[0], ray.org[1], ray.org[2]);
    printf("ray.dir : %f, %f, %f \n", ray.dir[0], ray.dir[1], ray.dir[2]);
    Vector &v0 = scene.v0,
           &v1 = scene.v1,
           &v2 = scene.v2;
    Array uvt = rayIntersectTriangle(v0, v1, v2, ray);
    Float u = detach(uvt.x()),
          v = detach(uvt.y()),
          t = detach(uvt.z());
    Vector p = (1 - u - v) * v0 + u * v1 + v * v2;
    printf("p : %f, %f, %f \n", p[0], p[1], p[2]);
    Vector dir = p - scene.camera.cpos;
    Float dist = dir.norm();
    dir /= dist;
    Vector n = faceNormal(v0, v1, v2);
    Intersection its;
    its.geoFrame = Frame(n);
    its.shFrame = Frame(n);
    its.p = p;
    its.wi = its.toLocal(-dir);
    return scene.camera.evalFilter(pixel_idx.x(), pixel_idx.y(), its);
}

Float pixelColor(SceneData &scene, Array2i &pixel_idx, RndSampler &sampler)
{
    Ray primal, dual;
    Array2 rnd = sampler.next2D();
    printf("rnd : %f, %f \n", rnd[0], rnd[1]);
    scene.camera.samplePrimaryRayFromFilter(pixel_idx.x(), pixel_idx.y(),
                                            rnd, primal, dual);
    Float value = 0.;
    value += eval(primal, scene, pixel_idx);
    value += eval(dual, scene, pixel_idx);
    return value;
}

// ! can't put in the namespace
Float d_pixelColor(SceneData &scene, SceneData &d_scene,
                   Array2i &pixel_idx, RndSampler &sampler)
{
#if defined(ENZYME) && defined(ENZYME_TEST)
    __enzyme_autodiff((void *)pixelColor,
                      enzyme_dup, &scene, &d_scene,
                      enzyme_const, &pixel_idx,
                      enzyme_const, &sampler);
#endif
    return 0.;
}

void evalFilter(Float &x, Float &res)
{
    TentFilter f(0.5);
    res = f.eval(x);
}

void d_evalFilter(Float &x, Float &d_x, Float &res, Float &d_res)
{
    return __enzyme_autodiff((void *)evalFilter,
                             enzyme_dup, &x, &d_x,
                             enzyme_dup, &res, &d_res);
}

// TEST_CASE("filter")
// {
//     Float v = 0;
//     Float d_v = 1.;
//     Float x = 0.7;
//     Float d_x = 0.;
//     evalFilter(x, v);
//     std::cout << v << std::endl;

//     d_evalFilter(x, d_x, v, d_v);
//     std::cout << d_x << std::endl;
// }

TEST_CASE("filter", "[filter]")
{
    Matrix4x4 cam_to_world;
    cam_to_world << 1.0000, 0.0000, 0.0000, 0.0000,
        -0.0000, 0.4472, -0.8944, 20.0000,
        0.0000, -0.8944, -0.4472, 10.0000,
        0.0000, 0.0000, 0.0000, 1.0000;
    Matrix3x3 cam_to_ndc;
    cam_to_ndc << 0.9743, 0.0000, 0.0000,
        0.0000, 0.9743, 0.0000,
        0.0000, 0.0000, 1.0000;

    Vector v0(0, 0, 0);
    Vector v1(-1, 0, 1);
    Vector v2(0, 0, 1);
    Camera camera(320, 180, cam_to_world, cam_to_ndc, 0.01, -1);
    SceneData scene{camera, v0, v1, v2};
    SceneData d_scene{camera, Vector(0, 0, 0), Vector(0, 0, 0), Vector(0, 0, 0)};
    Array2i pixel_idx(152, 104);
    RndSampler sampler(12, ravel_multi_index(pixel_idx, {320, 180}));
    sampler.save();
    Float value = pixelColor(scene, pixel_idx, sampler);
    std::cout << value << std::endl;
    sampler.restore();
    d_pixelColor(scene, d_scene, pixel_idx, sampler);
    std::cout << d_scene.v0 << std::endl;
}