#include "bdptwas.h"
#include <render/imageblock.h>
#include <core/ray.h>
#include <core/sampler.h>
#include <render/scene.h>
#include <core/timer.h>
#include <iomanip>
#include "algorithm1.h"

// #define BDPT_PARTICLE

namespace
{
    struct BDPTWASRecord
    {
        BDPTWASRecord(RndSampler* rnd_sampler, int maxBounces, const Array2i& pixIdx, std::vector<Spectrum>& img)
        : sampler(rnd_sampler), max_bounces(maxBounces), pixelIdx(pixIdx), image(img) {}
        
        mutable RndSampler *sampler;
        int max_bounces;
        Array2i pixelIdx;
        std::vector<Spectrum> &image;

        mutable algorithm1_bdptwas::LightPath light_path;
        mutable algorithm1_bdptwas::LightPath camera_path;

        algorithm1_bdptwas::LightPath &getLightPath() const { return light_path; }
        algorithm1_bdptwas::LightPath &getCameraPath() const { return camera_path;}
        Float computeWeight(const Scene& scene, int s, int t, Float sensor_val);
    };

    struct BDPTWASRecordAD : BDPTWASRecord
    {
        BDPTWASRecordAD(RndSampler* rnd_sampler, int maxBounces, const Array2i& pixIdx,
                     std::vector<Spectrum>& img, std::vector<Spectrum>& d_img, std::vector<Spectrum>& grad_img)
        : BDPTWASRecord(rnd_sampler, maxBounces, pixIdx, img), d_image(d_img), grad_image(grad_img)
        {
            // sampler = rnd_sampler;
            // max_bounces = maxBounces;
            // pixelIdx = pixIdx;
            // image = img;
            // d_image = d_img;
            // grad_image = grad_img;
        }

        std::vector<Spectrum> &d_image;
        std::vector<Spectrum> &grad_image;
    };

    template <typename T, bool ad>
    struct type_traits;

    template <>
    struct type_traits<Scene, false>
    {
        using type = const Scene;
    };
    template <>
    struct type_traits<Scene, true>
    {
        using type = SceneAD;
    };

    template <typename T>
    auto &value(T &t);

    template <>
    auto &value(const Scene &t)
    {
        return t;
    }

    template <>
    auto &value(SceneAD &t)
    {
        return t.val;
    }

    template <>
    struct type_traits<BDPTWASRecord, false>
    {
        using type = BDPTWASRecord;
    };
    template <>
    struct type_traits<BDPTWASRecord, true>
    {
        using type = BDPTWASRecordAD;
    };

    template <typename T, bool ad>
    using Type = typename type_traits<T, ad>::type;

    Float misRatio(Float pdf0, Float pdf1)
    {
        Float ret = pdf0/pdf1;
        ret *= ret;
        return ret;
    }

    void evalBDPTPath(const Scene& scene, const algorithm1_bdptwas::LightPath& camera_path, int s, const algorithm1_bdptwas::LightPath& light_path, int t,
                    Spectrum& contrb_all)
    {
        if ( t == -1 ) {
            //assert(s > 0);
            const Intersection& v_cam = camera_path[s];
            const Vector& pre_p = (s == 0) ? scene.camera.cpos : camera_path[s-1].p;
            if (v_cam.isEmitter()) {
                Vector dir = (pre_p - v_cam.p).normalized();
                contrb_all = v_cam.Le(dir) * camera_path[s].value;
            }

        } else {
            const Intersection& v_lgt = light_path[t];
            const Intersection& v_cam = camera_path[s];
            if (!scene.isVisible(v_lgt.p, true, v_cam.p, true)) return;

            Vector dir = v_lgt.p - v_cam.p;
            Float dist2 = dir.squaredNorm();
            dir /= std::sqrt(dist2);
            Spectrum seg_lgt = (t == 0) ? Spectrum::Ones() * v_lgt.ptr_emitter->evalDirection(v_lgt.geoFrame.n, -dir)
                                        : v_lgt.evalBSDF(v_lgt.toLocal(-dir), EBSDFMode::EImportanceWithCorrection);
            seg_lgt /= dist2;
            Spectrum seg_cam = v_cam.evalBSDF(v_cam.toLocal(dir), EBSDFMode::ERadiance);
            contrb_all = seg_cam * seg_lgt * v_cam.value * v_lgt.value;
        }
    }

    int sampleSurface(const Vector2 &rnd2, const Scene& scene, const Intersection &primal, Intersection& aux, Float sigma) {
        // Vector2 rnd2(_rnd2);
        Vector2 _rnd2(rnd2);
        // std::cout << "<------sample surface------>" << std::endl;
        Float rnd1 = _rnd2[0];
        _rnd2[0] = _rnd2[0] > 0.5 ? (_rnd2[0] - 0.5) * 2.0 : _rnd2[0] * 2.0;
        const Shape* shape = scene.getShape(primal.shape_id);
        const Vector3i &ind = shape->indices[primal.triangle_id];
        const Vector &v0 = shape->vertices[ind[0]], &v1 = shape->vertices[ind[1]], &v2 = shape->vertices[ind[2]];
        Vector2 barycentric = primal.barycentric;
        Vector x0 = (1.0f - primal.barycentric[0] - primal.barycentric[1])*v0 + primal.barycentric[0]*v1 + primal.barycentric[1]*v2;
        
        // std::cout << barycentric << std::endl;
        Vector coord_system[3];
        coord_system[0] = shape->faceNormals[primal.triangle_id];
        coordinateSystem(coord_system[0], coord_system[1], coord_system[2]);
        // Intersection its;
        Ray ray;
        Vector2 diskpos = squareToGaussianDisk(_rnd2, sigma);
        int axis = 0;
        
        Vector sphereCoord = Vector(diskpos(0), diskpos(1), 1e-2);
        Vector origin = x0 + sphereCoord[2] * coord_system[axis]// * direction
                            + sphereCoord[0] * coord_system[(axis + 1) % 3]
                            + sphereCoord[1] * coord_system[(axis + 2) % 3];
        
        Intersection its0, its1;
        //one ray down
        ray = Ray(origin, -coord_system[axis]);
        bool its0_exists = scene.rayIntersect(ray, its0, IntersectionMode::ESpatial);
        if (its0.t > 0.1) its0_exists = false;
        if (its0.shape_id != primal.shape_id) its0_exists = false;

        //one ray up
        ray = Ray(origin, coord_system[axis]);
        bool its1_exists = scene.rayIntersect(ray, its1, IntersectionMode::ESpatial);
        // std::cout << "origin: " << origin << std::endl;
        // std::cout << "ray_dir: " << coord_system[axis] << std::endl;
        // std::cout << "its.p: " << its0.p << ", " << its1.p << std::endl;
        if (its1.t > 0.1) its1_exists = false;
        if (its1.shape_id != primal.shape_id) its1_exists = false;
        // std::cout << "exists: " << its0_exists << ", " << its1_exists << std::endl;
        if (its0_exists && its1_exists) {
            if (rnd1 > 0.5) 
                aux = its0;
            else 
                aux = its1;
            return 2;
        } else if (its0_exists) {
            aux = its0;
            return 1;
        } else if (its1_exists) {
            aux = its1;
            return 1;
        }
        else {
            return 0;
        }
        // std::cout << "<----sample surface end---->" << std::endl;
    }

    void sampleSurfacePair(const Vector2 &rnd2, const Scene& scene, const Intersection &primal, Intersection& aux_0, Intersection& aux_1, int &found0, int &found1, Float sigma) {
        Vector2 _rnd2 = rnd2;
        found0 = sampleSurface(_rnd2, scene, primal, aux_0, sigma);
        _rnd2[1] = (_rnd2[1] + 0.5) > 1 ? _rnd2[1] - 0.5 : _rnd2[1] + 0.5;
        found1 = sampleSurface(_rnd2, scene, primal, aux_1, sigma);
    }

    Float sampleSurfacePdf(const Scene& scene, const Intersection &primal, const Intersection &aux, Float sigma) {
        const Shape* shape = scene.getShape(primal.shape_id);
        
        Vector coord_system[3];
        coord_system[0] = shape->faceNormals[primal.triangle_id];
        coordinateSystem(coord_system[0], coord_system[1], coord_system[2]);
        Vector off = aux.p - primal.p;
        int axis = 0;
        Vector2 diskpos = Vector2(coord_system[(axis + 1) % 3].dot(off), coord_system[(axis + 2) % 3].dot(off));
        // std::cout << "<---------------sampleSurfacePdf-------------->" << std::endl;
        Float pdf = squareToGaussianDiskPdf(diskpos, sigma);
        // std::cout << diskpos << std::endl;
        // std::cout << sigma << ", " << pdf << std::endl;
        // std::cout << "<-------------sampleSurfacePdf end------------>" << std::endl;
        // assert(pdf > 0);
        return pdf;
    }

    void harmonic_weight(const Vector &prim, const Vector &aux, const Float &rayLength, const Float &sigma, const Float &B, Float &w, Vector &dw) { // hand derived
        Float dist = (prim - aux).squaredNorm();
        Float D = pow(rayLength / sigma, algorithm1_bdptwas::power) * (1.0 - exp(-dist / sigma / rayLength));
        w = 1.0 / ((pow(D, 3) + B));
        dw = -3.0 * pow(D, 2) / pow((pow(D, 3) + B), 2) * pow(rayLength / sigma, algorithm1_bdptwas::power) * (exp(-dist / sigma / rayLength) / sigma / rayLength * 2.0 * (prim - aux));
    }

    void sampleVertex(const Scene& scene, RndSampler *sampler, const Intersection &cur, const Intersection &prev, 
                    std::vector<algorithm1_bdptwas::WASCache> &cache_list) {
        
        // std::cout << "<---------------sample Vertex-------------->" << std::endl;
        int valid_sample[64];
        Float rayLength = std::sqrt((prev.p - cur.p).norm());
        Float sigma = algorithm1_bdptwas::base_sigma * rayLength;
        bool ANTITHETIC = true;
        assert(!ANTITHETIC || (algorithm1_bdptwas::NUM_AUX_SAMPLES % 2 == 0));
        Intersection antithetic_buffer;
        cache_list.resize(algorithm1_bdptwas::NUM_AUX_SAMPLES);
        
        for (int i = 0; i < algorithm1_bdptwas::NUM_AUX_SAMPLES; i++) {
            // std::cout << "<-------------iter " << i << " ------------>" << std::endl;
            Intersection x_aux;
            x_aux.shape_id = cur.shape_id;
            algorithm1_bdptwas::WASCache was_cache;
            // std::cout << "x bary" << x->its.barycentric << std::endl;
            if (!ANTITHETIC) {
                valid_sample[i] = sampleSurface(sampler->next2D(), scene, cur, x_aux, sigma);
            } else if (i % 2 == 0) {
                sampleSurfacePair(sampler->next2D(), scene, cur, 
                                    x_aux, antithetic_buffer, valid_sample[i], valid_sample[i + 1], sigma);
            } else {
                x_aux = antithetic_buffer;
            }
            
            if (valid_sample[i] == 0) {
                was_cache.VALID = false;
                cache_list[i] = (was_cache);
                continue;
            }
            
            Vector n_local = cur.geoFrame.toLocal(x_aux.geoFrame.n);
            Float area_J = n_local[2];
            Float pdf = sampleSurfacePdf(scene, cur, x_aux, sigma) * area_J / valid_sample[i];
            was_cache.aux = x_aux;

            Intersection isec_half;
            Vector dir = (x_aux.p - prev.p).normalized();
            Ray ray(prev.p, dir);
            bool half_found = scene.rayIntersect(ray, true, isec_half);
            Float B = 0.0;
            if (!half_found) {
                was_cache.VALID = false;
                cache_list[i] = (was_cache);
                continue;
            }
            // if (dir.dot(prev.geoFrame.n) < 1e-2) { // if aux_point is in the back hemisphere
            //     B = (scene.shape_list[isec_half.shape_id])->B(ray, isec_half);
            //     was_cache.force_zero = true;
            // }
            if (isec_half.t < 1e-5) {
                was_cache.force_zero = true;
            }
            if ((isec_half.p - x_aux.p).norm() < 1e-3) {
                
                was_cache.force_zero = true;
            }
            if (isec_half.shape_id < 0) {
                was_cache.VALID = false;
                cache_list[i] = (was_cache);
                continue;
            }
            B = (scene.shape_list[isec_half.shape_id])->B(ray, isec_half);
            harmonic_weight(cur.p, x_aux.p, rayLength, sigma, B, was_cache.w, was_cache.dw);
            was_cache.w /= pdf;
            was_cache.dw /= pdf;
            if ((isec_half.p - x_aux.p).norm() > 1e-3) {
                // std::cout << "<-------------sampleVertex " << i << "------------>" << std::endl;
                // std::cout << "force_zero: " << was_cache.force_zero << std::endl;
                // std::cout << "prim: " << cur.p << ", " << cur.shape_id << ", " << cur.triangle_id << std::endl;
                // std::cout << "aux: " << x_aux.p << ", " << x_aux.shape_id << ", " << x_aux.triangle_id << std::endl;
                // std::cout << "w, dw: " << was_cache.w << ", " << was_cache.dw << std::endl;
                // std::cout << "aux_half: " << isec_half.p << ", " << isec_half.shape_id << std::endl;
            }
            was_cache.aux_half = isec_half;
            cache_list[i] = (was_cache);
        }
        // std::cout << cache_list.size() << "\n";
        // std::cout << "<-------------sample Vertex end------------>" << std::endl;
    }


    void sampleLightPath(const Scene& scene, RndSampler* sampler, algorithm1_bdptwas::LightPath& path, bool sampleAD)
    {
        int max_bounces = path.vertices.capacity() - 1;
        Float pdf;
        Ray ray;
        Intersection its = path[0];
        for (int ibounce = 0; ibounce < max_bounces; ibounce++) {
            Spectrum throughput(1.0f);
            if (ibounce == 0 && !path.isCameraPath) {
                throughput *= its.ptr_emitter->sampleDirection(sampler->next2D(), ray.dir, &pdf);
                if (its.ptr_shape != nullptr)
                    ray.dir = its.geoFrame.toWorld(ray.dir);
            } else {
                Float bsdf_eta;
                Vector wo_local, wo;
                EBSDFMode mode = path.isCameraPath ? EBSDFMode::ERadiance : EBSDFMode::EImportanceWithCorrection;
                throughput *= its.sampleBSDF(sampler->next3D(), wo_local, pdf, bsdf_eta, mode);
                wo = its.toWorld(wo_local);
                // check light leak
                Vector wi = its.toWorld(its.wi);
                Float wiDotGeoN = wi.dot(its.geoFrame.n), woDotGeoN = wo.dot(its.geoFrame.n);
                if (wiDotGeoN * its.wi.z() <= 0 || woDotGeoN * wo_local.z() <= 0)
                    throughput.setZero();
                ray.dir = wo;
            }
            ray.org = its.p;
            Intersection its_debug = its;
            Intersection its_prev = its;
            if (throughput.isZero(Epsilon) || !scene.rayIntersect(ray, true, its)) break;
            Float G = geometric(ray.org, its.p, its.geoFrame.n);
            if (G < Epsilon || its.t < ShadowEpsilon) break;     // Hack for fixing numerical issues
            its.pdf = pdf * G;
            its.value = path[ibounce].value * throughput;
            path.append(its);
        }
        
        Float pdf1;
        Float pdf_remain_main = 1.0f;
        Float pdf_remain_anti = 1.0f;
        Float weight_anti = 1.0f;

        for (int ibounce = 1; ibounce < path.size(); ibounce++) {
            path.weight[ibounce] = 1.0f;
            if (ibounce > 1) {
                path.weight[ibounce] += path.weight[ibounce-1] * misRatio(pdf1, path[ibounce-2].pdf);
            }

            if (ibounce < path.size() - 1) {
                const Intersection &its0 = path[ibounce-1];
                Intersection &its1 = path[ibounce];
                Vector wo_local = its1.wi;
                Vector wi = (path[ibounce+1].p - its1.p).normalized();
                its1.wi = its1.toLocal(wi);
                pdf1 = its1.pdfBSDF(wo_local) * geometric(its1.p, its0.p, its0.geoFrame.n);
                its1.wi = wo_local;
            }
        }
        if (sampleAD) {
            if (path.isCameraPath) {
                for (int ibounce = 1; ibounce < path.size(); ibounce++) {
                    Intersection &cur = path[ibounce];
                    Intersection &prev = path[ibounce-1];
                    sampleVertex(scene, sampler, cur, prev, path.vertex_container[ibounce].vertex_cache);
                }
            } else {
                for (int ibounce = 1; ibounce < path.size(); ibounce++) {
                    Intersection &prev = path[ibounce];
                    Intersection &cur = path[ibounce-1];
                    sampleVertex(scene, sampler, cur, prev, path.vertex_container[ibounce-1].vertex_cache);
                }
            }
        }
    }

    Float BDPTWASRecord::computeWeight(const Scene& scene, int s, int t, Float sensor_val)
    {
        const Camera& camera = scene.camera;
        Float pdf_fwd1, pdf_fwd2, pdf_bwd1, pdf_bwd2;
        Float inv_weight = 1.0f;

        // related to antithetic sampling
        Float J_primary = 1.f;
        Float weight_remain = 1.f;
#ifdef BDPT_PT
        assert(s >= 0);
        if (t == -1) {
            pdf_fwd1 = camera_path[s].pdf;
            pdf_bwd1 = scene.pdfEmitterSample(camera_path[s]);
            inv_weight += misRatio(pdf_bwd1, pdf_fwd1);     // pdf(s-1, t+1) / pdf(s, t)
        } else {
            assert(t == 0);
            pdf_fwd1 = light_path[t].pdf;
            Vector dir = light_path[t].p - camera_path[s].p;
            Float dist2 = dir.squaredNorm();
            dir /= std::sqrt(dist2);
            pdf_bwd1 = camera_path[s].pdfBSDF(camera_path[s].toLocal(dir));
            pdf_bwd1 *= std::abs(light_path[t].geoFrame.n.dot(dir)) / dist2;
            inv_weight += misRatio(pdf_bwd1, pdf_fwd1);     // pdf(s+1, t-1) / pdf(s, t)
        }

#else
       if (s >= 0) {
            pdf_fwd1 = camera_path[s].pdf;
            Vector dir;
            if (t >= 0) {
                dir = camera_path[s].p - light_path[t].p;
                Float dist2 = dir.squaredNorm();
                dir /= std::sqrt(dist2);
                pdf_bwd1 = (t == 0) ? light_path[t].pdfEmitter(light_path[t].toLocal(dir))      // Here maybe we should use geometric normal
                                : light_path[t].pdfBSDF(light_path[t].toLocal(dir));
                pdf_bwd1 *= abs(dir.dot(camera_path[s].geoFrame.n)) / dist2;
            } else {
                pdf_bwd1 = scene.pdfEmitterSample(camera_path[s]);
            }

            inv_weight += misRatio(pdf_bwd1, pdf_fwd1);     // pdf(s-1, t+1) / pdf(s, t)
            if (s >= 1) {
                pdf_fwd2 = camera_path[s-1].pdf;
                Vector wo_local = camera_path[s].wi;
                if (t >= 0) {
                    camera_path[s].wi = camera_path[s].toLocal(-dir);
                    pdf_bwd2 = camera_path[s].pdfBSDF(wo_local);
                    camera_path[s].wi = wo_local;
                } else {
                    pdf_bwd2 = camera_path[s].pdfEmitter(wo_local);
                }
                pdf_bwd2 *= geometric(camera_path[s].p, camera_path[s-1].p, camera_path[s-1].geoFrame.n);
                inv_weight += camera_path.weight[s] * misRatio(pdf_bwd1 * pdf_bwd2, pdf_fwd1 * pdf_fwd2);     
            }
        }

        if ( t >= 0 ) {
            pdf_fwd1 = light_path[t].pdf;
            if (s >= 0) {
                Vector dir = light_path[t].p - camera_path[s].p;
                Float dist2 = dir.squaredNorm();
                dir /= std::sqrt(dist2);
                pdf_bwd1 = camera_path[s].pdfBSDF(camera_path[s].toLocal(dir));
                pdf_bwd1 *= std::abs(light_path[t].geoFrame.n.dot(dir)) / dist2;
            } else {
                Vector dir = (light_path[t].p - camera.cpos).normalized(); 
                pdf_bwd1 = sensor_val / camera.getNumPixels() * std::abs(dir.dot(light_path[t].geoFrame.n));
            }
            inv_weight += misRatio(pdf_bwd1, pdf_fwd1);     // pdf(s+1, t-1) / pdf(s, t)

            if ( t >= 1 ) {
                pdf_fwd2 = light_path[t-1].pdf;
                Vector wo_local = light_path[t].wi;
                Vector wi = (((s >= 0) ? camera_path[s].p : camera.cpos) - light_path[t].p).normalized(); 
                light_path[t].wi = light_path[t].toLocal(wi);
                pdf_bwd2 = light_path[t].pdfBSDF(wo_local);
                light_path[t].wi = wo_local;
                pdf_bwd2 *= geometric(light_path[t].p, light_path[t-1].p, light_path[t-1].geoFrame.n);
                // Sum(pdf(s+t-j, j)) / pdf(s, t), where -1 <= j <= t-2
                inv_weight += light_path.weight[t] * misRatio(pdf_bwd1 * pdf_bwd2, pdf_fwd1 * pdf_fwd2);              
            }

        }
#endif
        return 1.f / inv_weight;
    }

    template <bool ad>
    void bidirTrace(Type<Scene, ad> &sceneV, Type<BDPTWASRecord, ad> &b_rec)
    {
        int thread_id = omp_get_thread_num();
        const Scene &scene = value(sceneV);
        const Camera& camera = scene.camera;
        int num_pixels = camera.getNumPixels();

#ifdef FORWARD
        Float param = 0.f;
        int shape_idx = -1;
        if constexpr(ad) {
            shape_idx = scene.getShapeRequiresGrad();
            sceneV.gm.get(thread_id).shape_list[shape_idx]->param = 0;
            param = sceneV.gm.get(thread_id).shape_list[shape_idx]->param;
        }
#endif

        algorithm1_bdptwas::LightPath &camera_path = b_rec.getCameraPath();
#ifndef BDPT_PARTICLE
        // sample camera path
        {
            camera_path.isCameraPath = true;
            camera_path.reserve(b_rec.max_bounces + 1);
            const Array2i& pixel_idx = b_rec.pixelIdx;
            Ray ray, ray_ant;
            camera.samplePrimaryRayFromFilter(pixel_idx[0], pixel_idx[1], b_rec.sampler->next2D(),
                                              ray, ray_ant);
            Intersection its;
            if (scene.rayIntersect(ray, false, its)) {
                its.pdf = camera.eval(pixel_idx[0], pixel_idx[1], its.p, its.geoFrame.n);
                its.pdf /= num_pixels;
                its.value = Spectrum(num_pixels);
                camera_path.pixelIdx = b_rec.pixelIdx;
                camera_path.append(its);
                sampleLightPath(scene, b_rec.sampler, camera_path, ad);
            }
        }
#endif

        algorithm1_bdptwas::LightPath &light_path = b_rec.getLightPath();
        // sample light path
        {
            light_path.isCameraPath = false;
            light_path.reserve(b_rec.max_bounces + 1);
            Intersection its;
            its.value = scene.sampleEmitterPosition(b_rec.sampler->next2D(), its);
            light_path.append(its);
#ifndef BDPT_PT
            sampleLightPath(scene, b_rec.sampler, light_path, ad);
#endif
        }

        auto light_pathAD = algorithm1_bdptwas::LightPathAD(light_path);
        auto camera_pathAD = algorithm1_bdptwas::LightPathAD(camera_path);
#ifndef BDPT_PARTICLE
        // zero-bounce contrb.
        int pixel_idx = ravel_multi_index(b_rec.pixelIdx, {camera.width, camera.height});
        // contrb from jittered sampling (i.e. s >= 0)
        if constexpr (ad) {
            camera_pathAD.der.resize(camera_pathAD.val);
            light_pathAD.der.resize(light_pathAD.val);
            camera_pathAD.der.setZero();
            light_pathAD.der.setZero();
            if (algorithm1_bdptwas::use_prefix_sum){
                algorithm1_bdptwas::evalWarpPath(scene, light_pathAD.val);
                algorithm1_bdptwas::evalWarpPath(scene, camera_pathAD.val);
                algorithm1_bdptwas::evalPrefix(light_pathAD.val);
                algorithm1_bdptwas::evalPrefix(camera_pathAD.val);
            }
        }
        for(int nbounce = 1; nbounce <= b_rec.max_bounces; nbounce++) {
            int s_min = std::max(0, nbounce - light_path.size());
            int s_max = std::min(nbounce, camera_path.size() - 1);
            for (int s = s_min; s <= s_max; s++) {
                int t = nbounce - 1 - s;
                Spectrum path_val(0.f);
                evalBDPTPath(scene, camera_path, s, light_path, t, path_val);
                if (path_val.isZero()) continue;
                Float mis_weight = b_rec.computeWeight(scene, s, t, 0.f);
                if((mis_weight * path_val).allFinite())
                    b_rec.image[pixel_idx] += mis_weight * path_val;
                if constexpr (ad) {
                    if (t != -1 && s != -1)
                        sampleVertex(scene, b_rec.sampler, light_pathAD.val[t], camera_pathAD.val[s], light_pathAD.val.connect_vertex_cache);
                    light_pathAD.der.connect_vertex_cache.resize(light_pathAD.val.connect_vertex_cache.size());
                    for (auto & cache : light_pathAD.der.connect_vertex_cache) {
                        cache.setZero();
                    }

                    if (algorithm1_bdptwas::use_prefix_sum){
                        algorithm1_bdptwas::d_evalBoundary_prefix(sceneV.val, camera_pathAD, s, light_pathAD, t, 
                                                            mis_weight * path_val, b_rec.d_image[pixel_idx]);
                        if (t != -1 && s != -1){
                            algorithm1_bdptwas::d_warpSurface(light_pathAD.val.connect_vertex_cache, light_pathAD.der.connect_vertex_cache,
                                                                light_pathAD.der.connect_warped_X, light_pathAD.der.connect_div_warped_X);
                            const Intersection &connect_preV = camera_pathAD.val.vertices[s];
                            for (int j = 0; j < algorithm1_bdptwas::NUM_AUX_SAMPLES; j++) { // for every aux_point
                                const algorithm1_bdptwas::WASCache& cache = light_pathAD.val.connect_vertex_cache[j];
                                algorithm1_bdptwas::WASCache& d_cache = light_pathAD.der.connect_vertex_cache[j];
                                if (!cache.VALID) continue;
                                algorithm1_bdptwas::d_rayIntersectEdgeExt(sceneV, connect_preV, cache.aux_half, cache.aux, d_cache.aux.p);
                            }
                        }
                        for (int i = 1; i < s; i++)
                            camera_pathAD.val.vertex_container[i].hasGradient = true;
                        for (int i = 0; i < t; i++)
                            light_pathAD.val.vertex_container[i].hasGradient = true;
                    }
                    else{
                        algorithm1_bdptwas::evalWarp(sceneV.val, camera_pathAD.val, s, light_pathAD.val, t);
                        algorithm1_bdptwas::d_evalBoundary(sceneV.val, camera_pathAD, s, light_pathAD, t, 
                                                            mis_weight * path_val, b_rec.d_image[pixel_idx]);
                        algorithm1_bdptwas::d_evalWarp(sceneV.val, camera_pathAD, s, light_pathAD, t);
                        algorithm1_bdptwas::d_getWAS(sceneV, camera_pathAD, s, light_pathAD, t);
                    }

                    // if (isfinite(sceneV.gm.get(thread_id).shape_list[shape_idx]->param))
                    // {
                    //     Float delta = sceneV.gm.get(thread_id).shape_list[shape_idx]->param - param;
                    //     param = sceneV.gm.get(thread_id).shape_list[shape_idx]->param;
                    //     if (delta < 1e7 && delta > -1e7)
                    //         b_rec.grad_image[pixel_idx] += Spectrum(delta, 0.f, 0.f);
                    // }
                }
            }
        }
        if constexpr (ad) {
            if (algorithm1_bdptwas::use_prefix_sum){
                algorithm1_bdptwas::d_evalPrefix(camera_pathAD);
                algorithm1_bdptwas::d_evalPrefix(light_pathAD);
                algorithm1_bdptwas::d_evalWarpPath(sceneV.val, camera_pathAD);
                algorithm1_bdptwas::d_evalWarpPath(sceneV.val, light_pathAD);
                algorithm1_bdptwas::d_getWASPath(sceneV, camera_pathAD);
                algorithm1_bdptwas::d_getWASPath(sceneV, light_pathAD);
                camera_pathAD.val.resetHasGradient();
                light_pathAD.val.resetHasGradient();
            }
            camera_pathAD.der.setZero();
            light_pathAD.der.setZero();
#ifdef FORWARD
            if (isfinite(sceneV.gm.get(thread_id).shape_list[shape_idx]->param))
            {
                Float delta = sceneV.gm.get(thread_id).shape_list[shape_idx]->param - param;
                param = sceneV.gm.get(thread_id).shape_list[shape_idx]->param;
                if (delta < 1e7 && delta > -1e7)
                    b_rec.grad_image[pixel_idx] += Spectrum(delta, 0.f, 0.f);
            }
#endif
        }

#endif

        Array4 sensor_vals;
        Matrix2x4 pix_uvs;
        Vector wo;
        for (int t = 1; t < light_path.size(); t++) {
            Intersection& its = light_path[t];
            if (!scene.isVisible(its.p, true, camera.cpos, true)) continue;
            camera.sampleDirect(its.p, pix_uvs, sensor_vals, wo);
            if (sensor_vals.isZero()) continue;
            Vector wi = its.toWorld(its.wi);
            Vector wo_local = its.toLocal(wo);
            /* Prevent light leaks due to the use of shading normals -- [Veach, p. 158] */
            Float wiDotGeoN = wi.dot(its.geoFrame.n), woDotGeoN = wo.dot(its.geoFrame.n);
            if (wiDotGeoN * its.wi.z() <= 0 || woDotGeoN * wo_local.z() <= 0) continue;
            for (int i = 0; i < 4; i++) {
                if (sensor_vals(i) < Epsilon) continue;
                Spectrum path_val = sensor_vals(i) * its.value * its.evalBSDF(wo_local, EBSDFMode::EImportanceWithCorrection);
                if (path_val.isZero()) continue;
                int pixel_idx = camera.getPixelIndex(pix_uvs.col(i));
                light_path.pixelIdx = unravel_index(pixel_idx, {camera.width, camera.height});
#ifdef BDPT_PARTICLE
                //Particle tracer with antithetic sampling
                Float mis_weight = 1.0f;
                if (antithetic_success) {
                    Float pdf_anti = anti_its.pdf / camera.geometric(anti_its.p, anti_its.geoFrame.n);
                    pdf_anti *= camera.geometric(light_path[t].p, light_path[t].geoFrame.n);
                    mis_weight = 1.0 / (1.0 + misRatio(pdf_anti, light_path[t].pdf));
                }
                b_rec.image[pixel_idx] += path_val * mis_weight;
#else               
                Float mis_weight = b_rec.computeWeight(scene, -1, t, sensor_vals(i));
                b_rec.image[pixel_idx] += path_val * mis_weight;
#endif
                if constexpr (ad) {
                    if (!algorithm1_bdptwas::use_prefix_sum){
                        algorithm1_bdptwas::evalWarp(sceneV.val, camera_pathAD.val, -1, light_pathAD.val, t);
                    }
                    // algorithm1_bdptwas::evalBoundary(sceneV.val, camera_pathAD.val, -1, light_pathAD.val, t, mis_weight * path_val);
                    algorithm1_bdptwas::d_evalBoundary(sceneV.val, camera_pathAD, -1, light_pathAD, t, 
                                                    mis_weight * path_val, b_rec.d_image[pixel_idx]);
                    algorithm1_bdptwas::d_evalWarp(sceneV.val, camera_pathAD, -1, light_pathAD, t);
                    algorithm1_bdptwas::d_getWAS(sceneV, camera_pathAD, -1, light_pathAD, t);
                    // algorithm1_bdptwas::d_evalPath(sceneV, camera_pathAD, -1, light_pathAD, t,
                    //                             mis_weight, b_rec.d_image[pixel_idx]);
#ifdef FORWARD

                    if (isfinite(sceneV.gm.get(thread_id).shape_list[shape_idx]->param))
                    {
                        Float delta = sceneV.gm.get(thread_id).shape_list[shape_idx]->param - param;
                        param = sceneV.gm.get(thread_id).shape_list[shape_idx]->param;
                        if (delta < 1e7 && delta > -1e7)
                            b_rec.grad_image[pixel_idx] += Spectrum(delta, 0.f, 0.f);
                    }
#endif
                }
            }
        }        
        if constexpr (ad) {
            camera_pathAD.val.setZero();
            light_pathAD.val.setZero();
            camera_pathAD.der.setZero();
            light_pathAD.der.setZero();
            camera_pathAD.val.clear();
            light_pathAD.val.clear();
            camera_pathAD.der.clear();
            light_pathAD.der.clear();
        }

    }
}

ArrayXd BDPTWAS::renderC(const Scene &scene, const RenderOptions &options) const
{
    int size_block = scene.camera.getNumPixels();
    int num_block = options.num_samples;
    const auto &camera = scene.camera;
    // int num_pixels = camera.getNumPixels();
    const int nworker = omp_get_num_procs();
    BlockedImage blocks({camera.width, camera.height}, {16, 16});
    std::vector<Spectrum> spec_list(camera.getNumPixels(), Spectrum::Zero());
    ThreadManager thread_manager(spec_list, nworker);
    int blockProcessed = 0;
    Timer _("Render interior");

#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < blocks.m_BlocksTotal; i++)
    {
        ImageBlock block = blocks.getBlock(i);
        for (Array2i pixelIdx = block.curPixel(); block.hasNext(); pixelIdx = block.nextPixel())
        {
            int pixel_ravel_idx = ravel_multi_index(pixelIdx, {camera.width, camera.height});
            RndSampler sampler(options.seed, pixel_ravel_idx);
            // Spectrum pixel_val = Spectrum::Zero();
            int thread_id = omp_get_thread_num();
            for (int j = 0; j < options.num_samples; j++)
            {
                BDPTWASRecord b_rec = BDPTWASRecord(&sampler, options.max_bounces, pixelIdx, thread_manager.get(thread_id));
                bidirTrace<false>(scene, b_rec);
            }
        }

        if (verbose)
// #pragma omp critical
            progressIndicator(static_cast<Float>(++blockProcessed) / blocks.size());
    }
    if (verbose)
        std::cout << std::endl;

    thread_manager.merge();
    size_t num_samples = size_t(size_block) * num_block;
    for (auto &spec : spec_list)
        spec /= num_samples;
    return from_spectrum_list_to_tensor(spec_list, camera.getNumPixels());
}


BDPTWAS::BDPTWAS(int NUM_AUX, Float base_sigma, Float power, bool prefix_sum) {
    algorithm1_bdptwas::NUM_AUX_SAMPLES = NUM_AUX;
    algorithm1_bdptwas::base_sigma = base_sigma;
    algorithm1_bdptwas::power = power;
    algorithm1_bdptwas::use_prefix_sum = prefix_sum;
}

ArrayXd BDPTWAS::renderD(SceneAD &sceneAD, const RenderOptions &options, const ArrayXd &d_image) const
{
    // int size_block = sceneAD.val.camera.getNumPixels();
    // int num_block = options.num_samples;
    const auto &camera = sceneAD.val.camera;
    // int num_pixels = camera.getNumPixels();
    const int nworker = omp_get_num_procs();
    BlockedImage blocks({camera.width, camera.height}, {16, 16});
    // d_scene
    GradientManager<Scene> &gm = sceneAD.gm;
    // image
    std::vector<Spectrum> spec_list(camera.getNumPixels(), Spectrum::Zero());
    ThreadManager thread_manager(spec_list, nworker);
    // d_image
    auto d_image_spec_list = from_tensor_to_spectrum_list(
        d_image / sceneAD.val.camera.getNumPixels() / options.num_samples, camera.getNumPixels());
    // gradient image
    std::vector<Spectrum> grad_image(std::vector<Spectrum>(camera.getNumPixels(), Spectrum::Zero()));
    ThreadManager grad_images(grad_image, nworker);
    int blockProcessed = 0;
    Timer _("Render interior");

#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < blocks.m_BlocksTotal; i++)
    {
        ImageBlock block = blocks.getBlock(i);
        for (Array2i pixelIdx = block.curPixel(); block.hasNext(); pixelIdx = block.nextPixel())
        {
            int pixel_ravel_idx = ravel_multi_index(pixelIdx, {camera.width, camera.height});
            RndSampler sampler(options.seed, pixel_ravel_idx);
            // Spectrum pixel_val = Spectrum::Zero();
            int thread_id = omp_get_thread_num();
            for (int j = 0; j < options.num_samples; j++)
            {
                BDPTWASRecordAD b_recAD = BDPTWASRecordAD(&sampler, options.max_bounces, pixelIdx,
                                                  thread_manager.get(thread_id), d_image_spec_list, grad_images.get(thread_id));
                bidirTrace<true>(sceneAD, b_recAD);
            }
        }

        if (verbose)
#pragma omp critical
            progressIndicator(static_cast<Float>(++blockProcessed) / blocks.size());
    }
    if (verbose)
        std::cout << std::endl;

    thread_manager.merge();
    grad_images.merge();
    sceneAD.gm.merge();
    // size_t num_samples = size_t(size_block) * num_block;
    // for (auto &spec : spec_list)
    //     spec /= num_samples;
    // for (auto &spec : grad_image)
    //     spec /= num_samples;
    return from_spectrum_list_to_tensor(grad_image, camera.getNumPixels());
}