#include <render/embree_scene.h>

EmbreeScene::EmbreeScene()
{
    embree_device = rtcNewDevice(nullptr);
    embree_scene = rtcNewScene(embree_device);
    rtcSetSceneBuildQuality(embree_scene, RTC_BUILD_QUALITY_HIGH);
    rtcSetSceneFlags(embree_scene, RTC_SCENE_FLAG_ROBUST);
}

EmbreeScene::~EmbreeScene()
{
    rtcReleaseScene(embree_scene);
    rtcReleaseDevice(embree_device);
}

void EmbreeScene::insert(const std::vector<Vector3> &V,
                         const std::vector<Vector3i> &F,
                         bool commit_scene)
{
    auto mesh = rtcNewGeometry(embree_device, RTC_GEOMETRY_TYPE_TRIANGLE);

    auto vertices = (Vector4f *)rtcSetNewGeometryBuffer(
        mesh, RTC_BUFFER_TYPE_VERTEX, 0, RTC_FORMAT_FLOAT3, sizeof(Vector4f),
        V.size());
    for (auto i = 0; i < static_cast<int>(V.size()); i++)
    {
        auto vertex = V[i];
        vertices[i] = Vector4f(vertex(0), vertex(1), vertex(2), 0.f);
    }
    auto triangles = (Vector3i *)rtcSetNewGeometryBuffer(
        mesh, RTC_BUFFER_TYPE_INDEX, 0, RTC_FORMAT_UINT3, sizeof(Vector3i),
        F.size());
    for (auto i = 0; i < static_cast<int>(F.size()); i++)
        triangles[i] = F[i];

    rtcSetGeometryVertexAttributeCount(mesh, 1);
    rtcCommitGeometry(mesh);
    // NOTE: id start from 0
    rtcAttachGeometry(embree_scene, mesh);

    rtcReleaseGeometry(mesh);

    if (commit_scene)
        rtcCommitScene(embree_scene);
}

void EmbreeScene::commit()
{
    rtcCommitScene(embree_scene);
}

using EIntersection = EmbreeScene::Intersection;
bool EmbreeScene::rayIntersect(const Ray &ray, EIntersection &its)
{
    RTCIntersectContext rtc_context;
    rtcInitIntersectContext(&rtc_context);
    RTCRayHit rtc_ray_hit;
    rtc_ray_hit.ray.org_x = ray.org.x();
    rtc_ray_hit.ray.org_y = ray.org.y();
    rtc_ray_hit.ray.org_z = ray.org.z();
    rtc_ray_hit.ray.dir_x = ray.dir.x();
    rtc_ray_hit.ray.dir_y = ray.dir.y();
    rtc_ray_hit.ray.dir_z = ray.dir.z();
    rtc_ray_hit.ray.tnear = ray.tmin;
    rtc_ray_hit.ray.tfar = ray.tmax;
    rtc_ray_hit.ray.mask = (unsigned int)(-1);
    rtc_ray_hit.ray.time = 0.f;
    rtc_ray_hit.ray.flags = 0;
    rtc_ray_hit.hit.geomID = RTC_INVALID_GEOMETRY_ID;
    rtc_ray_hit.hit.primID = RTC_INVALID_GEOMETRY_ID;
    rtc_ray_hit.hit.instID[0] = RTC_INVALID_GEOMETRY_ID;
    rtcIntersect1(embree_scene, &rtc_context, &rtc_ray_hit);
    if (rtc_ray_hit.hit.geomID == RTC_INVALID_GEOMETRY_ID)
    {
        its.t = std::numeric_limits<Float>::infinity();
        its.shape_id = -1;
        its.triangle_id = -1;
        return false;
    }
    its.t = rtc_ray_hit.ray.tfar;
    its.shape_id = static_cast<int>(rtc_ray_hit.hit.geomID);
    its.triangle_id = static_cast<int>(rtc_ray_hit.hit.primID);
    return true;
}

bool EmbreeScene::rayIntersect(const Ray &_ray, bool onSurface, EIntersection &its)
{
    Ray ray(_ray);
    ray.tmin = onSurface ? ShadowEpsilon : 0.0f;
    ray.tmax = std::numeric_limits<Float>::infinity();
    rayIntersect(ray, its);
    return true;
}

INACTIVE_FN(EmbreeScene_rayIntersect,
            static_cast<bool (EmbreeScene::*)(const Ray &, EIntersection &)>(&EmbreeScene::rayIntersect));
INACTIVE_FN(EmbreeScene_rayIntersect2,
            static_cast<bool (EmbreeScene::*)(const Ray &, bool, EIntersection &its)>(&EmbreeScene::rayIntersect));