#include "algorithm1.h"
#include <render/scene.h>

NAMESPACE_BEGIN(algorithm1_bdpt)

void __evalVertex(const Scene& scene, LightPath& path, int index)
{
    const Camera& camera = scene.camera;
    Float G = 1.0f;
    Vector dir;
    if (index == 0) {
        if (!path.isCameraPath) {
            int emitter_id = path[0].ptr_shape->light_id;
            const Emitter* ptr_emitter = scene.emitter_list[emitter_id];
            path[0].value = ptr_emitter->getIntensity() / path[0].pdf * path[0].J;  //may need to change this if not using area light
        } else {
            path[0].value = camera.eval(path.pixelIdx[0], path.pixelIdx[1], path[0].p, path[0].geoFrame.n) / path[0].pdf * path[0].J;
        }
    } else {
        dir = path[index].p - path[index-1].p;
        Float dist = dir.norm();
        dir /= dist;
        G = std::abs(path[index].geoFrame.n.dot(dir)) / (dist * dist);
        if ( index == 1 && !path.isCameraPath) {
            int emitter_id = path[index-1].ptr_shape->light_id;
            const Emitter* ptr_emitter = scene.emitter_list[emitter_id];
            path[index].value = ptr_emitter->evalDirection(path[index-1].geoFrame.n, dir);
        } else {
            int bsdf_id = path[index-1].ptr_shape->bsdf_id;
            const BSDF* ptr_bsdf = scene.bsdf_list[bsdf_id];
            path[index].value = ptr_bsdf->eval(path[index-1], path[index-1].toLocal(dir),
                                               path.isCameraPath ? ERadiance : EImportanceWithCorrection);
        }
        path[index].value *= path[index].J * G / path[index].pdf;
        path[index].value *= path[index-1].value;
    }

    if (path.isCameraPath && index <= 2) {
        if (path.antithetic_vtx[index].type == EVInvalid) return;
        Intersection& its = (index == 2) ? path[index]
                                         : path.antithetic_vtx[index];
        if (index == 0) {
            its.value = camera.eval(path.pixelIdx[0], path.pixelIdx[1], its.p, its.geoFrame.n) / its.pdf * its.J;
        } else {
            Intersection& its_prev = path.antithetic_vtx[index-1];
            if (index == 1) {
                dir = its.p - its_prev.p;
                Float dist = dir.norm();
                dir /= dist;
                G = std::abs(its.geoFrame.n.dot(dir)) / (dist * dist);
            }
            int bsdf_id = its_prev.ptr_shape->bsdf_id;
            const BSDF* ptr_bsdf = scene.bsdf_list[bsdf_id];
            Spectrum bsdf_val = ptr_bsdf->eval(its_prev, its_prev.toLocal(dir));
            Spectrum contrb_anti = its_prev.value * bsdf_val * G * its.J / path[index].pdf;
            if (index == 2)
                path[index].value += contrb_anti;
            else
                its.value = contrb_anti;
        }
    }
}

void d_evalVertex(SceneAD &sceneAD, LightPathAD& pathAD, int index)
{
    auto &d_scene = sceneAD.gm.get(omp_get_thread_num());
    __enzyme_autodiff((void *)__evalVertex,
                      enzyme_dup, &sceneAD.val, &d_scene,
                      enzyme_dup, &pathAD.val, &pathAD.der,
                      enzyme_const, index);
}

void __evalPath(const Scene &scene, const LightPath& camera_path, int s, const LightPath& light_path, int t,
                Float w, bool antithetic_success, Spectrum& value, Float hack)
{
    if ( s == -1 ) {
        const Camera& camera = scene.camera;
        const Intersection& v_lgt = light_path[t];
        Vector dir = (camera.cpos - v_lgt.p).normalized();
        int bsdf_id = v_lgt.ptr_shape->bsdf_id;
        Spectrum bsdf_val = scene.bsdf_list[bsdf_id]->eval(v_lgt, v_lgt.toLocal(dir), EBSDFMode::EImportanceWithCorrection);
        Float camera_val = camera.eval(light_path.pixelIdx[0], light_path.pixelIdx[1], v_lgt.p);
        value = v_lgt.value * camera_val * bsdf_val;
        if (antithetic_success) {
            const Intersection& v_lgt2 = light_path.antithetic_vtx[0];
            Spectrum val_anti = light_path[t-1].value;
            Float G = geometric(light_path[t-1].p, v_lgt2.p, v_lgt2.geoFrame.n);
            Vector dir = (v_lgt2.p - light_path[t-1].p).normalized();
            if (t == 1) {
                int emitter_id = light_path[t-1].ptr_shape->light_id;
                val_anti *=  scene.emitter_list[emitter_id]->evalDirection(light_path[t-1].geoFrame.n, dir);
            } else {
                int bsdf_id = light_path[t-1].ptr_shape->bsdf_id;
                val_anti *= scene.bsdf_list[bsdf_id]->eval(light_path[t-1], light_path[t-1].toLocal(dir), EBSDFMode::EImportanceWithCorrection);
            }
            val_anti *= G * v_lgt2.J / light_path[t].pdf;      // check here

            Float camera_val2 = camera.eval(light_path.pixelIdx[0], light_path.pixelIdx[1], v_lgt2.p);
            dir = (camera.cpos - v_lgt2.p).normalized();
            bsdf_id = v_lgt2.ptr_shape->bsdf_id;
            Spectrum bsdf_val2 = scene.bsdf_list[bsdf_id]->eval(v_lgt2, v_lgt2.toLocal(dir), EBSDFMode::EImportanceWithCorrection);
            val_anti *= camera_val2 * bsdf_val2;
            // val_anti *= camera.geometric(v_lgt.p, v_lgt.geoFrame.n) / camera.geometric(v_lgt2.p, v_lgt2.geoFrame.n);
            val_anti *= hack;
            value += val_anti;
        }
    } else if ( t == -1 ) {
        const Camera& camera = scene.camera;
        const Intersection& v_cam = camera_path[s];
        const Vector& pre_p = (s == 0) ? camera.cpos : camera_path[s-1].p;
        Vector dir = (pre_p - v_cam.p).normalized();
        int emitter_id = v_cam.ptr_shape->light_id;
        value = scene.emitter_list[emitter_id]->eval(v_cam.geoFrame.n, dir) * v_cam.value;
        if (antithetic_success  && s <= 1) {
            const Intersection& v_cam2 = camera_path.antithetic_vtx[s];
            const Vector& pre_p2 = (s == 0) ? camera.cpos : camera_path.antithetic_vtx[s-1].p;
            Vector dir = (pre_p2 - v_cam2.p).normalized();
            int emitter_id = v_cam2.ptr_shape->light_id;
            value += scene.emitter_list[emitter_id]->eval(v_cam2.geoFrame.n, dir) * v_cam2.value;
        }
    } else {
        const Intersection& v_lgt = light_path[t];
        const Intersection& v_cam = camera_path[s];
        Vector dir = v_lgt.p - v_cam.p;
        Float dist2 = dir.squaredNorm();
        dir /= std::sqrt(dist2);
        Spectrum seg_lgt(0.f);
        if ( t == 0 ) {
            int emitter_id = v_lgt.ptr_shape->light_id;
            seg_lgt = scene.emitter_list[emitter_id]->evalDirection(v_lgt.geoFrame.n, -dir);
        } else {
            int bsdf_id = v_lgt.ptr_shape->bsdf_id;
            seg_lgt = scene.bsdf_list[bsdf_id]->eval(v_lgt, v_lgt.toLocal(-dir), EBSDFMode::EImportanceWithCorrection);
        }
        seg_lgt /= dist2;
        int bsdf_id = v_cam.ptr_shape->bsdf_id;
        value = seg_lgt * scene.bsdf_list[bsdf_id]->eval(v_cam, v_cam.toLocal(dir)) * v_lgt.value * v_cam.value;

        if (s <= 1 && antithetic_success) {
            const Intersection& v_cam2 = camera_path.antithetic_vtx[s];
            Spectrum val_anti = v_cam2.value * v_lgt.value;
            if (s == 0) {
                Vector dir = v_lgt.p - v_cam2.p;
                Float dist2 = dir.squaredNorm();
                dir /= std::sqrt(dist2);
                if ( t == 0 ) {
                    int emitter_id = v_lgt.ptr_shape->light_id;
                    val_anti *= scene.emitter_list[emitter_id]->evalDirection(v_lgt.geoFrame.n, -dir);
                } else {
                    int bsdf_id = v_lgt.ptr_shape->bsdf_id;
                    val_anti *= scene.bsdf_list[bsdf_id]->eval(v_lgt, v_lgt.toLocal(-dir), EBSDFMode::EImportanceWithCorrection);
                }
                int bsdf_id = v_cam2.ptr_shape->bsdf_id;
                val_anti *= scene.bsdf_list[bsdf_id]->eval(v_cam2, v_cam2.toLocal(dir)) / dist2;
            } else if (s == 1) {
                int bsdf_id = v_cam2.ptr_shape->bsdf_id;
                val_anti *= seg_lgt * scene.bsdf_list[bsdf_id]->eval(v_cam2, v_cam2.toLocal(dir));
            }
            value += val_anti;
        }
    }
    value *= w;
}

Spectrum d_evalPath(SceneAD &sceneAD, LightPathAD& cameraPathAD, int s, LightPathAD& lightPathAD, int t,
                    Float w, bool antithetic_success, Spectrum d_value)
{
    auto &d_scene = sceneAD.gm.get(omp_get_thread_num());
    Spectrum value = Spectrum::Ones();
    Float hack = 0.0f;      // hack because of enzyme detach bug (?)
    if (s == -1 && antithetic_success) {
        const Intersection& v_lgt = lightPathAD.val[t];
        const Intersection& v_lgt2 = lightPathAD.val.antithetic_vtx[0];
        const Camera& camera = sceneAD.val.camera;
        hack = camera.geometric(v_lgt.p, v_lgt.geoFrame.n) / camera.geometric(v_lgt2.p, v_lgt2.geoFrame.n);
    }
    __enzyme_autodiff((void *)__evalPath,
                      enzyme_dup, &sceneAD.val, &d_scene,
                      enzyme_dup, &cameraPathAD.val, &cameraPathAD.der,
                      enzyme_const, s,
                      enzyme_dup, &lightPathAD.val, &lightPathAD.der,
                      enzyme_const, t,
                      enzyme_const, w,
                      enzyme_const, antithetic_success,
                      enzyme_dup, &value, &d_value,
                      enzyme_const, hack);
    return value;
}

void __getPoint(const Scene& scene, LightPath& path, int index, bool antithetic) {
    Intersection &v = antithetic ? (path.isCameraPath ? path.antithetic_vtx[index] : path.antithetic_vtx[0])
                                 : path[index];
    const Shape *shape = scene.shape_list[v.shape_id];
    const Vector3i &ind = shape->indices[v.triangle_id];
    const Vector &v0 = shape->getVertex(ind[0]),
                &v1 = shape->getVertex(ind[1]),
                &v2 = shape->getVertex(ind[2]);
    v.p = (1. - v.barycentric.x() - v.barycentric.y()) * v0 +
        v.barycentric.x() * v1 +
        v.barycentric.y() * v2;
    Vector geo_n = shape->getFaceNormal(v.triangle_id);
    Vector sh_n = shape->getShadingNormal(v.triangle_id, v.barycentric);
    v.geoFrame = Frame(geo_n);
    v.shFrame = Frame(sh_n);
    v.J = shape->getArea(v.triangle_id);
    v.J /= detach(v.J);

    if (path.isCameraPath) {
        const Vector& pre_p = (index == 0) ? scene.camera.cpos 
                                           : (antithetic ? path.antithetic_vtx[index-1].p : path[index-1].p);
        Vector dir = (pre_p - v.p).normalized();
        v.wi = v.toLocal(dir);
    } else {
        if (index > 0) {
            Vector dir = (path[index-1].p - v.p).normalized();
            v.wi = v.toLocal(dir);
        }
    }
}

void d_getPoint(SceneAD &sceneAD, LightPathAD& pathAD, int index, bool antithetic) {
    if (antithetic) {
        int idx = pathAD.val.isCameraPath ? index : 0;
        if (pathAD.val.antithetic_vtx[idx].type == EVInvalid)
            return;
    }

    auto &d_scene = sceneAD.gm.get(omp_get_thread_num());
    __enzyme_autodiff((void *)__getPoint,
                      enzyme_dup, &sceneAD.val, &d_scene,
                      enzyme_dup, &pathAD.val, &pathAD.der,
                      enzyme_const, index,
                      enzyme_const, antithetic);
}


NAMESPACE_END(algorithm1_bdpt)