#include "boundary_pixel.h"
#include <render/scene.h>
#include <core/math_func.h>
#include <core/timer.h>
#include <render/imageblock.h>
namespace
{
    struct BoundaryQueryRecord
    {
        Vector2i pixel_idx;
        RndSampler *sampler;
        int max_bounces;
        bool enable_antithetic = true;
        bool two_point_antithetic = true;
    };

    [[maybe_unused]] std::tuple<Spectrum, bool, Intersection>
    __radiance(const Scene &scene, const Ray &_ray,
               const RadianceQueryRecord &rRec)
    {
        Intersection its_first;
        bool its_first_valid = false;
        // loop variables : ray, its, mRec, ret, depth, throughput, incEmission, medium
        Ray ray(_ray);
        Intersection its;
        MediumSamplingRecord mRec;
        Spectrum ret = Spectrum::Zero();
        int depth = 0;
        Spectrum throughput = Spectrum::Ones();
        bool incEmission = true;
        const Medium *medium = nullptr;
        int med_id = -1; // FIXME: assume the camera is outside any shapes
        int shape_id = -1;
        med_id = rRec.med_id;
        if (rRec.med_id != -1)
            medium = scene.getMedium(rRec.med_id);
        Vector preX = scene.camera.cpos;
        Float pdfFailure = 1.; // keep track of the pdf of hitting a surface
        Float pdfSuccess = 1.; // keep track of the pdf of hitting a medium
        RndSampler *sampler = rRec.sampler;
        int max_bounces = rRec.max_bounces;
        Float eta = 1.0f;
        scene.rayIntersect(ray, false, its);
        while (depth <= max_bounces)
        {
            bool inside_med = medium != nullptr &&
                              medium->sampleDistance(Ray(ray), its.t, sampler, mRec);
            if (inside_med) // sampled a medium interaction
            {
                if (incEmission)
                {
                    its_first_valid = true;
                    its_first = Intersection{mRec, med_id, pdfSuccess * mRec.pdfSuccess};
                }

                if (depth >= max_bounces)
                    break;

                const PhaseFunction *phase = scene.phase_list[medium->phase_id];
                throughput *= mRec.sigmaS * mRec.transmittance / mRec.pdfSuccess;
                // ====================== emitter sampling =========================
                DirectSamplingRecord dRec(mRec.p);
                Spectrum value = scene.sampleAttenuatedEmitterDirect(
                    dRec, sampler->next2D(), sampler, mRec.medium);
                if (!value.isZero())
                {
                    Float phaseVal = phase->eval(-ray.dir, dRec.dir);
                    if (phaseVal != 0)
                    {
                        Float phasePdf = phase->pdf(-ray.dir, dRec.dir);
                        Float mis_weight = miWeight(dRec.pdf / dRec.G, phasePdf);
                        ret += throughput * value * phaseVal * mis_weight;
                    }
                }

                // ====================== phase sampling =============================
                Vector wo;
                Float phaseVal = phase->sample(-ray.dir, sampler->next2D(), wo);
                Float phasePdf = phase->pdf(-ray.dir, wo);
                if (phaseVal == 0)
                    break;
                throughput *= phaseVal;
                pdfFailure = phasePdf;
                pdfSuccess = phasePdf;
                ray = Ray(mRec.p, wo);

                value = scene.rayIntersectAndLookForEmitter(
                    ray, false, sampler, mRec.medium, its, dRec);
                if (!value.isZero())
                {
                    Float pdf_emitter = scene.pdfEmitterDirect(dRec);
                    Float mis_weight = miWeight(phasePdf, pdf_emitter / dRec.G);
                    ret += throughput * value * mis_weight;
                }

                // update loop variables
                incEmission = false;
                preX = mRec.p;
            }
            else // sampled a surface interaction
            {
                if (medium)
                {
                    pdfFailure *= mRec.pdfFailure;
                    pdfSuccess *= mRec.pdfFailure;
                    throughput *= mRec.transmittance / mRec.pdfFailure;
                }
                if (!its.isValid())
                    break;

                if (its.isEmitter() && incEmission)
                    ret += throughput * its.Le(-ray.dir);

                if (depth >= max_bounces)
                    break;

                // ====================== emitter sampling =========================
                DirectSamplingRecord dRec(its);
                if (!its.getBSDF()->isNull())
                {
                    if (incEmission)
                    {
                        its_first_valid = true;
                        its_first = its;
                    }

                    Spectrum value = scene.sampleAttenuatedEmitterDirect(
                        dRec, its, sampler->next2D(), sampler, medium);

                    if (!value.isZero())
                    {
                        Spectrum bsdfVal = its.evalBSDF(its.toLocal(dRec.dir));
                        Float bsdfPdf = its.pdfBSDF(its.toLocal(dRec.dir));
                        Float mis_weight = miWeight(dRec.pdf / dRec.G, bsdfPdf);
                        ret += throughput * value * bsdfVal * mis_weight;
                    }
                }

                // ====================== BSDF sampling =============================
                Vector wo;
                Float bsdfPdf, bsdfEta;
                Spectrum bsdfWeight = its.sampleBSDF(sampler->next3D(), wo, bsdfPdf, bsdfEta);
                if (bsdfWeight.isZero())
                    break;

                wo = its.toWorld(wo);
                ray = Ray(its.p, wo);

                throughput *= bsdfWeight;
                if (its.isMediumTransition())
                {
                    med_id = its.getTargetMediumId(wo);
                    medium = its.getTargetMedium(wo);
                }
                if (its.getBSDF()->isNull())
                {
                    scene.rayIntersect(ray, true, its);
                }
                else
                {
                    pdfFailure = bsdfPdf;
                    pdfSuccess = bsdfPdf;
                    Spectrum value = scene.rayIntersectAndLookForEmitter(
                        ray, true, sampler, medium, its, dRec);
                    if (!value.isZero())
                    {
                        Float mis_weight = miWeight(bsdfPdf, dRec.pdf / dRec.G);
                        ret += throughput * value * mis_weight;
                    }
                    incEmission = false;
                    preX = ray.org;
                }
            }
            depth++;
        }
        return {ret, its_first_valid, its_first};
    }

    Float _normal_velocity(const PixelBoundarySegmentInfo &seg, const Vector &vB, const Vector &n)
    {
        return normal_velocity(seg.xD, seg.xB, vB, seg.xS_0, seg.xS_1, seg.xS_2, n);
    }

    void __normal_velocity(const PixelBoundarySegmentInfo &seg, const Vector &vB, const Vector &n, Float &ret)
    {
        ret = _normal_velocity(seg, vB, n);
    }

    void d_normal_velocity(const PixelBoundarySegmentInfo &seg, PixelBoundarySegmentInfo &d_seg,
                           const Vector &vB, const Vector &n)
    {
        Float ret = 0., d_ret = 1.;
        #ifdef ENZYME
        __enzyme_autodiff((void *)__normal_velocity,
                          enzyme_dup, &seg, &d_seg,
                          enzyme_const, &vB,
                          enzyme_const, &n,
                          enzyme_dup, &ret, &d_ret);
        #endif
    }

    Eigen::Matrix<Float, 9, 1> normalVelocities(const Scene &scene, PixelBoundarySamplingRecord &bRec)
    {
        auto &camera = scene.camera;
        Vector dir_local = bRec.dir_local;
        Float t = 1.0 / dir_local.z();
        dir_local *= t;
        Vector xB = camera.camToWorld(dir_local);
        Vector dir_edge = detach(scene.camera.toWorld(bRec.dir_edge));
        Vector dir_world = scene.camera.toWorld(bRec.dir_local);
        Vector dir_visible = detach(scene.camera.toWorld(bRec.dir_visible));
        Vector nB = detach(dir_edge.cross(dir_world.normalized())); // normal of incident plane
        nB *= math::signum(nB.dot(dir_visible));
        Vector xD = camera.cpos;
        const Shape *shapeS = scene.shape_list[bRec.shape_id];
        const auto &indS = shapeS->getIndices(bRec.triangle_id);
        const Vector &xS_0 = shapeS->getVertex(indS[0]);
        const Vector &xS_1 = shapeS->getVertex(indS[1]);
        const Vector &xS_2 = shapeS->getVertex(indS[2]);

        PixelBoundarySegmentInfo segInfo;
        segInfo.xS_0 = xS_0;
        segInfo.xS_1 = xS_1;
        segInfo.xS_2 = xS_2;

        segInfo.xB = xB;
        segInfo.xD = xD;

        PixelBoundarySegmentInfo d_segInfo;
        d_segInfo.setZero();

        d_normal_velocity(segInfo, d_segInfo, dir_edge, nB);

        return d_segInfo.getVelocities();
    }

    Float max_normal_velocity(const Scene &scene, PixelBoundarySamplingRecord &bRec)
    {
        return normalVelocities(scene, bRec).maxCoeff();
    }

    double tetra_vol(const Vector3d &a, const Vector3d &b, const Vector3d &c, const Vector3d &d)
    {
        return std::abs((a - d).dot((b - d).cross(c - d))) / 6.;
    }

    Vector4 getBarycentric(const Vector &p,
                           const Vector3d &a, const Vector3d &b, const Vector3d &c, const Vector3d &d)
    {
        Vector4 bary = Vector4::Zero();
        bary[0] = tetra_vol(b, c, d, p);
        bary[1] = tetra_vol(a, c, d, p);
        bary[2] = tetra_vol(a, b, d, p);
        bary[3] = tetra_vol(a, b, c, p);
        bary /= tetra_vol(a, b, c, d);
        return bary;
    }
    // this term has zero value, but non-zero gradient
    Float velocity(const Scene &scene, PixelBoundarySamplingRecord &bRec)
    {
        Vector org = scene.camera.cpos;
        Vector xD = org;
        Vector dir_world = scene.camera.toWorld(bRec.dir_local);
        Vector dir_visible = detach(scene.camera.toWorld(bRec.dir_visible));
        Vector dir_edge = detach(scene.camera.toWorld(bRec.dir_edge));
        Ray ray(org, dir_world.normalized());
        Vector nB = detach(dir_edge.cross(ray.dir)).normalized(); // normal of incident plane
        nB *= math::signum(nB.dot(dir_visible));
        if (bRec.onSurface_S)
        {
            const Shape *shape = scene.getShape(bRec.shape_id);
            const Vector3i &index = shape->getIndices(bRec.triangle_id);
            const Vector &x1 = shape->getVertex(index[0]),
                         &x2 = shape->getVertex(index[1]),
                         &x3 = shape->getVertex(index[2]);
            Vector x = rayIntersectTriangle2(x1, x2, x3, ray, EReference);
            Vector nB = detach(dir_edge.cross(ray.dir));               // normal of incident plane
            nB *= math::signum(nB.dot(dir_visible));
            Vector nS = detach((x2 - x1).cross(x3 - x1).normalized()); // normal of the intersection
            Vector u = detach(nB.cross(nS));                           // direction of the plane-plane intersection
            Vector n = detach(u.cross(nS));
            n *= math::signum(n.dot(nB));
            n = detach(n.normalized());
            Float x_proj = x.dot(detach(n));
            return x_proj - detach(x_proj);
        }
        else
        {
            assert(bRec.med_id_S != -1);
            assert(bRec.tet_id_S != -1);
            assert(abs(1.0 - bRec.barycentric4_S.sum()) < 1e-5);
            const Medium *medium = scene.medium_list[bRec.med_id_S];
            const Vector4i &indices = medium->getTet(bRec.tet_id_S);
            const Vector &xS_0 = medium->getVertex(scene, indices[0]);
            const Vector &xS_1 = medium->getVertex(scene, indices[1]);
            const Vector &xS_2 = medium->getVertex(scene, indices[2]);
            const Vector &xS_3 = medium->getVertex(scene, indices[3]);
            // getPoint
            Vector xS = Vector::Zero();
            for (int i = 0; i < 4; i++)
                xS += bRec.barycentric4_S[i] *
                      medium->getVertex(scene, indices[i]);
            Float dist = detach((xS - xD).norm());
            // trace xD and project xD back to reference
            xS = ray(dist);
            Vector4 bary = getBarycentric(xS, xS_0, xS_1, xS_2, xS_3);
            assert(bary.cwiseAbs().sum() > 0.99 && bary.cwiseAbs().sum() < 1.01);
            xS = detach(xS_0) * bary[0] +
                 detach(xS_1) * bary[1] +
                 detach(xS_2) * bary[2] +
                 detach(xS_3) * bary[3];
            Float xS_proj = xS.dot(detach(nB));
            return xS_proj - detach(xS_proj); // val = 0, der != 0
        }
    }

    void __velocity(const Scene &scene, PixelBoundarySamplingRecord &bRec, Float &res)
    {
        res = -velocity(scene, bRec);
    }

    void d_velocity(SceneAD &sceneAD, PixelBoundarySamplingRecord &pRec, Float d_res)
    {
        [[maybe_unused]] Float res;
        const int tid = omp_get_thread_num();
#if defined(ENZYME)
        __enzyme_autodiff((void *)__velocity,
                          enzyme_dup, &sceneAD.val, &sceneAD.gm.get(tid),
                          enzyme_const, &pRec,
                          enzyme_dup, &res, &d_res);
#endif
    }

    using Vector9d = Eigen::Matrix<Float, 9, 1>;
    using PixelMap = std::unordered_map<int, Vector9d>;

    std::tuple<bool, int, Vector9d> boundaryContribution(const Scene &scene, RadianceQueryRecord &rRec, Float rnd)
    {
        auto [ray, dir_visible, dir_local, dir_edge] = scene.camera.samplePrimaryBoundaryRay(rRec.pixel_idx, rnd);
        auto [value, valid_first, its_first] = __radiance(scene, ray, rRec);
        int pixel_idx;
        Vector dir;
        Spectrum value_sensor = Spectrum::Ones() * scene.camera.sampleDirect(its_first.p, 0., pixel_idx, dir);
        if (!valid_first)
            return {false, -1, Vector9d::Zero()};
        bool onSurface = its_first.type == EVSurface;
        PixelBoundarySamplingRecord pRec{onSurface, dir_local, dir_visible, dir_edge};
        assert(onSurface);
        if (onSurface)
        {
            pRec.shape_id = its_first.shape_id;
            pRec.triangle_id = its_first.triangle_id;
        }
        else
        {
            const Medium *medium = scene.medium_list[its_first.medium_id];
            int tet_id = medium->m_tetmesh.in_element(its_first.p);
            if (tet_id == -1)
                return {false, -1, Vector9d::Zero()};
            Vector4 barycentric4 = medium->m_tetmesh.getBarycentric(tet_id, its_first.p);
            pRec.med_id_S = its_first.medium_id;
            pRec.tet_id_S = tet_id;
            pRec.barycentric4_S = barycentric4;
        }
        auto velocities = normalVelocities(scene, pRec);

        Float t = 1.0 / dir_local.z();
        dir_local *= t;

        Float pdf = 0.25 / (2 * tan(scene.camera.m_fov * 0.5 * M_PI / 180.0) / scene.camera.height);

        Float baseValue = 0;
        Vector xD = scene.camera.cpos;

        Vector xB = scene.camera.camToWorld(dir_local); // FIXME this is not used
        Vector vB = (scene.camera.toWorld(dir_edge)).normalized();
        Vector xS = its_first.p;
        Float xD_xB = (xB - xD).norm();
        Float xD_xS = (xS - xD).norm();
        Float sB = sinB(xD, xB, vB, xS);
        Float sD = sinD(xD, xB, vB, xS, its_first.geoFrame.n);

        if (sB > Epsilon && sD > Epsilon)
        {
            baseValue = xD_xS / xD_xB * sB / sD;
        }
        Float cosS = std::abs(its_first.geoFrame.n.dot(-ray.dir));
        baseValue *= cosS;
        assert(!std::isnan(baseValue));
        assert(value.allFinite());
        assert(!std::isnan(baseValue));

        return {true, pRec.triangle_id, velocities * (value_sensor * value).sum() * baseValue / pdf};
    }

    // pixel guiding density
    Float guidingDensity(const Scene &scene, PixelQueryRecord &pRec)
    {
        PixelMap pixelMap;
        for (int i = 0; i < pRec.nsamples; i++)
        {
            Float rnd = pRec.sampler->next1D();
            pRec.sampler->save();
            auto [success_primal, tri_id_primal, velocities_primal] = boundaryContribution(scene, pRec, rnd);
            if (success_primal)
            {
                if (pixelMap.find(tri_id_primal) == pixelMap.end())
                    pixelMap[tri_id_primal] = velocities_primal;
                else
                    pixelMap[tri_id_primal] += velocities_primal;
            }
            if (pRec.enable_antithetic)
            {
                Float vals[] = {
                    rnd < 0.5 ? rnd + 0.5 : rnd - 0.5,
                    rnd < 0.75 ? 0.75 - rnd : 1.75 - rnd,
                    rnd < 0.25 ? 0.25 - rnd : 1.25 - rnd};

                for (Float rnd_anti : vals)
                {
                    pRec.sampler->restore();
                    auto [success_dual, tri_id_dual, velocities_dual] = boundaryContribution(scene, pRec, rnd_anti);
                    if (success_dual)
                    {
                        if (pixelMap.find(tri_id_dual) == pixelMap.end())
                            pixelMap[tri_id_dual] = velocities_dual;
                        else
                            pixelMap[tri_id_dual] += velocities_dual;
                    }
                }
            }
        }

        // // find the maximum density among all triangles
        // Float density = 0;
        // for (auto &[tri_id, velocities] : pixelMap)
        // {
        //     density = std::max(velocities.cwiseAbs().maxCoeff(), density);
        // }

        Float density = 0;
        if (pixelMap.size() > 0)
            density = pixelMap.begin()->second[0];

        assert(!std::isnan(density));
        return density / pRec.nsamples;
    }

    // density map
    std::vector<Float> guidingDensity(const Scene &scene, RndSampler *sampler, const RenderOptions &options)
    {
        const auto &camera = scene.camera;
        int height = camera.height;
        int width = camera.width;
        std::vector<Float> res(height * width);
        for (int i = 0; i < res.size(); i++)
        {
            PixelQueryRecord pRec{unravel_index(i, {camera.width, camera.height}), sampler, options.max_bounces, 100, true};
            res[i] = guidingDensity(scene, pRec);
        }
        return res;
    }

    Spectrum pixelColor(const Scene &scene, RadianceQueryRecord &rRec)
    {
        // Float fov_factor = 1.0 / tan(scene.camera.m_fov * 0.5 * M_PI / 180.0);
        Float pdf = 0.25 / (2 * tan(scene.camera.m_fov * 0.5 * M_PI / 180.0) / scene.camera.height);
        auto [ray, dir_visible, dir_local, dir_edge] = scene.camera.samplePrimaryBoundaryRay(rRec.pixel_idx, rRec.sampler->next1D());
        auto [value, is_first_valid, its_first] = __radiance(scene, ray, rRec);
        Intersection its;
        if (!scene.rayIntersect(ray, false, its, EReference))
            return Spectrum::Zero();
        int pixel_idx;
        Vector dir;
        Spectrum value_sensor = Spectrum::Ones() * scene.camera.sampleDirect(its.p, 0., pixel_idx, dir);
        Float dist = (its.p - ray.org).norm();
        return value;
    }

    __attribute__((optnone)) void sample_boundary_term(SceneAD &sceneAD, Float rnd, BoundaryQueryRecord &bRec, Spectrum d_res)
    {
        const Scene &scene = sceneAD.val;
        Float pdf = 0.25 / (2 * tan(scene.camera.m_fov * 0.5 * M_PI / 180.0) / scene.camera.height);
        auto [ray, dir_visible, dir_local, dir_edge] = scene.camera.samplePrimaryBoundaryRay(bRec.pixel_idx, rnd);
        RadianceQueryRecord rRec(bRec.pixel_idx, bRec.sampler, bRec.max_bounces);
        auto [value, is_first_valid, its_first] = __radiance(scene, ray, rRec);
        if (value.isZero())
            return;
        int pixel_idx;
        Vector dir;
        Spectrum value_sensor = Spectrum::Ones() * scene.camera.sampleDirect(its_first.p, 0., pixel_idx, dir);
        if (!is_first_valid)
            return;
        bool onSurface = its_first.type & EVSurface;
        PixelBoundarySamplingRecord pRec;
        pRec.dir_local = dir_local;
        pRec.dir_visible = dir_visible;
        pRec.dir_edge = dir_edge;
        if (onSurface)
        {
            pRec.onSurface_S = true;
            // for surface vertex
            pRec.shape_id = its_first.shape_id;
            pRec.triangle_id = its_first.triangle_id;
        }
        else
        {
            pRec.onSurface_S = false;
            const Medium *medium = scene.medium_list[its_first.medium_id];
            int tet_id = medium->m_tetmesh.in_element(its_first.p);
            if (tet_id == -1)
                return;
            Vector4 barycentric4 = medium->m_tetmesh.getBarycentric(tet_id, its_first.p);
            pRec.med_id_S = its_first.medium_id;
            pRec.tet_id_S = tet_id;
            pRec.barycentric4_S = barycentric4;
        }
        Float dist = (its_first.p - ray.org).norm();
        Float t = 1.0 / dir_local.z();
        dir_local *= t;

        Vector xD = scene.camera.cpos;
        Vector xB = scene.camera.camToWorld(dir_local);
        Vector vB = (scene.camera.toWorld(dir_edge)).normalized();
        Vector xS = its_first.p;
        Float xD_xB = (xB - xD).norm();
        Float xD_xS = (xS - xD).norm();
        Float baseValue = 0; // J * G
        if (onSurface)
        {
            Float sB = sinB(xD, xB, vB, xS);
            Float sD = sinD(xD, xB, vB, xS, its_first.geoFrame.n);

            if (sB > Epsilon && sD > Epsilon)
            {
                baseValue = xD_xS / xD_xB * sB / sD;
            }
            Float cosS = std::abs(its_first.geoFrame.n.dot(-ray.dir));
            baseValue *= cosS;
        }
        else
        {
            Float sB = sinB(xD, xB, vB, xS);
            if (sB > Epsilon)
            {
                baseValue = xD_xS / xD_xB * sB;
            }
        }
        d_velocity(sceneAD, pRec, (value * value_sensor * d_res * baseValue / pdf).sum());
    }

    __attribute__((optnone)) void boundary_term(SceneAD &sceneAD, BoundaryQueryRecord &bRec, Spectrum d_res)
    {
        const Scene &scene = sceneAD.val;
        Float rnd = bRec.sampler->next1D();
        if (!bRec.enable_antithetic)
        {
            sample_boundary_term(sceneAD, rnd, bRec, d_res);
        }
        else
        {
            Float vals[] = {
                rnd,
                rnd < 0.5 ? rnd + 0.5 : rnd - 0.5,
                rnd < 0.75 ? 0.75 - rnd : 1.75 - rnd,
                rnd < 0.25 ? 0.25 - rnd : 1.25 - rnd};

            bRec.sampler->save();
            if (!bRec.two_point_antithetic) {
                for (Float val : vals) {
                    bRec.sampler->restore();
                    sample_boundary_term(sceneAD, val, bRec, d_res / 4.0);
                }
            } else {
                sample_boundary_term(sceneAD, vals[0], bRec, d_res / 2.0);
                bRec.sampler->restore();
                sample_boundary_term(sceneAD, vals[1], bRec, d_res / 2.0);
            }
        }
    }
}

PixelBoundarySegmentInfo PixelBoundaryIntegrator::normalVelocity(const PixelBoundarySegmentInfo &seg, const Vector &vB, const Vector &n) const
{
    PixelBoundarySegmentInfo d_seg;
    d_seg.setZero();
    Float ret = 0., d_ret = 1.;
    #ifdef ENZYME
    __enzyme_autodiff((void *)__normal_velocity,
                      enzyme_dup, &seg, &d_seg,
                      enzyme_const, &vB,
                      enzyme_const, &n,
                      enzyme_dup, &ret, &d_ret);
    #endif
    return d_seg;
}

// pixel guiding density
Float PixelBoundaryIntegrator::guidingDensity(const Scene &scene, RenderOptions options, const Array2i &pixelIdx) const
{
    auto &camera = scene.camera;
    int pixel_ravel_idx = ravel_multi_index(pixelIdx, {camera.width, camera.height});
    RndSampler sampler(options.seed, pixel_ravel_idx);
    PixelQueryRecord pRec{pixelIdx, &sampler, options.max_bounces, options.sppe0, true};
    return ::guidingDensity(scene, pRec);
}

ArrayXd PixelBoundaryIntegrator::guide(const Scene &scene, RenderOptions options, Float weight) const
{
    const auto &camera = scene.camera;
    int height = camera.height;
    int width = camera.width;
    // guiding map
    std::vector<Float> res(height * width);
    BlockedImage blocks({camera.width, camera.height}, {16, 16});
    const int nworker = omp_get_num_procs();
    int blockProcessed = 0;
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < blocks.m_BlocksTotal; i++)
    {
        const int tid = omp_get_thread_num();
        ImageBlock block = blocks.getBlock(i);
        for (Array2i pixelIdx = block.curPixel(); block.hasNext(); pixelIdx = block.nextPixel())
        {
            int pixel_ravel_idx = ravel_multi_index(pixelIdx, {camera.width, camera.height});
            RndSampler sampler(options.seed, pixel_ravel_idx);
            PixelQueryRecord pRec{pixelIdx, &sampler, options.max_bounces, options.sppe0, true};
            res[pixel_ravel_idx] = ::guidingDensity(scene, pRec);
        }
#pragma omp critical
        progressIndicator(static_cast<Float>(++blockProcessed) / blocks.size());
    }
    guideMap = Eigen::Map<ArrayXd, Eigen::Unaligned>(res.data(), res.size());
    assert(weight > 0 && weight <= 1);
    guideMap = 1 - weight + weight * guideMap;
    return guideMap;
}

ArrayXd PixelBoundaryIntegrator::getSampleMap(const ArrayXd &_guideMap, int sppe0) const
{
    ArrayXd guideMap = _guideMap;
    assert(guideMap.size() > 0);
    // normalize
    guideMap /= guideMap.sum();
    int size = guideMap.size();
    ArrayXd sampleMap = ArrayXd::Zero(size);
    for (int i = 0; i < size; i++)
    {
        sampleMap[i] = sppe0 * guideMap[i] * size;
    }
    return sampleMap;
}

ArrayXd PixelBoundaryIntegrator::getVarMap(SceneAD &sceneAD, RenderOptions options, const ArrayXd &d_image) const
{
    const Scene &scene = sceneAD.val;
    GradientManager<Scene> &gm = sceneAD.gm;
    Scene &d_scene = sceneAD.der;
    gm.setZero(); // zero multi-thread gradient

    const int nworker = omp_get_num_procs();
    Timer _("Render interior");

    const auto &camera = scene.camera;
    assert(!camera.rect.isValid());
    BlockedImage blocks({camera.width, camera.height}, {16, 16});

    std::vector<Spectrum> d_image_spec_list = from_tensor_to_spectrum_list(
        d_image / options.sppe0, camera.getNumPixels());

    // gradient image
    std::vector<Spectrum> g_image_spec_list(camera.getNumPixels(), Spectrum::Zero());
#if defined(FORWARD)
    int shape_idx = scene.getShapeRequiresGrad();
#endif

    int blockProcessed = 0;
    ArrayXd varMap = ArrayXd::Zero(camera.getNumPixels());
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < blocks.m_BlocksTotal; i++)
    {
        const int tid = omp_get_thread_num();
        ImageBlock block = blocks.getBlock(i);
        for (Array2i pixelIdx = block.curPixel(); block.hasNext(); pixelIdx = block.nextPixel())
        {
            int pixel_ravel_idx = ravel_multi_index(pixelIdx, {camera.width, camera.height});
            RndSampler sampler(options.seed, pixel_ravel_idx);
            // adaptive sampling
            Float ex2 = 0.;
            Float ex = 0.;
            for (int j = 0; j < options.sppe0; j++)
            {
#ifdef FORWARD
                gm.get(tid).shape_list[shape_idx]->param = 0;
                Float param = gm.get(tid).shape_list[shape_idx]->param;
#endif
                BoundaryQueryRecord bRec{pixelIdx, &sampler, options.max_bounces, enable_antithetic, two_point_antithetic};
                boundary_term(sceneAD, bRec, d_image_spec_list[pixel_ravel_idx]);
#ifdef FORWARD
                if (isfinite(gm.get(tid).shape_list[shape_idx]->param))
                {
                    param = gm.get(tid).shape_list[shape_idx]->param - param;
                    g_image_spec_list[pixel_ravel_idx] += Spectrum(param, 0.f, 0.f);
                }
                ex2 += param * param;
                ex += param;
#endif
            }
#ifdef FORWARD
            Float var = ex2 / options.sppe0 - ex * ex / (options.sppe0 * options.sppe0);
            varMap[pixel_ravel_idx] = var;
#endif
        }
        if (verbose)
#pragma omp critical
            progressIndicator(static_cast<Float>(++blockProcessed) / blocks.size());
    }
    if (verbose)
        std::cout << std::endl;

    gm.merge();

    /* normal related */
#ifdef NORMAL_PREPROCESS
    Timer _2("normal preprocess");
    d_precompute_normal(scene, d_scene);
#endif

    return varMap;
}

Properties PixelBoundaryIntegrator::velocities(const Scene &scene, const PixelBoundarySamplingRecord &bRec) const
{
    Vector org = scene.camera.cpos;
    Vector xD = org;
    Vector dir_world = scene.camera.toWorld(bRec.dir_local);
    Vector dir_visible = detach(scene.camera.toWorld(bRec.dir_visible));
    Vector dir_edge = detach(scene.camera.toWorld(bRec.dir_edge));
    Ray ray(org, dir_world.normalized());
    Vector nB = detach(dir_edge.cross(ray.dir)).normalized(); // normal of incident plane
    nB *= math::signum(nB.dot(dir_visible));
    assert(bRec.onSurface_S);
    const Shape *shape = scene.getShape(bRec.shape_id);
    const Vector3i &index = shape->getIndices(bRec.triangle_id);
    const Vector &x1 = shape->getVertex(index[0]),
                 &x2 = shape->getVertex(index[1]),
                 &x3 = shape->getVertex(index[2]);
    Vector x = rayIntersectTriangle2(x1, x2, x3, ray, EReference);
    Vector nB1 = detach(dir_edge.cross(ray.dir));              // normal of incident plane
    nB1 *= math::signum(nB1.dot(dir_visible));
    Vector nS = detach((x2 - x1).cross(x3 - x1).normalized()); // normal of the intersection
    Vector u = detach(nB1.cross(nS));                          // direction of the plane-plane intersection
    Vector n = detach(u.cross(nS));
    n *= math::signum(n.dot(nB1));
    n = detach(n.normalized());
    Float x_proj = x.dot(detach(n));
    return Properties({{"bRec.dir_local", bRec.dir_local},
                       {"bRec.dir_visible", bRec.dir_visible},
                       {"bRec.dir_edge", bRec.dir_edge},
                       {"dir_world", dir_world},
                       {"dir_visible", dir_visible},
                       {"dir_edge", dir_edge},
                       {"nB", nB},
                       {"nB1", nB1},
                       {"nS", nS},
                       {"u", u},
                       {"n", n},
                       {"u.cross(nS)", u.cross(nS)}});
}

Spectrum PixelBoundaryIntegrator::Li(const Scene &scene, const Ray &ray, const RadianceQueryRecord &rRec) const
{
    auto [value, valid_first, its_first] = __radiance(scene, ray, rRec);
    return value;
}

ArrayXd PixelBoundaryIntegrator::renderC(const Scene &scene, const RenderOptions &options) const
{
    std::cout << "PixelBoundaryIntegrator renderC" << std::endl;
    Timer _("Forward rendering");
    const int nworker = omp_get_num_procs();
    const auto &camera = scene.camera;
    BlockedImage blocks({camera.width, camera.height}, {16, 16});
    std::vector<Spectrum> spec_list(camera.getNumPixels(), Spectrum::Zero());
    int blockProcessed = 0;
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < blocks.m_BlocksTotal; i++)
    {
        ImageBlock block = blocks.getBlock(i);
        for (Array2i pixelIdx = block.curPixel(); block.hasNext(); pixelIdx = block.nextPixel())
        {
            int pixel_ravel_idx = ravel_multi_index(pixelIdx, {camera.width, camera.height});
            RndSampler sampler(options.seed, pixel_ravel_idx);
            for (int j = 0; j < options.sppe0; j++)
            {
                RadianceQueryRecord rRec{pixelIdx, &sampler, options.max_bounces};
                spec_list[pixel_ravel_idx] += pixelColor(scene, rRec);
            }
            spec_list[pixel_ravel_idx] /= options.sppe0;
        }
        if (verbose)
#pragma omp critical
            progressIndicator(static_cast<Float>(++blockProcessed) / blocks.size());
    }
    if (verbose)
        std::cout << std::endl;
    return from_spectrum_list_to_tensor(spec_list, camera.getNumPixels());
}

ArrayXd PixelBoundaryIntegrator::renderD(SceneAD &sceneAD,
                                         const RenderOptions &options, const ArrayXd &d_image) const
{
    assert(options.sppe0 == 0 || dynamic_cast<BoxFilter *>(sceneAD.val.camera.rfilter) != nullptr);
    // guiding info
    bool enable_guiding = (!(guideMap.size() == 0));
    if (!enable_guiding)
        std::cout << "PixelBoundaryIntegrator renderD without guiding" << std::endl;
    else
    {
        std::cout << "PixelBoundaryIntegrator renderD with guiding" << std::endl;
        // FIXME : normalize guiding map. remember to handle epsilon
        guideMap /= guideMap.sum();
    }

    const Scene &scene = sceneAD.val;
    GradientManager<Scene> &gm = sceneAD.gm;
    Scene &d_scene = sceneAD.der;
    gm.setZero(); // zero multi-thread gradient

    const int nworker = omp_get_num_procs();
    Timer _("Render interior");

    const auto &camera = scene.camera;
    assert(!camera.rect.isValid());
    BlockedImage blocks({camera.width, camera.height}, {16, 16});

    std::vector<Spectrum> d_image_spec_list = from_tensor_to_spectrum_list(
        d_image, camera.getNumPixels());

    std::vector<int> ad_sample_count(camera.getNumPixels(), 0);
    // adaptive sampling
    for (int i = 0; i < ad_sample_count.size(); i++)
    {
        ad_sample_count[i] = options.sppe0;
        if (enable_guiding)
        {
            Float guiding_density = guideMap[i];
            ad_sample_count[i] = ad_sample_count[i] * guiding_density * guideMap.size() + 1;
        }
    }

    for (int i = 0; i < ad_sample_count.size(); i++)
    {
        d_image_spec_list[i] /= ad_sample_count[i];
    }

    // gradient image
    std::vector<Spectrum> g_image_spec_list(camera.getNumPixels(), Spectrum::Zero());
#if defined(FORWARD)
    int shape_idx = scene.getShapeRequiresGrad();
#endif

    int blockProcessed = 0;
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < blocks.m_BlocksTotal; i++)
    {
        const int tid = omp_get_thread_num();
        ImageBlock block = blocks.getBlock(i);
        for (Array2i pixelIdx = block.curPixel(); block.hasNext(); pixelIdx = block.nextPixel())
        {
#ifdef FORWARD
            gm.get(tid).shape_list[shape_idx]->param = 0;
            Float param = gm.get(tid).shape_list[shape_idx]->param;
#endif
            int pixel_ravel_idx = ravel_multi_index(pixelIdx, {camera.width, camera.height});
            RndSampler sampler(options.seed, pixel_ravel_idx);
            // adaptive sampling
            for (int j = 0; j < ad_sample_count[pixel_ravel_idx]; j++)
            {
                BoundaryQueryRecord bRec{pixelIdx, &sampler, options.max_bounces, enable_antithetic, two_point_antithetic};
                boundary_term(sceneAD, bRec, d_image_spec_list[pixel_ravel_idx]);
            }
#ifdef FORWARD
            if (isfinite(gm.get(tid).shape_list[shape_idx]->param))
            {
                param = gm.get(tid).shape_list[shape_idx]->param - param;
                g_image_spec_list[pixel_ravel_idx] += Spectrum(param, 0.f, 0.f);
            }
#endif
        }
        if (verbose)
#pragma omp critical
            progressIndicator(static_cast<Float>(++blockProcessed) / blocks.size());
    }
    if (verbose)
        std::cout << std::endl;

    gm.merge();
    d_scene.configureD(scene);
    /* normal related */
#ifdef NORMAL_PREPROCESS
    Timer _2("normal preprocess");
    d_precompute_normal(scene, d_scene);
#endif

    return from_spectrum_list_to_tensor(g_image_spec_list, camera.getNumPixels());
}