#include <render/scene.h>
#include <core/math_func.h>
#include <algorithm>
#include <assert.h>
#include <iostream>
#include "emitter/area.h"
#include <scene_loader.h>
#include <core/logger.h>
#include "tetra.hpp"
#include <medium/homogeneous.h>
#include <core/statistics.h>

// template <typename T>
// T __enzyme_virtualreverse(T);

__attribute__((noinline)) bool _isIntersect(
    const RTCScene &embree_scene, const Ray &ray, Float dist)
{
    RTCIntersectContext rtc_context;
    rtcInitIntersectContext(&rtc_context);
    RTCRayHit rtc_ray_hit;
    rtc_ray_hit.ray.org_x = ray.org.x();
    rtc_ray_hit.ray.org_y = ray.org.y();
    rtc_ray_hit.ray.org_z = ray.org.z();
    rtc_ray_hit.ray.dir_x = ray.dir.x();
    rtc_ray_hit.ray.dir_y = ray.dir.y();
    rtc_ray_hit.ray.dir_z = ray.dir.z();
    rtc_ray_hit.ray.tnear = ShadowEpsilon;
    rtc_ray_hit.ray.tfar = dist;
    rtc_ray_hit.ray.mask = (unsigned int)(-1);
    rtc_ray_hit.ray.time = 0.f;
    rtc_ray_hit.ray.flags = 0;
    rtc_ray_hit.hit.geomID = RTC_INVALID_GEOMETRY_ID;
    rtc_ray_hit.hit.primID = RTC_INVALID_GEOMETRY_ID;
    rtc_ray_hit.hit.instID[0] = RTC_INVALID_GEOMETRY_ID;
    rtcIntersect1(embree_scene, &rtc_context, &rtc_ray_hit);
    return rtc_ray_hit.hit.geomID != RTC_INVALID_GEOMETRY_ID;
}

__attribute__((noinline)) bool _rayIntersect(
    const RTCScene &embree_scene, const Ray &ray, bool onSurface, Intersection &its)
{
    Float tmin = onSurface ? ShadowEpsilon : 0.0f;
    RTCIntersectContext rtc_context;
    rtcInitIntersectContext(&rtc_context);
    RTCRayHit rtc_ray_hit;
    rtc_ray_hit.ray.org_x = ray.org.x();
    rtc_ray_hit.ray.org_y = ray.org.y();
    rtc_ray_hit.ray.org_z = ray.org.z();
    rtc_ray_hit.ray.dir_x = ray.dir.x();
    rtc_ray_hit.ray.dir_y = ray.dir.y();
    rtc_ray_hit.ray.dir_z = ray.dir.z();
    rtc_ray_hit.ray.tnear = tmin;
    rtc_ray_hit.ray.tfar = std::numeric_limits<Float>::infinity();
    rtc_ray_hit.ray.mask = (unsigned int)(-1);
    rtc_ray_hit.ray.time = 0.f;
    rtc_ray_hit.ray.flags = 0;
    rtc_ray_hit.hit.geomID = RTC_INVALID_GEOMETRY_ID;
    rtc_ray_hit.hit.primID = RTC_INVALID_GEOMETRY_ID;
    rtc_ray_hit.hit.instID[0] = RTC_INVALID_GEOMETRY_ID;
    rtcIntersect1(embree_scene, &rtc_context, &rtc_ray_hit);
    if ( rtc_ray_hit.hit.geomID == RTC_INVALID_GEOMETRY_ID ) {
        its.t = std::numeric_limits<Float>::infinity();
        its.ptr_shape = nullptr;
        return false;
    }
    else {
        // Fill in the corresponding pointers
        its.indices[0] = static_cast<int>(rtc_ray_hit.hit.geomID);
        its.shape_id = its.indices[0];
        // Ray-Shape intersection
        its.indices[1] = static_cast<int>(rtc_ray_hit.hit.primID);
        its.triangle_id = its.indices[1];
        return true;
    }
}


__attribute__((noinline)) bool _rayIntersect2(
    const RTCScene &embree_scene, const Ray &ray, Intersection &its)
{
    RTCIntersectContext rtc_context;
    rtcInitIntersectContext(&rtc_context);
    RTCRayHit rtc_ray_hit;
    rtc_ray_hit.ray.org_x = ray.org.x();
    rtc_ray_hit.ray.org_y = ray.org.y();
    rtc_ray_hit.ray.org_z = ray.org.z();
    rtc_ray_hit.ray.dir_x = ray.dir.x();
    rtc_ray_hit.ray.dir_y = ray.dir.y();
    rtc_ray_hit.ray.dir_z = ray.dir.z();
    rtc_ray_hit.ray.tnear = ray.tmin;
    rtc_ray_hit.ray.tfar = ray.tmax;
    rtc_ray_hit.ray.mask = (unsigned int)(-1);
    rtc_ray_hit.ray.time = 0.f;
    rtc_ray_hit.ray.flags = 0;
    rtc_ray_hit.hit.geomID = RTC_INVALID_GEOMETRY_ID;
    rtc_ray_hit.hit.primID = RTC_INVALID_GEOMETRY_ID;
    rtc_ray_hit.hit.instID[0] = RTC_INVALID_GEOMETRY_ID;
    rtcIntersect1(embree_scene, &rtc_context, &rtc_ray_hit);
    if (rtc_ray_hit.hit.geomID == RTC_INVALID_GEOMETRY_ID) {
        its.t = std::numeric_limits<Float>::infinity();
        its.ptr_shape = nullptr;
        return false;
    }
    else {
        // Fill in the corresponding pointers
        its.indices[0] = static_cast<int>(rtc_ray_hit.hit.geomID);
        its.shape_id = its.indices[0];
        // Ray-Shape intersection
        its.indices[1] = static_cast<int>(rtc_ray_hit.hit.primID);
        its.triangle_id = its.indices[1];
        return true;
    }
}


void intersectionFilter(const RTCFilterFunctionNArguments* args){
    /* avoid crashing when debug visualizations are used */
    if (args->context == nullptr) return;

    assert(args->N == 1);
    unsigned int mask = RTCRayN_mask(args->ray, 1, 0);
    int* valid = args->valid;

    /* ignore hit if */
    if (mask !=  RTCHitN_geomID(args->hit, 1, 0)) 
        valid[0] = 0;
}

__attribute__((noinline)) bool _rayIntersectShape(int mesh_id, const RTCScene &embree_scene, const Ray &ray, Intersection &its) {
    Ray _ray = ray;
    // while(scene.rayIntersect(_ray, true, its)) {
    //     if (its.ptr_shape == this) return true;
    //     _ray.org = its.p + _ray.dir * 1e-2;
    // }
    // return false;
    RTCIntersectContext rtc_context;
    rtcInitIntersectContext(&rtc_context);
    rtc_context.filter = intersectionFilter;
    // rtcSetGeometryMask(rtcGetGeometry(scene.embree_scene, mesh_id), 1 << (mesh_id + 1));
    RTCRayHit rtc_ray_hit;
    rtc_ray_hit.ray.org_x = ray.org.x();
    rtc_ray_hit.ray.org_y = ray.org.y();
    rtc_ray_hit.ray.org_z = ray.org.z();
    rtc_ray_hit.ray.dir_x = ray.dir.x();
    rtc_ray_hit.ray.dir_y = ray.dir.y();
    rtc_ray_hit.ray.dir_z = ray.dir.z();
    rtc_ray_hit.ray.tnear = ShadowEpsilon;
    rtc_ray_hit.ray.tfar = std::numeric_limits<Float>::infinity();
    rtc_ray_hit.ray.mask = mesh_id;
    rtc_ray_hit.ray.time = 0.f;
    rtc_ray_hit.ray.flags = 0;
    rtc_ray_hit.hit.geomID = RTC_INVALID_GEOMETRY_ID;
    rtc_ray_hit.hit.primID = RTC_INVALID_GEOMETRY_ID;
    rtc_ray_hit.hit.instID[0] = RTC_INVALID_GEOMETRY_ID;
    rtcIntersect1(embree_scene, &rtc_context, &rtc_ray_hit);
    if (rtc_ray_hit.hit.geomID == RTC_INVALID_GEOMETRY_ID || static_cast<int>(rtc_ray_hit.hit.geomID) != mesh_id) {
        its.t = std::numeric_limits<Float>::infinity();
        its.ptr_shape = nullptr;
        return false;
    } else {
        // Fill in the corresponding pointers
        its.indices[0] = static_cast<int>(rtc_ray_hit.hit.geomID);
        its.shape_id = its.indices[0];
        assert(its.shape_id == mesh_id);
        // Ray-Shape intersection
        its.indices[1] = static_cast<int>(rtc_ray_hit.hit.primID);
        its.triangle_id = its.indices[1];
        return true;
    }
}


// ! warning: didn't implement memmove, using memcpy as fallback which can result in errors
INACTIVE_FN(_rayIntersect, _rayIntersect);
INACTIVE_FN(_rayIntersect2, _rayIntersect2);
INACTIVE_FN(_isIntersect, _isIntersect);
INACTIVE_FN(_rayIntersectShape, _rayIntersectShape);

Scene::Scene(
    const Camera &camera, const std::vector<Shape *> &shapes, const std::vector<BSDF *> &bsdfs,  
    const std::vector<Emitter *> &area_lights, const std::vector<PhaseFunction *> &phases,
    const std::vector<Medium *> &mediums, bool use_hierarchy)
    : camera(camera), shape_list(shapes), bsdf_list(bsdfs)
    , emitter_list(area_lights), phase_list(phases), medium_list(mediums)
{
    m_state = ESLoaded;
    configure();
}


void Scene::configure()
{
    if (embree_scene)
        rtcReleaseScene(embree_scene);
    if (embree_device)
        rtcReleaseDevice(embree_device);

    // Initialize Embree scene
    embree_device = rtcNewDevice(nullptr);
    embree_scene = rtcNewScene(embree_device);
    rtcSetSceneBuildQuality(embree_scene, RTC_BUILD_QUALITY_HIGH);
    rtcSetSceneFlags(embree_scene, RTC_SCENE_FLAG_ROBUST);

    // Copy the scene into Embree (since Embree requires 16 bytes alignment)
    for ( const Shape *shape : shape_list ) {
        auto mesh = rtcNewGeometry(embree_device, RTC_GEOMETRY_TYPE_TRIANGLE);
        auto vertices = (Vector4f *)rtcSetNewGeometryBuffer(
            mesh, RTC_BUFFER_TYPE_VERTEX, 0, RTC_FORMAT_FLOAT3, sizeof(Vector4f),
            shape->num_vertices);
        for ( auto i = 0; i < shape->num_vertices; i++ ) {
            auto vertex = shape->getVertex(i);
            vertices[i] = Vector4f(vertex(0), vertex(1), vertex(2), 0.f);
        }
        auto triangles = (Vector3i *)rtcSetNewGeometryBuffer(
            mesh, RTC_BUFFER_TYPE_INDEX, 0, RTC_FORMAT_UINT3, sizeof(Vector3i),
            shape->num_triangles);
        for ( auto i = 0; i < shape->num_triangles; i++ )
            triangles[i] = shape->getIndices(i);

        rtcSetGeometryVertexAttributeCount(mesh, 1);
        rtcCommitGeometry(mesh);
        // NOTE: id start from 0
        rtcAttachGeometry(embree_scene, mesh);
        rtcReleaseGeometry(mesh);
    }
    rtcCommitScene(embree_scene);

    num_lights = emitter_list.size();
    light_distrb.clear();
    point_emitter_list.clear();
    area_emitter_list.clear();

    if ( emitter_list.size() > 0 ) {
        light_distrb.reserve(num_lights);
        for ( int i = 0; i < num_lights; ++i ) {
            int shape_id = emitter_list[i]->getShapeID();
            if ( shape_id >= 0 ) {
                const Shape &shape = *shape_list[shape_id];
                light_distrb.append(shape.getArea() *
                                    emitter_list[i]->getIntensity().maxCoeff());
                area_light_distrb.append(shape.getArea() *
                                         emitter_list[i]->getIntensity().maxCoeff());
                area_emitter_list.push_back(emitter_list[i]);
            }
            else {
                // power-heuristic for point light
                light_distrb.append(emitter_list[i]->getIntensity().maxCoeff());
                point_light_distrb.append(emitter_list[i]->getIntensity().maxCoeff());
                point_emitter_list.push_back(emitter_list[i]);
            }
        }
        light_distrb.normalize();
        if (!point_emitter_list.empty())
            point_light_distrb.normalize();
        if (!area_emitter_list.empty())
            area_light_distrb.normalize();
    }

    shape_distrb.clear();
    shape_distrb.reserve(shape_list.size());
    for (int shape_id = 0; shape_id < (int)shape_list.size(); shape_id++)
        shape_distrb.append(shape_list[shape_id]->getArea());
    shape_distrb.normalize();

    // assign shape pointer
    for ( const auto &emitter : emitter_list )
        emitter->shape_ptr = emitter->shape_id < 0 ? nullptr : shape_list[emitter->shape_id];
    initTetmesh();

    m_state = ESConfigured;
}


void Scene::initTetmesh()
{
    for ( int i = 0; i < medium_list.size(); i++ ) {
        std::vector<Vector3> vertices;
        std::vector<Vector3i> faces;
        std::vector<std::pair<int, int>> tet_ids;

        for ( int j = 0; j < shape_list.size(); j++ ) {
            const Shape *shape = getShape(j);
            if ( shape->med_ext_id == i || shape->med_int_id == i ) {
                int size = static_cast<int>(vertices.size());
                auto indices = shape->indices;
                for ( auto &idx : indices )
                    idx = idx + Vector3i(size, size, size);
                vertices.insert(vertices.end(), shape->vertices.begin(), shape->vertices.end());
                faces.insert(faces.end(), indices.begin(), indices.end());
                for ( int k = 0; k < shape->vertices.size(); k++ )
                    tet_ids.push_back({j, k});
            }
        }
        getMedium(i)->setTetmesh(vertices, faces, tet_ids);
    }
}


Scene::~Scene()
{
    delete_vector(shape_list);
    delete_vector(bsdf_list);
    delete_vector(emitter_list);
    delete_vector(phase_list);
    delete_vector(medium_list);
    rtcReleaseScene(embree_scene);
    rtcReleaseDevice(embree_device);
}


void Scene::load_file(const char *file_name, bool auto_configure)
{
    psdr::SceneLoader::load_from_file(file_name, *this);
    if ( auto_configure ) configure();
}


bool Scene::isVisible(
    const Vector &p, bool pOnSurface, const Vector &q, bool qOnSurface) const
{
    Vector dir = q - p;
    Float dist = dir.norm();
    dir /= dist;
    Float tfar = qOnSurface ? (1.0f - ShadowEpsilon) * dist : dist;

    RTCIntersectContext rtc_context;
    rtcInitIntersectContext(&rtc_context);
    RTCRayHit rtc_ray_hit;
    rtc_ray_hit.ray.org_x = p.x();
    rtc_ray_hit.ray.org_y = p.y();
    rtc_ray_hit.ray.org_z = p.z();
    rtc_ray_hit.ray.dir_x = dir.x();
    rtc_ray_hit.ray.dir_y = dir.y();
    rtc_ray_hit.ray.dir_z = dir.z();
    rtc_ray_hit.ray.mask = (unsigned int)(-1);
    rtc_ray_hit.ray.time = 0.f;
    rtc_ray_hit.ray.flags = 0;
    rtc_ray_hit.hit.geomID = RTC_INVALID_GEOMETRY_ID;
    rtc_ray_hit.hit.primID = RTC_INVALID_GEOMETRY_ID;
    rtc_ray_hit.hit.instID[0] = RTC_INVALID_GEOMETRY_ID;
    rtc_ray_hit.ray.tnear = pOnSurface ? ShadowEpsilon : 0.0f;
    rtc_ray_hit.ray.tfar = tfar;
    rtcIntersect1(embree_scene, &rtc_context, &rtc_ray_hit);

    while ( rtc_ray_hit.hit.geomID != RTC_INVALID_GEOMETRY_ID ) {
        int bsdf_index = shape_list[(int)rtc_ray_hit.hit.geomID]->bsdf_id;
        if ( !bsdf_list[bsdf_index]->isNull() )
            return false;
        else {
            Float tHit = rtc_ray_hit.ray.tfar + ShadowEpsilon;
            rtc_ray_hit.ray.org_x += dir.x() * tHit;
            rtc_ray_hit.ray.org_y += dir.y() * tHit;
            rtc_ray_hit.ray.org_z += dir.z() * tHit;
            rtc_ray_hit.ray.tnear = ShadowEpsilon;
            dist -= rtc_ray_hit.ray.tfar;
            rtc_ray_hit.ray.tfar = qOnSurface ? (1.0f - ShadowEpsilon) * dist : dist;
            rtc_ray_hit.hit.geomID = RTC_INVALID_GEOMETRY_ID;
            rtc_ray_hit.hit.primID = RTC_INVALID_GEOMETRY_ID;
            rtc_ray_hit.hit.instID[0] = RTC_INVALID_GEOMETRY_ID;
            rtcIntersect1(embree_scene, &rtc_context, &rtc_ray_hit);
        }
    }
    return true;
}


Float Scene::evalTransmittance(
    const Ray &ray, bool onSurface, const Medium *ptr_medium, Float remaining, RndSampler *sampler) const
{
    return evalTransmittance(ray.org, onSurface, ray.at(remaining), true, ptr_medium, sampler);
}


Float Scene::evalTransmittance(
    const Vector &p1, bool p1OnSurface, const Vector &p2, bool p2OnSurface, const Medium *medium,
    RndSampler *sampler) const
{
    Vector dir = p2 - p1;
    Float remaining = dir.norm();
    dir /= remaining;
    Float transmittance = 1.;
    Intersection its;
    Ray ray(p1, dir,
            p1OnSurface ? ShadowEpsilon : 0.,
            p2OnSurface ? remaining - ShadowEpsilon : remaining);

    int interactions;
    for ( interactions = 0; interactions < max_null_interactions && remaining > Epsilon; ++interactions ) {
        bool surface = rayIntersect(ray, its, ESpatial);

        // if hit a non-null surface, return 0
        if ( surface && !its.getBSDF()->isNull() ) return 0.0;

        // if inside a medium, evaluate transmittance
        if ( medium ) {
            transmittance *= medium->evalTransmittance(
                ray, 0, std::min(its.t, remaining), sampler);
        }

        // if hit nothing or the transmittance is 0, terminate
        if ( !surface || transmittance < Epsilon ) break;

        // update loop variables: ray, remaining, medium, interactions
        if ( its.isMediumTransition() ) {
            if ( medium != its.getTargetMedium(-dir) ) {
                // fprintf(stdcerr, "inconsistent medium transitions\n");
                return 0.;
            }
            medium = its.getTargetMedium(dir);
        }

        ray.org = ray(its.t);
        remaining -= its.t;
        ray.tmax = p2OnSurface ? remaining - ShadowEpsilon : remaining;
        ray.tmin = ShadowEpsilon;
    }
    if (interactions == max_null_interactions) {
        if ( verbose )
            fprintf(stderr, "Max null interactions (%d) reached. Dead loop?\n", max_null_interactions);
        // Statistics::getInstance().getCounter("Warning", "Null interactions") += 1;
        return 0.;
    }

    return transmittance;
}


// TODO: only works for homogeneous medium
std::pair<Float, Float> Scene::evalTransmittanceAndPdf(
    const Vector &p1, bool p1OnSurface, const Vector &p2, bool p2OnSurface,
    const Medium *medium, RndSampler *sampler) const
{
    Vector dir = p2 - p1;
    Float remaining = dir.norm();
    dir /= remaining;
    Float transmittance = 1.;
    Intersection its;
    Ray ray(p1, dir,
            p1OnSurface ? ShadowEpsilon : 0.,
            p2OnSurface ? remaining - ShadowEpsilon : remaining);

    Float pdf = 1.0;

    int interactions;
    for (interactions = 0; interactions < max_null_interactions && remaining > Epsilon; ++interactions) {
        bool surface = rayIntersect(ray, its, ESpatial);

        // if hit a non-null surface, return 0
        if ( surface && !its.getBSDF()->isNull() ) return {0., 0.};

        // if inside a medium, evaluate transmittance
        if ( medium ) {
            assert(medium->isHomogeneous());

            Float trans = medium->evalTransmittance(
                ray, 0, std::min(its.t, remaining), sampler);
            transmittance *= trans;
            Float sampling_weight = static_cast<const Homogeneous *>(medium)->sampling_weight;
            Spectrum sigma_s = static_cast<const Homogeneous *>(medium)->sigma_s;
            Float sigma_t = static_cast<const Homogeneous *>(medium)->sigma_t;

            if ( its.t < remaining + ShadowEpsilon ) {
                // hit a surface
                pdf *= 1 - sampling_weight * (1 - exp(sigma_t * (-its.t)));
            }
            else {
                // hit a medium
                pdf *= sampling_weight * sigma_t * trans;
            }
        }

        // if hit nothing or the transmittance is 0, terminate
        if ( !surface || transmittance < Epsilon ) break;

        // update loop variables: ray, remaining, medium, interactions
        if ( its.isMediumTransition() ) {
            if ( medium != its.getTargetMedium(-dir) ) {
                // fprintf(stdcerr, "inconsistent medium transitions\n");
                return {0.0, 0.0};
            }
            medium = its.getTargetMedium(dir);
        }

        ray.org = ray(its.t);
        remaining -= its.t;
        ray.tmax = p2OnSurface ? remaining - ShadowEpsilon : remaining;
        ray.tmin = ShadowEpsilon;
    }
    if (interactions == max_null_interactions) {
        if ( verbose )
            fprintf(stderr, "Max null interactions (%d) reached. Dead loop?\n", max_null_interactions);
        // Statistics::getInstance().getCounter("Warning", "Null interactions") += 1;
        return {0., 0.};
    }

    return {transmittance, pdf};
}


bool Scene::traceForSurface(const Ray &_ray, bool onSurface, Intersection &its) const
{
    Ray ray(_ray);
    while ( rayIntersect(ray, onSurface, its) ) {
        if ( !its.isValid() ) return false;
        if ( !its.ptr_bsdf->isNull() ) return true;
        ray.org = its.p;
        onSurface = true;
    }
    return false;
}


bool Scene::traceForMedium(
    const Ray &_ray, bool onSurface, const Medium *medium, Float targetTransmittance,
    RndSampler *sampler, MediumSamplingRecord &mRec) const
{
    Ray ray(_ray);
    mRec.medium = nullptr;
    mRec.transmittance = 1.0;
    Float distanceTraveled = 0.0;
    Intersection its;
    while (mRec.transmittance > targetTransmittance) {
        bool surface = rayIntersect(ray, onSurface, its);
        if ( medium ) {
            assert( medium->isHomogeneous() );
            Float transmittance = medium->evalTransmittance(
                ray, 0, std::min(its.t, 1e7), sampler);
            Float remainingTransmittance = targetTransmittance / mRec.transmittance;

            if ( transmittance > remainingTransmittance ) {
                // the equal transmittance point is outside the medium
                mRec.transmittance *= transmittance; 
            }
            else {
                // the equal transmittance point is inside the medium, find the antithetic point
                Float t = -std::log(remainingTransmittance) / medium->sigT(Vector());
                distanceTraveled += t;
                mRec.transmittance *= medium->evalTransmittance(ray, 0, t, sampler);
                mRec.medium = medium;
                mRec.p = _ray(distanceTraveled);
                mRec.t = distanceTraveled;
                return true;
            }
        }
        if ( !surface ) break;

        // update loop variables: ray, remaining, medium, interactions
        if ( its.isMediumTransition() ) {
            if (medium != its.getTargetMedium(-ray.dir)) {
                // fprintf(stdcerr, "inconsistent medium transitions\n");
                return 0.0;
            }
            medium = its.getTargetMedium(ray.dir);
        }
        distanceTraveled += its.t;
        ray.org = ray(its.t);
        onSurface = true;
    }
    return false;
}

// equal distance sampling
bool Scene::traceForMedium2(
    const Ray &_ray, bool onSurface, const Medium *medium, Float tarDist,
    RndSampler *sampler, MediumSamplingRecord &mRec) const
{
    Ray ray(_ray);
    mRec.medium = nullptr;
    mRec.transmittance = 1.0;
    Float distanceTraveled = 0.0;
    Intersection its;
    while ( distanceTraveled < tarDist ) {
        bool surface = rayIntersect(ray, onSurface, its);
        if ( medium ) {
            if ( tarDist > distanceTraveled + its.t ) {
                // the equal distance point is outside the medium
                mRec.transmittance *= medium->evalTransmittance(
                    ray, 0, its.t, sampler);
            }
            else {
                // the equal distance point is inside the medium, find the antithetic point
                Float t = tarDist - distanceTraveled;
                distanceTraveled += t;
                mRec.transmittance *= medium->evalTransmittance(ray, 0, t, sampler);
                mRec.medium = medium;
                mRec.p = _ray(distanceTraveled);
                mRec.t = distanceTraveled;
                return true;
            }
        }
        if ( !surface ) break;

        // update loop variables: ray, remaining, medium, interactions
        if ( its.isMediumTransition() ) {
            if ( medium != its.getTargetMedium(-ray.dir) ) {
                // fprintf(stdcerr, "inconsistent medium transitions\n");
                return 0.0;
            }
            medium = its.getTargetMedium(ray.dir);
        }
        distanceTraveled += its.t;
        ray.org = ray(its.t);
        onSurface = true;
    }
    return false;
}


Spectrum Scene::sampleEmitterDirect(
    const Vector2 &_rnd_light, DirectSamplingRecord &dRec) const
{
    Vector2 rnd_light(_rnd_light);
    Float emPdf;
    const int light_id = light_distrb.sampleReuse(rnd_light[0], emPdf);
    const Emitter *emitter = emitter_list[light_id];
    dRec.shape_id = emitter->shape_id;
    Spectrum value = emitter->sampleDirect(rnd_light, dRec);
    dRec.pdf *= detach(emPdf);
    // orientation test. medium or transmissive surface would set dRec.refN = 0
    if ( dRec.dir.dot(dRec.refN) < 0. )
        return Spectrum::Zero();
    // visibility test
    if (_isIntersect(embree_scene, Ray(dRec.ref, dRec.dir), (1 - ShadowEpsilon) * dRec.dist))
        return Spectrum::Zero();
    value /= detach(emPdf);
    return value;
}


// TODO need refactoring
Spectrum Scene::sampleEmitterDirect(
    const Intersection &its, const Vector2 &_rnd_light, RndSampler *sampler,
    Vector &wo, Float &pdf) const
{
    Vector2 rnd_light(_rnd_light);
    const int light_id = light_distrb.sampleReuse(rnd_light[0], pdf);
    const int shape_id = emitter_list[light_id]->getShapeID();
    if ( shape_id < 0 ) {
        // point light
        DirectSamplingRecord dRec(its.p);
        emitter_list[light_id]->sampleDirect(rnd_light, dRec);
        const Vector &light_pos = dRec.p;
        Vector dir = light_pos - its.p;
        Float dist = dir.norm();
        dir /= dist;
        if ((its.ptr_bsdf->isTransmissive() || its.ptr_bsdf->isTwosided() ||
             (dir.dot(its.geoFrame.n) > Epsilon &&
              dir.dot(its.shFrame.n) > Epsilon)))
        {
            if ( !_isIntersect(embree_scene, Ray(its.p, dir), dist) ) {
                wo = its.toLocal(dir);
                Float invDist = 1. / dist;
                Spectrum ret =
                    emitter_list[light_id]->getIntensity() * invDist * invDist / pdf;
                pdf = -1.0;
                return ret;
            }
        }
    }
    else {
        PositionSamplingRecord pRec;
        shape_list[shape_id]->samplePosition(rnd_light, pRec);
        pdf /= shape_list[shape_id]->getArea();
        const Vector &light_pos = pRec.p, &light_norm = pRec.n;
        Vector dir = light_pos - its.p;
        if ((its.ptr_bsdf->isTransmissive() || its.ptr_bsdf->isTwosided() ||
             (dir.dot(its.geoFrame.n) > Epsilon &&
              dir.dot(its.shFrame.n) > Epsilon)) &&
            dir.dot(light_norm) < -Epsilon && pdf > Epsilon)
        {
            Float dist = dir.norm();
            dir = dir / dist;
            if ( !_isIntersect(embree_scene, Ray(its.p, dir), (1 - ShadowEpsilon) * dist) ) {
                pdf *= dist * dist / light_norm.dot(-dir);
                wo = its.toLocal(dir);
                return emitter_list[light_id]->eval(light_norm, -dir) / detach(pdf);
            }
        }
    }
    return Spectrum::Zero();
}


Spectrum Scene::sampleAttenuatedEmitterDirect(
    const Vector &pscatter, const Vector2 &_rnd_light, RndSampler *sampler, const Medium *ptr_medium,
    Vector &wo, Float &pdf) const
{
    Vector2 rnd_light(_rnd_light);
    const int light_id = light_distrb.sampleReuse(rnd_light[0], pdf);
    const int shape_id = emitter_list[light_id]->getShapeID();
    const Emitter *emitter = emitter_list[light_id];
    DirectSamplingRecord dRec(pscatter);
    Spectrum value = emitter->sampleDirect(rnd_light, dRec);

    const Vector &light_pos = dRec.p, &light_norm = dRec.n;
    Vector dir = light_pos - pscatter;
    if ( (shape_id < 0 || dir.dot(light_norm) < 0) && pdf > Epsilon ) {
        Float dist = dir.norm();
        dir = dir / dist;
        Ray shadow_ray(pscatter, dir);
        Float transmittance = evalTransmittance(shadow_ray, 0.0, ptr_medium, dist,
                                                sampler);
        if ( transmittance > Epsilon ) {
            wo = dir;
            value /= pdf;
            pdf *= dRec.pdf;
            return value * transmittance;
        }
    }
    return Spectrum::Zero();
}


Spectrum Scene::sampleAttenuatedEmitterDirect(
    DirectSamplingRecord &dRec, const Vector2 &_rnd_light, RndSampler *sampler, const Medium *medium) const
{
    Vector2 rnd_light(_rnd_light);
    Float emPdf;
    const int light_id = light_distrb.sampleReuse(rnd_light[0], emPdf);
    const Emitter *emitter = emitter_list[light_id];
    dRec.shape_id = emitter->shape_id;
    Spectrum value = emitter->sampleDirect(rnd_light, dRec);
    dRec.pdf *= detach(emPdf);
    if ( dRec.pdf > Epsilon )
        value *= evalTransmittance(dRec.ref, false, dRec.p, true, medium, sampler) / detach(emPdf);
    else
        value.setZero();
    return value;
}


Spectrum Scene::sampleBoundaryAttenuatedEmitterDirect(
    DirectSamplingRecord &dRec, const Vector2 &_rnd_light, RndSampler *sampler, const Medium *medium) const
{
    Vector2 rnd_light(_rnd_light);
    Float emPdf;
    const int light_id = light_distrb.sampleReuse(rnd_light[0], emPdf);
    const Emitter *emitter = emitter_list[light_id];
    dRec.shape_id = emitter->shape_id;
    Spectrum value = emitter->sampleDirect(rnd_light, dRec);
    dRec.pdf *= detach(emPdf);
    if ( dRec.pdf > Epsilon )
        value *= evalTransmittance(dRec.ref, true, dRec.p, true, medium, sampler) / detach(emPdf);
    else
        value.setZero();
    return value;
}

Spectrum Scene::sampleAttenuatedEmitterDirect(
    DirectSamplingRecord &dRec, const Intersection &its, const Vector2 &_rnd_light,
    RndSampler *sampler, const Medium *medium) const
{
    Vector2 rnd_light(_rnd_light);
    Float emPdf;
    const int light_id = light_distrb.sampleReuse(rnd_light[0], emPdf);
    const Emitter *emitter = emitter_list[light_id];
    dRec.shape_id = emitter->shape_id;
    Spectrum value = emitter->sampleDirect(rnd_light, dRec);
    dRec.pdf *= detach(emPdf);

    if ( dRec.pdf > Epsilon ) {
        if ( its.isValid() && its.isMediumTransition() )
            medium = its.getTargetMedium(dRec.dir);
        value *= evalTransmittance(dRec.ref, true, dRec.p, true, medium, sampler) / detach(emPdf);
        return value;
    }
    return Spectrum::Zero();
}


Spectrum Scene::sampleAttenuatedEmitterDirect(
    const Intersection &its, const Vector2 &_rnd_light, RndSampler *sampler, const Medium *ptr_medium,
    Vector &wo, Float &pdf, bool flag) const
{
    Vector2 rnd_light(_rnd_light);
    const int light_id = light_distrb.sampleReuse(rnd_light[0], pdf);
    const int shape_id = emitter_list[light_id]->getShapeID();
    const Emitter *emitter = emitter_list[light_id];
    DirectSamplingRecord dRec(its.p);
    Spectrum value = emitter->sampleDirect(rnd_light, dRec);
    const Vector &light_pos = dRec.p, &light_norm = dRec.n;

    Vector dir = light_pos - its.p;
    if ((its.ptr_bsdf->isTransmissive() || its.ptr_bsdf->isTwosided() ||
         (dir.dot(its.geoFrame.n) > Epsilon &&
          dir.dot(its.shFrame.n) > Epsilon)) &&
        (shape_id < 0 || (dir.dot(light_norm) < -Epsilon && pdf > Epsilon)))
    {
        if (its.isMediumTransition())
            ptr_medium = its.getTargetMedium(dir);

        Float dist = dir.norm();
        dir = dir / dist;
        Ray shadow_ray(its.p, dir);
        wo = its.toLocal(dir);

        Float transmittance = evalTransmittance(shadow_ray, ShadowEpsilon, ptr_medium, dist, sampler);
        if ( transmittance > Epsilon ) {
            value /= pdf;
            pdf *= dRec.pdf;
            return transmittance * value;
        }
    }
    return Spectrum::Zero();
}


Float Scene::pdfEmitterSample(const Intersection &its) const
{
    int light_id = its.ptr_shape->light_id;
    assert( light_id >= 0 );
    return light_distrb[light_id] / its.ptr_shape->getArea();
}
INACTIVE_FN(Scene_pdfEmitterSample, &Scene::pdfEmitterSample);


bool Scene::rayIntersect(
    const Ray &ray, bool onSurface, Intersection &its, IntersectionMode mode) const
{
    bool isHit = _rayIntersect(embree_scene, ray, onSurface, its);
    if ( isHit ) {
        its.shape_id = its.indices[0];
        its.triangle_id = its.indices[1];
        its.ptr_shape = shape_list[its.indices[0]];
        its.ext_med_id = its.ptr_shape->med_ext_id;
        its.int_med_id = its.ptr_shape->med_int_id;
        its.ptr_med_int = (its.ptr_shape->med_int_id >= 0)
                              ? medium_list[its.ptr_shape->med_int_id]
                              : nullptr;
        its.ptr_med_ext = (its.ptr_shape->med_ext_id >= 0)
                              ? medium_list[its.ptr_shape->med_ext_id]
                              : nullptr;
        its.ptr_emitter = (its.ptr_shape->light_id >= 0)
                              ? emitter_list[its.ptr_shape->light_id]
                              : nullptr;
        assert(its.ptr_shape->bsdf_id >= 0);
        its.ptr_bsdf = bsdf_list[its.ptr_shape->bsdf_id];
        // Ray-Shape intersection
        its.ptr_shape->rayIntersect(its.indices[1], ray, its,
                                    mode);
    }
    return isHit;
}


bool Scene::rayIntersect(
    const Ray &ray, Intersection &its, IntersectionMode mode) const
{
    bool isHit = _rayIntersect2(embree_scene, ray, its);
    if ( isHit ) {
        its.shape_id = its.indices[0];
        its.triangle_id = its.indices[1];
        its.ptr_shape = shape_list[its.indices[0]];
        its.ext_med_id = its.ptr_shape->med_ext_id;
        its.int_med_id = its.ptr_shape->med_int_id;
        its.ptr_med_int = (its.ptr_shape->med_int_id >= 0)
                              ? medium_list[its.ptr_shape->med_int_id]
                              : nullptr;
        its.ptr_med_ext = (its.ptr_shape->med_ext_id >= 0)
                              ? medium_list[its.ptr_shape->med_ext_id]
                              : nullptr;
        its.ptr_emitter = (its.ptr_shape->light_id >= 0)
                              ? emitter_list[its.ptr_shape->light_id]
                              : nullptr;
        assert(its.ptr_shape->bsdf_id >= 0);
        its.ptr_bsdf = bsdf_list[its.ptr_shape->bsdf_id];
        // Ray-Shape intersection
        its.ptr_shape->rayIntersect(its.indices[1], ray, its,
                                    mode);
    }
    return isHit;
}

bool Scene::rayIntersectShape(
    const int shape_idx, const Ray &ray, Intersection &its, IntersectionMode mode) const
{
    bool isHit = _rayIntersectShape(shape_idx, embree_scene, ray, its);
    if ( isHit ) {
        assert(shape_idx == its.indices[0]);
        its.shape_id = its.indices[0];
        its.ptr_shape = shape_list[its.indices[0]];
        its.triangle_id = its.indices[1];
        its.ext_med_id = its.ptr_shape->med_ext_id;
        its.int_med_id = its.ptr_shape->med_int_id;
        its.ptr_med_int = (its.ptr_shape->med_int_id >= 0)
                              ? medium_list[its.ptr_shape->med_int_id]
                              : nullptr;
        its.ptr_med_ext = (its.ptr_shape->med_ext_id >= 0)
                              ? medium_list[its.ptr_shape->med_ext_id]
                              : nullptr;
        its.ptr_emitter = (its.ptr_shape->light_id >= 0)
                              ? emitter_list[its.ptr_shape->light_id]
                              : nullptr;
        assert(its.ptr_shape->bsdf_id >= 0);
        its.ptr_bsdf = bsdf_list[its.ptr_shape->bsdf_id];
        // Ray-Shape intersection
        its.ptr_shape->rayIntersect(its.indices[1], ray, its,
                                    mode);
    }
    return isHit;
}


std::vector<Intersection> Scene::rayIntersectAll(const Ray &_ray, bool onSurface) const
{
    std::vector<Intersection> its_list;
    Ray ray(_ray);
    bool surface = false;
    Intersection its; // use _its to store the first intersection, its2 to keep track of the folowing ones

    int interactions;
    for (interactions = 0; interactions < max_null_interactions; interactions++)
    {
        surface = rayIntersect(ray, onSurface, its);

        // if hit an occluder or light source
        if (surface && (!its.getBSDF()->isNull() || its.isEmitter()))
            break;

        // hit nothing
        if (!surface)
            break;

        assert(its.getBSDF()->isNull());

        ray.org = ray(its.t);
        onSurface = true;
        its_list.push_back(its);
    }
    if (interactions == max_null_interactions)
    {
        if (verbose)
            fprintf(stderr, "Max null interactions (%d) reached. Dead loop?\n", max_null_interactions);
        // Statistics::getInstance().getCounter("Warning", "Null interactions") += 1;
        return its_list;
    }

    return its_list;
}

bool Scene::trace(const Ray &_ray, RndSampler *sampler, int max_depth, int med_id,
                  BoundaryEndpointSamplingRecord &bERec) const
{
    Ray ray(_ray);
    int interactions = 0;
    Spectrum throughput = Spectrum::Ones();
    Intersection its;
    const Medium *medium = getMedium(med_id);
    while ( interactions < max_depth ) {
        if ( !rayIntersect(ray, true, its) ) return false;
        MediumSamplingRecord mRec;
        bool inside_med = medium != nullptr &&
                          medium->sampleDistance(ray, its.t, sampler, mRec);

        if ( inside_med ) {
            // medium interaction
            throughput *= mRec.sigmaS * mRec.transmittance / mRec.pdfSuccess;
            bERec.onSurface = false;
            bERec.mRec = mRec;
            bERec.interactions = interactions;
            bERec.p = mRec.p;
            bERec.wi = -ray.dir;
            bERec.throughput = throughput;
            bERec.pdf = mRec.pdfSuccess;
            bERec.med_id = med_id;
            bERec.tet_id = medium->m_tetmesh.in_element(bERec.p);
            if (bERec.tet_id == -1) return false;
            bERec.barycentric4 = medium->m_tetmesh.getBarycentric(bERec.tet_id, bERec.p);
            return true;
        }

        // surface interaction
        if ( medium ) throughput *= mRec.transmittance / mRec.pdfFailure;
        if ( !its.getBSDF()->isNull() ) {
            Float gnDotD = its.geoFrame.n.dot(-ray.dir);
            Float snDotD = its.shFrame.n.dot(-ray.dir);
            bool valid = (its.ptr_bsdf->isTransmissive() &&
                          math::signum(gnDotD) * math::signum(snDotD) > .5) ||
                         (!its.ptr_bsdf->isTransmissive() &&
                          gnDotD > Epsilon && snDotD > Epsilon);
            if ( !valid ) return false;

            bERec.onSurface = true;
            bERec.its = its;
            bERec.interactions = interactions;
            bERec.p = its.p;
            bERec.wi = -ray.dir;
            bERec.n = its.geoFrame.n;
            bERec.sh_n = its.shFrame.n;
            bERec.throughput = throughput;
            bERec.pdf = mRec.pdfFailure;
            bERec.shape_id = its.indices[0];
            bERec.tri_id = its.indices[1];
            bERec.barycentric3 = Vector(1 - its.barycentric.sum(),
                                        its.barycentric[0],
                                        its.barycentric[1]);
            return true;
        }

        // null surface
        ray.org = its.p;
        if ( its.isMediumTransition() ) {
            med_id = its.getTargetMediumId(ray.dir);
            medium = its.getTargetMedium(ray.dir);
        }
        interactions++;
    }
    // reach max depth
    return false;
}


Spectrum Scene::rayIntersectAndLookForEmitter(
    const Ray &_ray, bool onSurface, RndSampler *sampler, const Medium *medium,
    Intersection &_its, DirectSamplingRecord &dRec) const
{
    Ray ray(_ray);
    Float transmittance = 1.0;
    bool surface = false;
    Intersection its2, *its = &_its; // use _its to store the first intersection, its2 to keep track of the folowing ones

    int interactions;
    for ( interactions = 0; interactions < max_null_interactions; interactions++ ) {
        surface = rayIntersect(ray, onSurface, *its);

        if (medium)
            transmittance *= medium->evalTransmittance(ray, 0.0, its->t, sampler);

        // if hit an occluder or light source
        if (surface && (!its->getBSDF()->isNull() || its->isEmitter()))
            break;

        // hit nothing
        if (!surface) break;

        assert( its->getBSDF()->isNull() );
        if ( transmittance < Epsilon ) return Spectrum::Zero();

        if (its->isMediumTransition())
            medium = its->getTargetMedium(ray.dir);

        ray.org = ray(its->t);
        its = &its2;
        onSurface = true;
    }
    if (interactions == max_null_interactions) {
        if ( verbose )
            fprintf(stderr, "Max null interactions (%d) reached. Dead loop?\n", max_null_interactions);
        // Statistics::getInstance().getCounter("Warning", "Null interactions") += 1;
        return Spectrum::Zero();
    }

    if (surface) {
        if ( its->isEmitter() ) {
            dRec.shape_id = its->indices[0];
            dRec.tri_id = its->indices[1];
            dRec.barycentric = its->barycentric;
            dRec.ref = _ray.org;
            dRec.p = its->p;
            dRec.n = its->geoFrame.n;
            dRec.dist = its->t;
            dRec.G = geometric(dRec.ref, dRec.p, dRec.n);
            dRec.measure = EMArea;
            dRec.interactions = interactions;
            return transmittance * its->Le(-ray.dir);
        }
    }

    return Spectrum::Zero();
}

Spectrum Scene::sampleEmitterPosition(
    const Vector2 &_rnd_light, Intersection &its, Float *_pdf) const
{
    Vector2 rnd_light(_rnd_light);
    Float pdf;

    const int light_id = light_distrb.sampleReuse(rnd_light[0], pdf);
    const int shape_id = emitter_list[light_id]->getShapeID();
    if ( shape_id < 0 ) {
        its.indices[0] = shape_id;
        its.ptr_emitter = emitter_list[light_id];
        its.ptr_med_ext = nullptr;
        PositionSamplingRecord pRec;
        its.ptr_emitter->samplePosition(rnd_light, pRec);
        its.p = pRec.p;
        its.wi = Vector(Vector::Zero());
        its.t = 0.0f;
        if ( _pdf ) *_pdf = pdf;
        return its.ptr_emitter->getIntensity() / pdf;
    }
    else {
        PositionSamplingRecord pRec;
        its.type = EVEmitter;
        its.indices[0] = shape_id;
        its.indices[1] = shape_list[shape_id]->samplePosition(rnd_light, pRec);
        its.shape_id = shape_id;
        its.triangle_id = its.indices[1];
        its.barycentric = pRec.uv;
        pdf /= shape_list[shape_id]->getArea();

        its.ptr_emitter = emitter_list[light_id];
        its.ptr_shape = shape_list[its.ptr_emitter->getShapeID()];
        its.ptr_bsdf = getBSDF(its.ptr_shape->bsdf_id);

        int med_id = its.ptr_shape->med_ext_id;
        its.ptr_med_ext = med_id >= 0 ? medium_list[med_id] : nullptr;
        med_id = its.ptr_shape->med_int_id;
        its.ptr_med_int = med_id >= 0 ? medium_list[med_id] : nullptr;

        its.p = pRec.p;
        its.wi = Vector::Zero();
        its.t = 0.0f;
        its.geoFrame = its.shFrame = Frame(pRec.n);
        if ( _pdf ) *_pdf = pdf;
        its.pdf = pdf;
        its.J = 1.0f;
        return its.ptr_emitter->getIntensity() / pdf;
    }
}


Spectrum Scene::sampleEmitterPosition(
    const Vector2 &rnd_light, Intersection &its, EEmitter type) const
{
    Vector2 rnd(rnd_light);
    Float pdf;
    int light_id;
    int shape_id;
    switch ( type ) {
    case EAreaLight:
        light_id = area_light_distrb.sampleReuse(rnd[0], pdf);
        shape_id = area_emitter_list[light_id]->getShapeID();
        its.ptr_emitter = area_emitter_list[light_id];
        break;
    case EPoint:
        light_id = point_light_distrb.sampleReuse(rnd[0], pdf);
        shape_id = point_emitter_list[light_id]->getShapeID();
        its.ptr_emitter = point_emitter_list[light_id];
        break;
    default:
        light_id = light_distrb.sampleReuse(rnd[0], pdf);
        shape_id = emitter_list[light_id]->getShapeID();
        its.ptr_emitter = emitter_list[light_id];
        break;
    }
    if ( shape_id < 0 ) {
        assert(false);
        its.indices[0] = shape_id;
        its.ptr_med_ext = nullptr;
        PositionSamplingRecord pRec;
        its.ptr_emitter->samplePosition(rnd_light, pRec);
        its.p = pRec.p;
        its.wi = Vector(Vector::Zero());
        its.t = 0.0f;
        return its.ptr_emitter->getIntensity() / pdf;
    }
    else {
        PositionSamplingRecord pRec;
        its.indices[0] = shape_id;
        its.indices[1] = shape_list[shape_id]->samplePosition(rnd_light, pRec);
        its.barycentric = pRec.uv;
        pdf /= shape_list[shape_id]->getArea();

        its.ptr_shape = shape_list[its.ptr_emitter->getShapeID()];
        assert(its.ptr_shape != nullptr);

        int med_id = its.ptr_shape->med_ext_id;
        its.ptr_med_ext = med_id >= 0 ? medium_list[med_id] : nullptr;
        med_id = its.ptr_shape->med_int_id;
        its.ptr_med_int = med_id >= 0 ? medium_list[med_id] : nullptr;

        its.p = pRec.p;
        its.wi = Vector::Zero();
        its.t = 0.0f;
        its.geoFrame = its.shFrame = Frame(pRec.n);
        return its.ptr_emitter->getIntensity() / pdf;
    }
}

Array4 Scene::sampleAttenuatedSensorDirect(
    const Intersection &its, RndSampler *sampler, Matrix2x4 &pixel_uvs, Vector &dir) const
{
    Array4 weights(0.);
    camera.sampleDirect(its.p, pixel_uvs, weights, dir);
    if ( !weights.isZero(Epsilon) ) {
        const Medium *ptr_medium = getMedium(its.medium_id);
        if (its.ptr_shape && its.isMediumTransition())
            ptr_medium = its.getTargetMedium(dir);
        Float dist = (its.p - camera.cpos).norm();
        weights *= evalTransmittance(Ray(its.p, dir), true, ptr_medium, dist, sampler);
    }
    else
        weights.setZero();
    return weights;
}

Array4 Scene::sampleAttenuatedSensorDirect(
    const Vector &p, const Medium *ptr_med, RndSampler *sampler,
    Matrix2x4 &pixel_uvs, Vector &dir) const
{
    Array4 weights(0.0);
    camera.sampleDirect(p, pixel_uvs, weights, dir);
    if ( !weights.isZero(Epsilon) ) {
        Float dist = (p - camera.cpos).norm();
        weights *= evalTransmittance(Ray(p, dir), false, ptr_med, dist, sampler);
    }
    return weights;
}

std::tuple<Float, int, Vector> Scene::sampleAttenuatedSensorDirect(const Intersection &its, RndSampler *sampler) const
{
    int pixel_idx;
    Vector dir;
    Float value = camera.sampleDirect(its.p, sampler->next1D(), pixel_idx, dir);
    if ( value > Epsilon ) {
        const Medium *ptr_medium = its.getTargetMedium(dir);
        Float dist = (its.p - camera.cpos).norm();
        return {value * evalTransmittance(Ray(its.p, dir), true, ptr_medium, dist, sampler), pixel_idx, dir};
    }
    return {0., -1, Vector::Zero()};
}

std::tuple<Float, int, Vector> Scene::sampleAttenuatedSensorDirect(const Vector &p, const Medium *ptr_med, RndSampler *sampler) const
{
    int pixel_idx;
    Vector dir;    
    Float value = camera.sampleDirect(p, sampler->next1D(), pixel_idx, dir);
    if ( value > Epsilon ) {
        Float dist = (p - camera.cpos).norm();
        return {value * evalTransmittance(Ray(p, dir), false, ptr_med, dist, sampler), pixel_idx, dir};
    }
    return {0., -1, Vector::Zero()};
}

Float Scene::sampleAttenuatedSensorDirect(
    const Intersection &its, RndSampler *sampler, CameraDirectSamplingRecord &cRec) const
{
    if (camera.sampleDirect(its.p, cRec) && cRec.baseVal > Epsilon)
    {
        const Vector &dir = cRec.dir;
        const Medium *ptr_medium = its.getTargetMedium(dir);
        Float dist = (its.p - camera.cpos).norm();
        return evalTransmittance(Ray(its.p, dir), true, ptr_medium, dist, sampler);
    }
    return 0.;
}

Float Scene::sampleAttenuatedSensorDirect(
    const Vector &p, const Medium *ptr_med, RndSampler *sampler, CameraDirectSamplingRecord &cRec) const
{
    if (camera.sampleDirect(p, cRec) && cRec.baseVal > Epsilon)
    {
        const Vector &dir = cRec.dir;
        Float dist = (p - camera.cpos).norm();
        return evalTransmittance(Ray(p, dir), false, ptr_med, dist, sampler);
    }
    return 0.;
}

Vector2i Scene::samplePosition(
    const Vector2 &_rnd2, PositionSamplingRecord &pRec) const
{
    Vector2i ret;
    Vector2 rnd2(_rnd2);
    ret[0] = static_cast<int>(shape_distrb.sampleReuse(rnd2[0]));
    ret[1] = shape_list[ret[0]]->samplePosition(rnd2, pRec);
    return ret;
}


std::tuple<Vector, Vector, Float> Scene::getPoint(const Intersection &its) const
{
    assert( its.indices[0] >= 0 );
    assert( its.indices[0] < static_cast<int>(shape_list.size()) );
    const Shape &shape = *shape_list[its.indices[0]];
    return shape.getPoint(its.indices[1], its.barycentric);
}


void Scene::getPoint(const Intersection &its, Intersection &itsAD) const
{
    assert( its.indices[0] >= 0 );
    assert( its.indices[0] < static_cast<int>(shape_list.size()) );
    itsAD.ptr_shape = its.ptr_shape;
    itsAD.ptr_med_int = its.ptr_med_int;
    itsAD.ptr_med_ext = its.ptr_med_ext;
    itsAD.ptr_bsdf = its.ptr_bsdf;
    itsAD.ptr_emitter = its.ptr_emitter;
    itsAD.uv = its.uv;
    itsAD.indices = its.indices;
    shape_list[its.indices[0]]->getPoint(its, itsAD);
}


void Scene::sampleEdgePoint(
    const Float &_rnd, const DiscreteDistribution &edge_dist, const std::vector<Vector2i> &edge_ind,
    EdgeSamplingRecord &eRec) const
{
    Float rnd(_rnd);
    Float pdf = 1.f;
    assert(edge_dist.getSum() > Epsilon);
    int i = edge_dist.sampleReuse(rnd, pdf);
    const Vector2i &ind = edge_ind.at(i);
    int shape_id = ind[0],
        edge_id = ind[1];
    const Shape *shape = shape_list[shape_id];
    const Edge &edge = shape->edges[edge_id];
    if (edge.f0 < 0) {
        std::cerr << "Error: edge.f0 < 0" << std::endl;
        eRec.shape_id = -1;
        return;
    }
    if (edge.v0 < 0 || edge.v1 < 0 ||
        edge.v0 > shape->vertices.size() ||
        edge.v1 > shape->vertices.size())
    {
        std::cerr << "Error: edge.v0 < 0 || edge.v1 < 0 || edge.v0 > shape->vertices.size() || edge.v1 > shape->vertices.size()" << std::endl;
        eRec.shape_id = -1;
        return;
    }
    assert(edge.f0 >= 0);
    pdf /= edge.length;
    const Vector &v0 = shape->vertices[edge.v0],
                 &v1 = shape->vertices[edge.v1];

    int med_id = -1;
    // FIXME: need to handle planes
    if (edge.f1 >= 0) {
        if (shape->med_ext_id >= 0 ||
            shape->med_ext_id >= 0)
        {
            const Vector &v2 = shape->getVertex(edge.v2);
            Vector v = v2 - v0;
            Vector n0 = shape->getGeoNormal(edge.f0),
                   n1 = shape->getGeoNormal(edge.f1);
            Vector n = n0 + n1;
            if (v.dot(n) > 0)
                med_id = shape->med_int_id;
            else
                med_id = shape->med_ext_id;
        }
    }
    else
        med_id = shape->med_int_id;

    eRec.pdf = detach(pdf);
    eRec.edge_id = edge_id;
    eRec.shape_id = shape_id;
    eRec.t = detach(rnd);
    eRec.ref = v0 + (v1 - v0) * eRec.t;
    eRec.med_id = med_id;
}

/* ========================= Normal precomputation ========================= */
#ifdef NORMAL_PREPROCESS
/* normal related */
void d_precompute_normal(const Scene &scene, Scene &d_scene)
{
    assert(scene.shape_list.size() == d_scene.shape_list.size());
    for (int i = 0; i < scene.shape_list.size(); i++)
        d_precompute_normal(*scene.shape_list[i], *d_scene.shape_list[i]);
}
#endif
