#include "bdpt2.h"
#include <render/imageblock.h>
#include <core/ray.h>
#include <core/sampler.h>
#include <render/scene.h>
#include <core/timer.h>
#include <iomanip>
#include "algorithm1.h"

// #define BDPT_PARTICLE
// #define BDPT_PT

namespace
{
    struct BDPTRecord
    {
        BDPTRecord(RndSampler* rnd_sampler, int maxBounces, const Array2i& pixIdx, std::vector<Spectrum>& img)
        : sampler(rnd_sampler), max_bounces(maxBounces), pixelIdx(pixIdx), image(img) {}
        
        mutable RndSampler *sampler;
        int max_bounces;
        Array2i pixelIdx;
        std::vector<Spectrum> &image;

        mutable algorithm1_bdpt::LightPath light_path;
        mutable algorithm1_bdpt::LightPath camera_path;
        mutable algorithm1_bdpt::LightPath camera_path2;

        algorithm1_bdpt::LightPath &getLightPath() const { return light_path; }
        algorithm1_bdpt::LightPath &getCameraPath(bool antithetic = false) const { return antithetic ? camera_path2 : camera_path;}
        Float computeWeight(const Scene& scene, int s, int t, Float sensor_val, bool antithetic_success);
    };

    struct BDPTRecordAD : BDPTRecord
    {
        BDPTRecordAD(RndSampler* rnd_sampler, int maxBounces, const Array2i& pixIdx,
                     std::vector<Spectrum>& img, std::vector<Spectrum>& d_img, std::vector<Spectrum>& grad_img)
        : BDPTRecord(rnd_sampler, maxBounces, pixIdx, img), d_image(d_img), grad_image(grad_img)
        {
            // sampler = rnd_sampler;
            // max_bounces = maxBounces;
            // pixelIdx = pixIdx;
            // image = img;
            // d_image = d_img;
            // grad_image = grad_img;
        }

        std::vector<Spectrum> &d_image;
        std::vector<Spectrum> &grad_image;
    };

    template <typename T, bool ad>
    struct type_traits;

    template <>
    struct type_traits<Scene, false>
    {
        using type = const Scene;
    };
    template <>
    struct type_traits<Scene, true>
    {
        using type = SceneAD;
    };

    template <typename T>
    auto &value(T &t);

    template <>
    auto &value(const Scene &t)
    {
        return t;
    }

    template <>
    auto &value(SceneAD &t)
    {
        return t.val;
    }

    template <>
    struct type_traits<BDPTRecord, false>
    {
        using type = BDPTRecord;
    };
    template <>
    struct type_traits<BDPTRecord, true>
    {
        using type = BDPTRecordAD;
    };

    template <typename T, bool ad>
    using Type = typename type_traits<T, ad>::type;

    Float misRatio(Float pdf0, Float pdf1)
    {
        Float ret = pdf0/pdf1;
        ret *= ret;
        return ret;
    }

    void evalBDPTPath(const Scene& scene, const algorithm1_bdpt::LightPath& camera_path, int s, const algorithm1_bdpt::LightPath& light_path, int t,
                    Spectrum& contrb_all, bool& antithetic_success)
    {
        antithetic_success = camera_path.antithetic_vtx[std::min(2, s)].type != EVInvalid;
        if ( t == -1 ) {
            //assert(s > 0);
            const Intersection& v_cam = camera_path[s];
            const Vector& pre_p = (s == 0) ? scene.camera.cpos : camera_path[s-1].p;
            if (v_cam.isEmitter()) {
                Vector dir = (pre_p - v_cam.p).normalized();
                contrb_all = v_cam.Le(dir) * camera_path[s].value;
            }

            if (!contrb_all.isZero() && s <= 1 && antithetic_success) {
                const Intersection& v_cam2 = camera_path.antithetic_vtx[s];
                const Vector& pre_p2 = (s == 0) ? scene.camera.cpos : camera_path.antithetic_vtx[s-1].p;
                antithetic_success = false;
                if (v_cam2.isEmitter()) {
                    Vector dir = (pre_p2 - v_cam2.p).normalized();
                    Spectrum seg_anti = v_cam2.Le(dir);
                    if (!seg_anti.isZero()) {
                        antithetic_success = true;
                        contrb_all += seg_anti * v_cam2.value;
                    }
                }
            }
        } else {
            const Intersection& v_lgt = light_path[t];
            const Intersection& v_cam = camera_path[s];
            if (!scene.isVisible(v_lgt.p, true, v_cam.p, true)) return;

            Vector dir = v_lgt.p - v_cam.p;
            Float dist2 = dir.squaredNorm();
            dir /= std::sqrt(dist2);
            Spectrum seg_lgt = (t == 0) ? Spectrum::Ones() * v_lgt.ptr_emitter->evalDirection(v_lgt.geoFrame.n, -dir)
                                        : v_lgt.evalBSDF(v_lgt.toLocal(-dir), EBSDFMode::EImportanceWithCorrection);
            seg_lgt /= dist2;
            Spectrum seg_cam = v_cam.evalBSDF(v_cam.toLocal(dir), EBSDFMode::ERadiance);
            contrb_all = seg_cam * seg_lgt * v_cam.value * v_lgt.value;
            if (!contrb_all.isZero() && s <= 1 && antithetic_success) {
                const Intersection& v_cam2 = camera_path.antithetic_vtx[s];
                antithetic_success = false;
                if (s == 0) {
                    if (!scene.isVisible(v_lgt.p, true, v_cam2.p, true)) return;
                    Vector dir = v_lgt.p - v_cam2.p;
                    Float dist2 = dir.squaredNorm();
                    dir /= std::sqrt(dist2);
                    seg_lgt = (t == 0) ? Spectrum::Ones() * v_lgt.ptr_emitter->evalDirection(v_lgt.geoFrame.n, -dir)
                                    : v_lgt.evalBSDF(v_lgt.toLocal(-dir), EBSDFMode::EImportanceWithCorrection);
                    seg_lgt /= dist2;
                    seg_cam = v_cam2.evalBSDF(v_cam2.toLocal(dir), EBSDFMode::ERadiance);
                } else if (s == 1) {
                    seg_cam = v_cam2.evalBSDF(v_cam2.toLocal(dir), EBSDFMode::ERadiance);
                }

                Spectrum seg_anti = seg_lgt * seg_cam;
                if ( !seg_anti.isZero() ) {
                    antithetic_success = true;
                    contrb_all += seg_anti * v_cam2.value * v_lgt.value;
                }
            }
        }
    }



    void sampleLightPath(const Scene& scene, RndSampler* sampler, algorithm1_bdpt::LightPath& path)
    {
        int max_bounces = path.vertices.capacity() - 1;
        Float pdf;
        Ray ray;
        Intersection its = path[0];
        for (int ibounce = 0; ibounce < max_bounces; ibounce++) {
            Spectrum throughput(1.0f);
            if (ibounce == 0 && !path.isCameraPath) {
                throughput *= its.ptr_emitter->sampleDirection(sampler->next2D(), ray.dir, &pdf);
                if (its.ptr_shape != nullptr)
                    ray.dir = its.geoFrame.toWorld(ray.dir);
            } else {
                Float bsdf_eta;
                Vector wo_local, wo;
                EBSDFMode mode = path.isCameraPath ? EBSDFMode::ERadiance : EBSDFMode::EImportanceWithCorrection;
                throughput *= its.sampleBSDF(sampler->next3D(), wo_local, pdf, bsdf_eta, mode);
                wo = its.toWorld(wo_local);
                // check light leak
                Vector wi = its.toWorld(its.wi);
                Float wiDotGeoN = wi.dot(its.geoFrame.n), woDotGeoN = wo.dot(its.geoFrame.n);
                if (wiDotGeoN * its.wi.z() <= 0 || woDotGeoN * wo_local.z() <= 0)
                    throughput.setZero();
                ray.dir = wo;
            }
            ray.org = its.p;
            Intersection its_debug = its;
            if (throughput.isZero(Epsilon) || !scene.rayIntersect(ray, true, its)) break;
            Float G = geometric(ray.org, its.p, its.geoFrame.n);
            if (G < Epsilon || its.t < ShadowEpsilon) break;     // Hack for fixing numerical issues
            its.pdf = pdf * G;
            its.value = path[ibounce].value * throughput;
            path.append(its);
            if (ibounce <= 1) {
                Intersection& anti_its = path.antithetic_vtx[ibounce];
                if (anti_its.type == EVInvalid) continue;
                if (ibounce == 0) {
                    if (!scene.isVisible(anti_its.p, true, its.p, true)) continue;
                    Vector wo = its.p - anti_its.p;
                    Float dist2 = wo.squaredNorm();
                    wo /= std::sqrt(dist2);
                    Vector wo_local = anti_its.toLocal(wo);
                    Spectrum bsdf_val2 = anti_its.evalBSDF(wo_local);
                    if (bsdf_val2.isZero()) continue;
                    Float G2 = std::abs(wo.dot(its.geoFrame.n)) / dist2;
                    if ( G2 < Epsilon || wo.dot(anti_its.geoFrame.n) * wo_local.z() <= 0 ) continue;
                    path.antithetic_vtx[1] = its;
                    path.antithetic_vtx[1].value = anti_its.value * bsdf_val2 * G2 / its.pdf;
                    path.antithetic_vtx[1].pdf = anti_its.pdfBSDF(wo_local) * G2;
                    assert(path.antithetic_vtx[1].pdf != 0.f);
                    path.antithetic_vtx[1].wi = path.antithetic_vtx[1].toLocal(-wo);
                } else if (ibounce == 1) {
                    Vector wi = anti_its.toWorld(anti_its.wi);
                    if (wi.dot(anti_its.geoFrame.n) * anti_its.wi.z() <= 0) continue;
                    Vector wo_local = anti_its.toLocal(ray.dir);
                    Spectrum bsdf_val2 = anti_its.evalBSDF(wo_local);
                    if (bsdf_val2.isZero()) continue;
                    path.antithetic_vtx[2] = its;
                    path.antithetic_vtx[2].value = anti_its.value * bsdf_val2 * G / its.pdf;
                    path.antithetic_vtx[2].pdf = anti_its.pdfBSDF(wo_local) * G;
                    assert(path.antithetic_vtx[1].pdf != 0.f);
                    path[2].value += path.antithetic_vtx[2].value;
                }
            }
        }
        
        Float pdf1;
        bool antithetic_success = path.antithetic_vtx[2].type != EVInvalid;
        Float pdf_remain_main = 1.0f;
        Float pdf_remain_anti = 1.0f;
        Float weight_anti = 1.0f;
        if (antithetic_success)
            weight_anti = misRatio(path[0].pdf, path.antithetic_vtx[0].pdf);

        for (int ibounce = 1; ibounce < path.size(); ibounce++) {
            path.weight[ibounce] = 1.0f;
            if (ibounce > 1) {
                path.weight[ibounce] += path.weight[ibounce-1] * misRatio(pdf1, path[ibounce-2].pdf);

                if (antithetic_success) {
                    if (ibounce == 2) {
                        const Intersection &its0 = path.antithetic_vtx[0];
                        Intersection &its1 = path.antithetic_vtx[1];
                        Vector wo_local = its1.wi;
                        Vector wi = (path.antithetic_vtx[2].p - its1.p).normalized();
                        its1.wi = its1.toLocal(wi);
                        pdf1 = its1.pdfBSDF(wo_local) * geometric(its1.p, its0.p, its0.geoFrame.n);
                        its1.wi = wo_local;
                    }
                    weight_anti = misRatio(pdf_remain_anti, pdf_remain_main) + weight_anti * misRatio(pdf1, path[ibounce-2].pdf);
                    path.weight[ibounce] += weight_anti;
                    if (ibounce <= 3) {
                        pdf_remain_anti *= path.antithetic_vtx[ibounce-1].pdf;
                        pdf_remain_main *= path[ibounce-1].pdf;
                    }
                }

            }

            if (ibounce < path.size() - 1) {
                const Intersection &its0 = path[ibounce-1];
                Intersection &its1 = path[ibounce];
                Vector wo_local = its1.wi;
                Vector wi = (path[ibounce+1].p - its1.p).normalized();
                its1.wi = its1.toLocal(wi);
                pdf1 = its1.pdfBSDF(wo_local) * geometric(its1.p, its0.p, its0.geoFrame.n);
                its1.wi = wo_local;
            }
        }
    }

    Float BDPTRecord::computeWeight(const Scene& scene, int s, int t, Float sensor_val, bool antithetic_success)
    {
        const Camera& camera = scene.camera;
        Float pdf_fwd1, pdf_fwd2, pdf_bwd1, pdf_bwd2;
        Float inv_weight = 1.0f;

        // related to antithetic sampling
        Float J_primary = 1.f;
        Float weight_remain = 1.f;
        if (antithetic_success) {
            // pdf'(s, t) / pdf(s, t)
            if (s == -1) {
                const Intersection& anti_its = light_path.antithetic_vtx[0];
                inv_weight += misRatio(anti_its.pdf, light_path[t].pdf);
            } else {
                J_primary = camera_path[0].pdf / camera_path.antithetic_vtx[0].pdf;
                Float pdf_remain_main = 1.f;
                Float pdf_remain_anti = 1.f;
                if (s >= 1) {
                    pdf_remain_main *= camera_path[1].pdf;
                    pdf_remain_anti *= camera_path.antithetic_vtx[1].pdf;
                }
                if (s >= 2) {
                    pdf_remain_main *= camera_path[2].pdf;
                    pdf_remain_anti *= camera_path.antithetic_vtx[2].pdf;
                }
                weight_remain = misRatio(pdf_remain_anti, pdf_remain_main);
                inv_weight += weight_remain;
            }
        }

#ifdef BDPT_PT
        assert(s >= 0);
        if (t == -1) {
            pdf_fwd1 = camera_path[s].pdf;
            pdf_bwd1 = scene.pdfEmitterSample(camera_path[s]);
            inv_weight += misRatio(pdf_bwd1, pdf_fwd1);     // pdf(s-1, t+1) / pdf(s, t)
            if (antithetic_success) {
                Float pdf_remain_main = 1.0f;
                Float pdf_remain_anti = 1.0f;
                if (s >= 2) {
                    pdf_remain_main *= camera_path[1].pdf;
                    pdf_remain_anti *= camera_path.antithetic_vtx[1].pdf;
                }
                if (s >= 3){
                    pdf_remain_main *= camera_path[2].pdf;
                    pdf_remain_anti *= camera_path.antithetic_vtx[2].pdf;
                }
                inv_weight += misRatio(pdf_bwd1 * pdf_remain_anti, pdf_fwd1 * pdf_remain_main);
            }
        } else {
            assert(t == 0);
            pdf_fwd1 = light_path[t].pdf;
            Vector dir = light_path[t].p - camera_path[s].p;
            Float dist2 = dir.squaredNorm();
            dir /= std::sqrt(dist2);
            pdf_bwd1 = camera_path[s].pdfBSDF(camera_path[s].toLocal(dir));
            pdf_bwd1 *= std::abs(light_path[t].geoFrame.n.dot(dir)) / dist2;
            inv_weight += misRatio(pdf_bwd1, pdf_fwd1);     // pdf(s+1, t-1) / pdf(s, t)

            if (antithetic_success) {                
                if ( s <= 1 ) {
                    const Intersection& anti_its = camera_path.antithetic_vtx[s];
                    Vector dir = light_path[t].p - anti_its.p;
                    Float dist2 = dir.squaredNorm();
                    dir /= std::sqrt(dist2);
                    pdf_bwd1 = anti_its.pdfBSDF(anti_its.toLocal(dir));
                    pdf_bwd1 *= std::abs(light_path[t].geoFrame.n.dot(dir)) / dist2;
                }
                inv_weight += weight_remain * misRatio(pdf_bwd1, pdf_fwd1);       
            }
        }

#else
       if (s >= 0) {
            pdf_fwd1 = camera_path[s].pdf;
            Vector dir;
            if (t >= 0) {
                dir = camera_path[s].p - light_path[t].p;
                Float dist2 = dir.squaredNorm();
                dir /= std::sqrt(dist2);
                pdf_bwd1 = (t == 0) ? light_path[t].pdfEmitter(light_path[t].toLocal(dir))      // Here maybe we should use geometric normal
                                : light_path[t].pdfBSDF(light_path[t].toLocal(dir));
                pdf_bwd1 *= abs(dir.dot(camera_path[s].geoFrame.n)) / dist2;
            } else {
                pdf_bwd1 = scene.pdfEmitterSample(camera_path[s]);
            }

            inv_weight += misRatio(pdf_bwd1, pdf_fwd1);     // pdf(s-1, t+1) / pdf(s, t)
            if (s >= 1) {
                pdf_fwd2 = camera_path[s-1].pdf;
                Vector wo_local = camera_path[s].wi;
                if (t >= 0) {
                    camera_path[s].wi = camera_path[s].toLocal(-dir);
                    pdf_bwd2 = camera_path[s].pdfBSDF(wo_local);
                    camera_path[s].wi = wo_local;
                } else {
                    pdf_bwd2 = camera_path[s].pdfEmitter(wo_local);
                }
                pdf_bwd2 *= geometric(camera_path[s].p, camera_path[s-1].p, camera_path[s-1].geoFrame.n);
                inv_weight += camera_path.weight[s] * misRatio(pdf_bwd1 * pdf_bwd2, pdf_fwd1 * pdf_fwd2);     
            }
            if (antithetic_success) {
                if ( s == 0 ) {
                    Vector dir;
                    assert(t >= 0);
                    dir = camera_path.antithetic_vtx[0].p - light_path[t].p;
                    Float dist2 = dir.squaredNorm();
                    dir /= std::sqrt(dist2);
                    pdf_bwd1 = (t == 0) ? light_path[t].pdfEmitter(light_path[t].toLocal(dir))      // Here maybe we should use geometric normal
                                        : light_path[t].pdfBSDF(light_path[t].toLocal(dir));
                    pdf_bwd1 *= abs(dir.dot(camera_path.antithetic_vtx[0].geoFrame.n)) / dist2;
                    inv_weight += misRatio(pdf_bwd1 * J_primary, pdf_fwd1);
                } else {
                    Float pdf_remain_main = 1.0f;
                    Float pdf_remain_anti = 1.0f;
                    if (s >= 2) {
                        pdf_remain_main *= camera_path[1].pdf;
                        pdf_remain_anti *= camera_path.antithetic_vtx[1].pdf;
                    }
                    if (s >= 3){
                        pdf_remain_main *= camera_path[2].pdf;
                        pdf_remain_anti *= camera_path.antithetic_vtx[2].pdf;
                    }
                    inv_weight += misRatio(pdf_bwd1 * pdf_remain_anti, pdf_fwd1 * pdf_remain_main);                    
                    if (s == 1) {
                        Vector wo_local = camera_path.antithetic_vtx[1].wi;
                        if (t >= 0) {
                            camera_path.antithetic_vtx[1].wi = camera_path.antithetic_vtx[1].toLocal(-dir);
                            pdf_bwd2 = camera_path.antithetic_vtx[1].pdfBSDF(wo_local);
                            camera_path.antithetic_vtx[1].wi = wo_local;
                        } else {
                            pdf_bwd2 = camera_path.antithetic_vtx[1].pdfEmitter(wo_local);
                        }
                        pdf_bwd2 *= geometric(camera_path.antithetic_vtx[1].p,
                                              camera_path.antithetic_vtx[0].p,
                                              camera_path.antithetic_vtx[0].geoFrame.n);
                        inv_weight += misRatio(pdf_bwd1 * pdf_bwd2 * J_primary, pdf_fwd1 * pdf_fwd2);
                    }
                }
            }
        }

        if ( t >= 0 ) {
            pdf_fwd1 = light_path[t].pdf;
            if (s >= 0) {
                Vector dir = light_path[t].p - camera_path[s].p;
                Float dist2 = dir.squaredNorm();
                dir /= std::sqrt(dist2);
                pdf_bwd1 = camera_path[s].pdfBSDF(camera_path[s].toLocal(dir));
                pdf_bwd1 *= std::abs(light_path[t].geoFrame.n.dot(dir)) / dist2;
            } else {
                Vector dir = (light_path[t].p - camera.cpos).normalized(); 
                pdf_bwd1 = sensor_val / camera.getNumPixels() * std::abs(dir.dot(light_path[t].geoFrame.n));
            }
            inv_weight += misRatio(pdf_bwd1, pdf_fwd1);     // pdf(s+1, t-1) / pdf(s, t)

            if ( t >= 1 ) {
                pdf_fwd2 = light_path[t-1].pdf;
                Vector wo_local = light_path[t].wi;
                Vector wi = (((s >= 0) ? camera_path[s].p : camera.cpos) - light_path[t].p).normalized(); 
                light_path[t].wi = light_path[t].toLocal(wi);
                pdf_bwd2 = light_path[t].pdfBSDF(wo_local);
                light_path[t].wi = wo_local;
                pdf_bwd2 *= geometric(light_path[t].p, light_path[t-1].p, light_path[t-1].geoFrame.n);
                // Sum(pdf(s+t-j, j)) / pdf(s, t), where -1 <= j <= t-2
                inv_weight += light_path.weight[t] * misRatio(pdf_bwd1 * pdf_bwd2, pdf_fwd1 * pdf_fwd2);              
            }

            if (antithetic_success) {                
                if (s == -1) {
                    inv_weight += misRatio(pdf_bwd1, pdf_fwd1);
                    assert(t >= 1);
                    pdf_fwd1 *= pdf_fwd2;
                    Intersection& anti_its = light_path.antithetic_vtx[0];
                    Vector wo_local = anti_its.wi;
                    Vector wi = (camera.cpos - anti_its.p).normalized();
                    anti_its.wi = anti_its.toLocal(wi);
                    pdf_bwd1 *= anti_its.pdfBSDF(wo_local);
                    anti_its.wi = wo_local;
                    pdf_bwd1 *= geometric(anti_its.p, light_path[t-1].p, light_path[t-1].geoFrame.n);
                    inv_weight += misRatio(pdf_bwd1, pdf_fwd1);

                    if ( t >= 2 ) {
                        pdf_fwd1 *= light_path[t-2].pdf;
                        Vector wo_local = light_path[t-1].wi;
                        Vector wi = (anti_its.p - light_path[t-1].p).normalized(); 
                        light_path[t-1].wi = light_path[t-1].toLocal(wi);
                        pdf_bwd1 *= light_path[t-1].pdfBSDF(wo_local);
                        light_path[t-1].wi = wo_local;
                        pdf_bwd1 *= geometric(light_path[t-1].p, light_path[t-2].p, light_path[t-2].geoFrame.n);
                        inv_weight += light_path.weight[t-1] * misRatio(pdf_bwd1, pdf_fwd1);
                    }
                } else {
                    if ( s <= 1) {
                        const Intersection& anti_its = camera_path.antithetic_vtx[s];
                        Vector dir = light_path[t].p - anti_its.p;
                        Float dist2 = dir.squaredNorm();
                        dir /= std::sqrt(dist2);
                        pdf_bwd1 = anti_its.pdfBSDF(anti_its.toLocal(dir));
                        pdf_bwd1 *= std::abs(light_path[t].geoFrame.n.dot(dir)) / dist2;
                    }
                    inv_weight += weight_remain * misRatio(pdf_bwd1, pdf_fwd1);
                     
                    if ( t >= 1 ) {
                        if (s == 0) {
                            const Intersection& anti_its = camera_path.antithetic_vtx[s];
                            Vector wo_local = light_path[t].wi;
                            Vector wi = (anti_its.p - light_path[t].p).normalized(); 
                            light_path[t].wi = light_path[t].toLocal(wi);
                            pdf_bwd2 = light_path[t].pdfBSDF(wo_local);
                            light_path[t].wi = wo_local;
                            pdf_bwd2 *= geometric(light_path[t].p, light_path[t-1].p, light_path[t-1].geoFrame.n);
                        }
                        inv_weight += weight_remain * light_path.weight[t] * misRatio(pdf_bwd1 * pdf_bwd2, pdf_fwd1 * pdf_fwd2);
                    }
                }
            }
        }
#endif
        return 1.f / inv_weight;
    }

    template <bool ad>
    void bidirTrace(Type<Scene, ad> &sceneV, Type<BDPTRecord, ad> &b_rec, bool apply_antithetic)
    {
        const Scene &scene = value(sceneV);
        const Camera& camera = scene.camera;
        int num_pixels = camera.getNumPixels();

#ifdef FORWARD
        Float param = 0.f;
        int shape_idx = -1;
        if constexpr(ad) {
            shape_idx = scene.getShapeRequiresGrad();
            sceneV.gm.get(omp_get_thread_num()).shape_list[shape_idx]->param = 0;
            param = sceneV.gm.get(omp_get_thread_num()).shape_list[shape_idx]->param;
        }
#endif

        algorithm1_bdpt::LightPath &camera_path = b_rec.getCameraPath();
#ifndef BDPT_PARTICLE
        // sample camera path
        {
            camera_path.isCameraPath = true;
            camera_path.reserve(b_rec.max_bounces + 1);
            const Array2i& pixel_idx = b_rec.pixelIdx;
            Ray ray, ray_ant;
            camera.samplePrimaryRayFromFilter(pixel_idx[0], pixel_idx[1], b_rec.sampler->next2D(),
                                              ray, ray_ant);
            Intersection its;
            if (scene.rayIntersect(ray, false, its)) {
                its.pdf = camera.eval(pixel_idx[0], pixel_idx[1], its.p, its.geoFrame.n);
                its.pdf /= num_pixels;
                its.value = Spectrum(num_pixels);
                camera_path.pixelIdx = b_rec.pixelIdx;
                camera_path.append(its);
                if (apply_antithetic) {
                    Intersection& anti_its = camera_path.antithetic_vtx[0];
                    if (scene.rayIntersect(ray_ant, false, anti_its)) {
                        anti_its.pdf = camera.eval(pixel_idx[0], pixel_idx[1], anti_its.p, anti_its.geoFrame.n);
                        anti_its.pdf /= num_pixels;
                        anti_its.value = Spectrum(num_pixels);
                    }
                }
                sampleLightPath(scene, b_rec.sampler, camera_path);
            }
        }
#endif

        algorithm1_bdpt::LightPath &light_path = b_rec.getLightPath();
        // sample light path
        {
            light_path.isCameraPath = false;
            light_path.reserve(b_rec.max_bounces + 1);
            Intersection its;
            its.value = scene.sampleEmitterPosition(b_rec.sampler->next2D(), its);
            light_path.append(its);
#ifndef BDPT_PT
            sampleLightPath(scene, b_rec.sampler, light_path);
#endif
        }

        auto light_pathAD = algorithm1_bdpt::LightPathAD(light_path);
        auto camera_pathAD = algorithm1_bdpt::LightPathAD(camera_path);
#ifndef BDPT_PARTICLE
        // zero-bounce contrb.
        int pixel_idx = ravel_multi_index(b_rec.pixelIdx, {camera.width, camera.height});
        if (camera_path.size() > 0) {
            Spectrum path_val(0.f);
            bool antithetic_success = false;
            evalBDPTPath(scene, camera_path, 0, light_path, -1, path_val, antithetic_success);
            if (!path_val.isZero()) {
                Float mis_weight = antithetic_success ? 0.5f : 1.0f;
                if((mis_weight * path_val).allFinite())
                    b_rec.image[pixel_idx] += mis_weight * path_val;
                if constexpr (ad) {
                    d_evalPath(sceneV, camera_pathAD, 0, light_pathAD, -1, mis_weight, antithetic_success, b_rec.d_image[pixel_idx]);
                }
            }
        }

        // Path tracer (NEE + no MIS) with antithetic sampling
        // Float pdfRatio = 1.0f;
        // bool antithetic_valid = true;       // cehck if main path is valid
        // for (int s = 0; s < std::min(camera_path.size(), b_rec.max_bounces); s++) {
        //     Spectrum path_val(0.f);
        //     bool antithetic_success = false;
        //     evalBDPTPath(scene, camera_path, s, light_path, 0, path_val, antithetic_success);
        //     if (path_val.isZero()) continue;

        //     Float mis_weight = 1.0f;
        //     if (antithetic_success) {
        //         if (s == 1 || s == 2)
        //             pdfRatio *= misRatio(camera_path.antithetic_vtx[s].pdf, camera_path[s].pdf);
        //         mis_weight = 1.0 / (1.0 + pdfRatio);
        //     }
        //     b_rec.image[pixel_idx] += mis_weight * path_val;
        //     if constexpr (ad) {
        //         Spectrum val = mis_weight * path_val;
        //         algorithm1_bdpt::d_evalPath(sceneV, camera_pathAD, s, light_pathAD, 0, mis_weight, antithetic_success, b_rec.d_image[pixel_idx]);
        //     }
        // }

        // contrb from jittered sampling (i.e. s >= 0)
        for(int nbounce = 1; nbounce <= b_rec.max_bounces; nbounce++) {
            int s_min = std::max(0, nbounce - light_path.size());
            int s_max = std::min(nbounce, camera_path.size() - 1);
            for (int s = s_min; s <= s_max; s++) {
                int t = nbounce - 1 - s;
                Spectrum path_val(0.f);
                bool antithetic_success = false;
                evalBDPTPath(scene, camera_path, s, light_path, t, path_val, antithetic_success);
                if (path_val.isZero()) continue;
                Float mis_weight = b_rec.computeWeight(scene, s, t, 0.f, antithetic_success);
                if((mis_weight * path_val).allFinite())
                    b_rec.image[pixel_idx] += mis_weight * path_val;
                if constexpr (ad) {
                    algorithm1_bdpt::d_evalPath(sceneV, camera_pathAD, s, light_pathAD, t,
                                                mis_weight,antithetic_success, b_rec.d_image[pixel_idx]);
                }
            }
        }

#ifdef FORWARD
        if constexpr (ad) {
            for (int i = camera_path.size()-1; i >= 0; i--) {
                algorithm1_bdpt::d_evalVertex(sceneV, camera_pathAD, i);
                algorithm1_bdpt::d_getPoint(sceneV, camera_pathAD, i, false);
            }
            algorithm1_bdpt::d_getPoint(sceneV, camera_pathAD, 1, true);
            algorithm1_bdpt::d_getPoint(sceneV, camera_pathAD, 0, true);

            for (int i = light_path.size()-1; i >= 0; i--) {
                algorithm1_bdpt::d_evalVertex(sceneV, light_pathAD, i);
                algorithm1_bdpt::d_getPoint(sceneV, light_pathAD, i, false);
            }

            if (isfinite(sceneV.gm.get(omp_get_thread_num()).shape_list[shape_idx]->param))
            {
                Float delta = sceneV.gm.get(omp_get_thread_num()).shape_list[shape_idx]->param - param;
                param = sceneV.gm.get(omp_get_thread_num()).shape_list[shape_idx]->param;
                b_rec.grad_image[pixel_idx] += Spectrum(delta, 0.f, 0.f);
            }
        }
#endif
#endif

        Array4 sensor_vals;
        Matrix2x4 pix_uvs;
        Vector wo;
        for (int t = 1; t < light_path.size(); t++) {
            Intersection& its = light_path[t];
            if (!scene.isVisible(its.p, true, camera.cpos, true)) continue;
            camera.sampleDirect(its.p, pix_uvs, sensor_vals, wo);
            if (sensor_vals.isZero()) continue;
            Vector wi = its.toWorld(its.wi);
            Vector wo_local = its.toLocal(wo);
            /* Prevent light leaks due to the use of shading normals -- [Veach, p. 158] */
            Float wiDotGeoN = wi.dot(its.geoFrame.n), woDotGeoN = wo.dot(its.geoFrame.n);
            if (wiDotGeoN * its.wi.z() <= 0 || woDotGeoN * wo_local.z() <= 0) continue;
            for (int i = 0; i < 4; i++) {
                if (sensor_vals(i) < Epsilon) continue;
                Spectrum path_val = sensor_vals(i) * its.value * its.evalBSDF(wo_local, EBSDFMode::EImportanceWithCorrection);
                if (path_val.isZero()) continue;
                bool antithetic_success = false;
                int pixel_idx = camera.getPixelIndex(pix_uvs.col(i));
                light_path.pixelIdx = unravel_index(pixel_idx, {camera.width, camera.height});
                Intersection& anti_its = light_path.antithetic_vtx[0];
                if (apply_antithetic) {
                    Vector2 pixel_uv;
                    Vector dir;
                    camera.sampleDirect(its.p, pixel_uv, dir);
                    Ray dual = camera.sampleDualRay(light_path.pixelIdx, pixel_uv);
                    antithetic_success = scene.rayIntersect(dual, false, anti_its) &&
                                        scene.isVisible(anti_its.p, true, light_path[t-1].p, true);
                    if (antithetic_success) {
                        Spectrum contrb_anti = light_path[t-1].value;
                        Vector wo2 = (anti_its.p - light_path[t-1].p).normalized();
                        Vector wo2_local = light_path[t-1].toLocal(wo2);
                        Float G = geometric(light_path[t-1].p, anti_its.p, anti_its.geoFrame.n);
                        if ( t == 1 ) {
                            const Emitter* ptr_emitter = light_path[0].ptr_emitter;
                            contrb_anti *= ptr_emitter->evalDirection(light_path[0].geoFrame.n, wo2) * G / light_path[1].pdf;
                        } else {
                            if (wo2.dot(light_path[t-1].geoFrame.n) * wo2_local.z() < 0)
                                contrb_anti.setZero();
                            else
                                contrb_anti *= light_path[t-1].evalBSDF(wo2_local, EBSDFMode::EImportanceWithCorrection) * G / light_path[t].pdf;
                        }
                        antithetic_success = !contrb_anti.isZero();

                        if (antithetic_success) {
                            Vector dir2cam_local = anti_its.wi;
                            anti_its.wi = anti_its.toLocal(-wo2);
                            Float wiDotGeoN2 = -wo2.dot(anti_its.geoFrame.n), woDotGeoN2 = -dual.dir.dot(anti_its.geoFrame.n);
                            if (wiDotGeoN2 * anti_its.wi.z() <= 0 || woDotGeoN2 * dir2cam_local.z() <= 0)
                                contrb_anti.setZero();
                            else
                                contrb_anti *= anti_its.evalBSDF(dir2cam_local, EBSDFMode::EImportanceWithCorrection);

                            antithetic_success = !contrb_anti.isZero();
                            if (antithetic_success) {
                                anti_its.pdf = t == 1 ? light_path[0].pdfEmitter(wo2_local) * G          // should use geo. normal here
                                                    : light_path[t-1].pdfBSDF(wo2_local) * G;
                                anti_its.pdf *= camera.eval(light_path.pixelIdx[0], light_path.pixelIdx[1],
                                                            light_path[t].p, light_path[t].geoFrame.n);
                                anti_its.pdf /= camera.eval(light_path.pixelIdx[0], light_path.pixelIdx[1],
                                                            anti_its.p, anti_its.geoFrame.n);
                                Float J = woDotGeoN / woDotGeoN2;
                                contrb_anti *= sensor_vals(i) * J;
                                path_val += contrb_anti;
                            }
                        }
                    }
                }
#ifdef BDPT_PARTICLE
                //Particle tracer with antithetic sampling
                Float mis_weight = 1.0f;
                if (antithetic_success) {
                    Float pdf_anti = anti_its.pdf / camera.geometric(anti_its.p, anti_its.geoFrame.n);
                    pdf_anti *= camera.geometric(light_path[t].p, light_path[t].geoFrame.n);
                    mis_weight = 1.0 / (1.0 + misRatio(pdf_anti, light_path[t].pdf));
                }
                b_rec.image[pixel_idx] += path_val * mis_weight;
#else               
                Float mis_weight = b_rec.computeWeight(scene, -1, t, sensor_vals(i), antithetic_success);
                b_rec.image[pixel_idx] += path_val * mis_weight;
#endif
                if constexpr (ad) {
                    light_pathAD.val.antithetic_vtx[0] = light_path.antithetic_vtx[0];
                    light_pathAD.val.pixelIdx = light_path.pixelIdx;
                    light_pathAD.der.antithetic_vtx[0] = light_path.antithetic_vtx[0];
                    light_pathAD.der.antithetic_vtx[0].setZero();
                    algorithm1_bdpt::d_evalPath(sceneV, camera_pathAD, -1, light_pathAD, t,
                                                mis_weight, antithetic_success, b_rec.d_image[pixel_idx]);
#ifdef FORWARD
                    for (int i = t; i >= 0; i--) {
                        algorithm1_bdpt::d_evalVertex(sceneV, light_pathAD, i);
                        if (i == t && antithetic_success)
                            algorithm1_bdpt::d_getPoint(sceneV, light_pathAD, i, true);
                        algorithm1_bdpt::d_getPoint(sceneV, light_pathAD, i, false);
                    }

                    if (isfinite(sceneV.gm.get(omp_get_thread_num()).shape_list[shape_idx]->param))
                    {
                        Float delta = sceneV.gm.get(omp_get_thread_num()).shape_list[shape_idx]->param - param;
                        param = sceneV.gm.get(omp_get_thread_num()).shape_list[shape_idx]->param;
                        b_rec.grad_image[pixel_idx] += Spectrum(delta, 0.f, 0.f);
                    }
#else
                    if (antithetic_success)
                        algorithm1_bdpt::d_getPoint(sceneV, light_pathAD, t, true);
#endif
                }
            }
        }

#ifndef FORWARD
        if constexpr (ad) {
#ifndef BDPT_PARTICLE
            for (int i = camera_path.size()-1; i >= 0; i--) {
                algorithm1_bdpt::d_evalVertex(sceneV, camera_pathAD, i);
                algorithm1_bdpt::d_getPoint(sceneV, camera_pathAD, i, false);
            }
            algorithm1_bdpt::d_getPoint(sceneV, camera_pathAD, 1, true);
            algorithm1_bdpt::d_getPoint(sceneV, camera_pathAD, 0, true);
#endif
            for (int i = light_path.size()-1; i >= 0; i--) {
                algorithm1_bdpt::d_evalVertex(sceneV, light_pathAD, i);
                algorithm1_bdpt::d_getPoint(sceneV, light_pathAD, i, false);
            }
        }
#endif
    }
}

ArrayXd BDPT2::renderC(const Scene &scene, const RenderOptions &options) const
{
    int size_block = scene.camera.getNumPixels();
    int num_block = options.num_samples;
    const auto &camera = scene.camera;
    int num_pixels = camera.getNumPixels();
    const int nworker = omp_get_num_procs();
    BlockedImage blocks({camera.width, camera.height}, {16, 16});
    std::vector<Spectrum> spec_list(camera.getNumPixels(), Spectrum::Zero());
    ThreadManager thread_manager(spec_list, nworker);
    int blockProcessed = 0;
    Timer _("Render interior");

#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < blocks.m_BlocksTotal; i++)
    {
        ImageBlock block = blocks.getBlock(i);
        for (Array2i pixelIdx = block.curPixel(); block.hasNext(); pixelIdx = block.nextPixel())
        {
            int pixel_ravel_idx = ravel_multi_index(pixelIdx, {camera.width, camera.height});
            RndSampler sampler(options.seed, pixel_ravel_idx);
            Spectrum pixel_val = Spectrum::Zero();
            int thread_id = omp_get_thread_num();
            for (int j = 0; j < options.num_samples; j++)
            {
                BDPTRecord b_rec = BDPTRecord(&sampler, options.max_bounces, pixelIdx, thread_manager.get(thread_id));
                bidirTrace<false>(scene, b_rec, mApplyAntithetic);
            }
        }

        if (verbose)
#pragma omp critical
            progressIndicator(static_cast<Float>(++blockProcessed) / blocks.size());
    }
    if (verbose)
        std::cout << std::endl;

    thread_manager.merge();
    size_t num_samples = size_t(size_block) * num_block;
    for (auto &spec : spec_list)
        spec /= num_samples;
    return from_spectrum_list_to_tensor(spec_list, camera.getNumPixels());
}

ArrayXd BDPT2::renderD(SceneAD &sceneAD, const RenderOptions &options, const ArrayXd &d_image) const
{
    int size_block = sceneAD.val.camera.getNumPixels();
    int num_block = options.num_samples;
    const auto &camera = sceneAD.val.camera;
    int num_pixels = camera.getNumPixels();
    const int nworker = omp_get_num_procs();
    BlockedImage blocks({camera.width, camera.height}, {16, 16});
    // d_scene
    GradientManager<Scene> &gm = sceneAD.gm;
    // image
    std::vector<Spectrum> spec_list(camera.getNumPixels(), Spectrum::Zero());
    ThreadManager thread_manager(spec_list, nworker);
    // d_image
    auto d_image_spec_list = from_tensor_to_spectrum_list(
        d_image / sceneAD.val.camera.getNumPixels() / options.num_samples, camera.getNumPixels());
    // gradient image
    std::vector<Spectrum> grad_image(std::vector<Spectrum>(camera.getNumPixels(), Spectrum::Zero()));
    ThreadManager grad_images(grad_image, nworker);
    int blockProcessed = 0;
    Timer _("Render interior");

#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < blocks.m_BlocksTotal; i++)
    {
        ImageBlock block = blocks.getBlock(i);
        for (Array2i pixelIdx = block.curPixel(); block.hasNext(); pixelIdx = block.nextPixel())
        {
            int pixel_ravel_idx = ravel_multi_index(pixelIdx, {camera.width, camera.height});
            RndSampler sampler(options.seed, pixel_ravel_idx);
            Spectrum pixel_val = Spectrum::Zero();
            int thread_id = omp_get_thread_num();
            for (int j = 0; j < options.num_samples; j++)
            {
                BDPTRecordAD b_recAD = BDPTRecordAD(&sampler, options.max_bounces, pixelIdx,
                                                  thread_manager.get(thread_id), d_image_spec_list, grad_images.get(thread_id));
                bidirTrace<true>(sceneAD, b_recAD, mApplyAntithetic);
            }
        }

        if (verbose)
#pragma omp critical
            progressIndicator(static_cast<Float>(++blockProcessed) / blocks.size());
    }
    if (verbose)
        std::cout << std::endl;

    thread_manager.merge();
    grad_images.merge();
    sceneAD.gm.merge();
    size_t num_samples = size_t(size_block) * num_block;
    // for (auto &spec : spec_list)
    //     spec /= num_samples;
    // for (auto &spec : grad_image)
    //     spec /= num_samples;
    return from_spectrum_list_to_tensor(grad_image, camera.getNumPixels());
}