#include "ptracer2.h"
#include <render/imageblock.h>
#include <core/ray.h>
#include <core/sampler.h>
#include <render/scene.h>
#include <core/timer.h>
#include <iomanip>
#include "algorithm1.h"
#include <core/statistics.h>

#define USE_NEW_CAMERA_SAMPLING

namespace
{
    struct ParticleRecord
    {
        mutable RndSampler *sampler;
        int max_bounces;
        int bounces = 0;
        int med_id = -1;
        const Medium *medium = nullptr;
        Spectrum throughput;
        std::vector<Spectrum> &image;
        bool enable_antithetic = true;
        bool is_equal_trans = false;
        mutable algorithm1_ptracer::LightPath path;
        algorithm1_ptracer::LightPath &getPath() const
        {
            return path;
        }
    };

    struct ParticleRecordAD
    {
        mutable RndSampler *sampler;
        int max_bounces;
        int bounces = 0;
        int med_id = -1;
        const Medium *medium = nullptr;
        Spectrum throughput;
        std::vector<Spectrum> &image;
        std::vector<Spectrum> &d_image;
        std::vector<Spectrum> &grad_image;
        bool enable_antithetic = true;
        bool is_equal_trans = false;
        mutable algorithm1_ptracer::LightPath path;
        algorithm1_ptracer::LightPath &getPath() const
        {
            return path;
        }
    };

    template <typename T, bool ad>
    struct type_traits;

    template <>
    struct type_traits<Scene, false>
    {
        using type = const Scene;
    };
    template <>
    struct type_traits<Scene, true>
    {
        using type = SceneAD;
    };

    template <typename T>
    auto &value(T &t);

    template <>
    auto &value(const Scene &t)
    {
        return t;
    }

    template <>
    auto &value(SceneAD &t)
    {
        return t.val;
    }

    template <>
    struct type_traits<ParticleRecord, false>
    {
        using type = ParticleRecord;
    };
    template <>
    struct type_traits<ParticleRecord, true>
    {
        using type = ParticleRecordAD;
    };

    template <typename T, bool ad>
    using Type = typename type_traits<T, ad>::type;

    void handleEmission(const Scene &scene, const Intersection &its, const ParticleRecord &p_rec)
    {
#ifdef USE_NEW_CAMERA_SAMPLING
        CameraDirectSamplingRecord cRec;
        Spectrum value(scene.sampleAttenuatedSensorDirect(its, p_rec.sampler, cRec));
        if (!value.isZero(Epsilon) && cRec.baseVal > Epsilon)
        {
            value *= its.ptr_emitter->evalDirection(its.geoFrame.n, cRec.dir) * p_rec.throughput;
            if (!value.isZero(Epsilon))
                scene.camera.accumulateDirect(cRec, value, p_rec.sampler, &p_rec.image[0]);
        }
#else
        auto &path = p_rec.path;
        Matrix2x4 pix_uvs;
        Vector dir;
        Array4 sensor_vals = scene.sampleAttenuatedSensorDirect(
            its, p_rec.sampler, pix_uvs, dir);
        if (sensor_vals.isZero())
            return;
        for (int i = 0; i < 4; i++)
        {
            if (sensor_vals[i] < Epsilon)
                continue;
            Spectrum value = p_rec.throughput * sensor_vals[i] * its.ptr_emitter->evalDirection(its.geoFrame.n, dir);
            int pixel_idx = scene.camera.getPixelIndex(pix_uvs.col(i));
            p_rec.image[pixel_idx] += value;

            // path.append(scene.camera);                                                                            // NOTE
            // path.vertices.back().pixel_idx = unravel_index(pixel_idx, {scene.camera.width, scene.camera.height}); // NOTE
            // // NOTE anti_path is the inplace modified path
            // // FIXME
            // auto [valid, anti_its, pdf1A, pdf1B, pdf2A, pdf2B] = algorithm1_ptracer::antithetic(scene, path);
            // if (valid && p_rec.enable_antithetic)
            // {
            //     // generate anti_path
            //     auto &org_its = path.vertices[path.path[path.path.size() - 2]];
            //     auto tmp_its = org_its;
            //     Spectrum value1 = algorithm1_ptracer::eval(scene, path, p_rec.sampler);
            //     org_its = anti_its;
            //     Spectrum value2 = algorithm1_ptracer::eval(scene, path, p_rec.sampler);
            //     p_rec.image[pixel_idx] += value1 * misWeight(pdf1A, pdf1B);
            //     p_rec.image[pixel_idx] += value2 * misWeight(pdf2B, pdf2A);
            //     org_its = tmp_its;
            // }
            // else
            // {
            //     Spectrum value1 = algorithm1_ptracer::eval(scene, path, p_rec.sampler);
            //     p_rec.image[pixel_idx] += value1;
            // }

            // path.path.pop_back();
        }
#endif
    }

    void handleEmissionAD(SceneAD &sceneAD, const Intersection &its, const ParticleRecordAD &p_rec)
    {
        auto &scene = sceneAD.val;
        auto &path = p_rec.path;
        Matrix2x4 pix_uvs;
        Vector dir;
        Array4 sensor_vals = sceneAD.val.sampleAttenuatedSensorDirect(
            its, p_rec.sampler, pix_uvs, dir);
        if (sensor_vals.isZero())
            return;
        for (int i = 0; i < 4; i++)
        {
            if (sensor_vals[i] < Epsilon)
                continue;
            Spectrum value = p_rec.throughput * sensor_vals[i] * its.ptr_emitter->evalDirection(its.geoFrame.n, dir);
            int pixel_idx = sceneAD.val.camera.getPixelIndex(pix_uvs.col(i));
            path.append(sceneAD.val.camera);                                                                                  // NOTE
            path.vertices.back().pixel_idx = unravel_index(pixel_idx, {sceneAD.val.camera.width, sceneAD.val.camera.height}); // NOTE
            auto pathAD = algorithm1_ptracer::LightPathAD(path);
            // NOTE anti_path is the inplace modified p_rec.path
            // FIXME
            auto [valid, anti_its, pdf1A, pdf1B, pdf2A, pdf2B] = algorithm1_ptracer::antithetic(sceneAD.val, path);

#ifdef FORWARD
            int shape_idx = scene.getShapeRequiresGrad();
            sceneAD.gm.get(omp_get_thread_num()).shape_list[shape_idx]->param = 0;
            Float param = sceneAD.gm.get(omp_get_thread_num()).shape_list[shape_idx]->param;
#endif

            if (valid && p_rec.enable_antithetic)
            {
                // generate anti_path
                auto &org_its = path.vertices[path.path[path.path.size() - 2]];
                auto tmp_its = org_its;

                org_its.pdf /= misWeight(pdf1A, pdf1B);
                auto pathAD = algorithm1_ptracer::LightPathAD(path);
                algorithm1_ptracer::d_eval(sceneAD, pathAD, p_rec.d_image[pixel_idx], p_rec.sampler);
                org_its = anti_its;
                org_its.pdf /= misWeight(pdf2B, pdf2A);
                auto anti_pathAD = algorithm1_ptracer::LightPathAD(path);
                algorithm1_ptracer::d_eval(sceneAD, anti_pathAD, p_rec.d_image[pixel_idx], p_rec.sampler);

                // restore org_its
                org_its = tmp_its;
            }
            else
            {
                algorithm1_ptracer::d_eval(sceneAD, pathAD, p_rec.d_image[pixel_idx], p_rec.sampler);
            }

            path.path.pop_back();

#ifdef FORWARD
            if (isfinite(sceneAD.gm.get(omp_get_thread_num()).shape_list[shape_idx]->param))
            {
                param = sceneAD.gm.get(omp_get_thread_num()).shape_list[shape_idx]->param - param;
                p_rec.grad_image[pixel_idx] += Spectrum(param, 0.f, 0.f);
            }
#endif
        }
    }

    void handleSurfaceInteraction(const Scene &scene, const Intersection &its, const ParticleRecord &p_rec)
    {
#ifdef USE_NEW_CAMERA_SAMPLING
        CameraDirectSamplingRecord cRec;
        Spectrum value(scene.sampleAttenuatedSensorDirect(its, p_rec.sampler, cRec));
        if (!value.isZero(Epsilon) && cRec.baseVal > Epsilon)
        {
            Vector wi = its.toWorld(its.wi);
            Vector wo = cRec.dir, wo_local = its.toLocal(wo);
            /* Prevent light leaks due to the use of shading normals -- [Veach, p. 158] */
            Float wiDotGeoN = wi.dot(its.geoFrame.n), woDotGeoN = wo.dot(its.geoFrame.n);
            if (wiDotGeoN * its.wi.z() > Epsilon && woDotGeoN * wo_local.z() > Epsilon)
            {
                value *= its.ptr_bsdf->eval(its, wo_local, EBSDFMode::EImportanceWithCorrection) * p_rec.throughput;
                if (!value.isZero(Epsilon))
                    scene.camera.accumulateDirect(cRec, value, p_rec.sampler, &p_rec.image[0]);
            }
        }
#else
        Vector dir;        
        Matrix2x4 pix_uvs;
        Array4 sensor_vals = scene.sampleAttenuatedSensorDirect(
            its, p_rec.sampler, pix_uvs, dir);

        if (sensor_vals.isZero())
            return;
        Vector wi = its.toWorld(its.wi);
        Vector wo = dir, wo_local = its.toLocal(wo);
        /* Prevent light leaks due to the use of shading normals -- [Veach, p. 158] */
        Float wiDotGeoN = wi.dot(its.geoFrame.n),
              woDotGeoN = wo.dot(its.geoFrame.n);
        if (wiDotGeoN * its.wi.z() <= 0 ||
            woDotGeoN * wo_local.z() <= 0)
            return;

        /* Adjoint BSDF for shading normals -- [Veach, p. 155] */
        for (int i = 0; i < 4; i++)
        {
            if (sensor_vals(i) < Epsilon)
                continue;
            Spectrum bsdf = its.ptr_bsdf->eval(its, wo_local, EBSDFMode::EImportanceWithCorrection);
            Spectrum value = sensor_vals(i) * bsdf * p_rec.throughput;
            int pixel_idx = scene.camera.getPixelIndex(pix_uvs.col(i));
            p_rec.image[pixel_idx] += value;

            // p_rec.path.append(scene.camera);                                                                            // NOTE
            // p_rec.path.vertices.back().pixel_idx = unravel_index(pixel_idx, {scene.camera.width, scene.camera.height}); // NOTE
            // // NOTE anti_path is the inplace modified p_rec.path
            // auto [valid, anti_its, pdf1A, pdf1B, pdf2A, pdf2B] = algorithm1_ptracer::antithetic_surf(scene, p_rec.path, p_rec.sampler);
            // if (valid && p_rec.enable_antithetic)
            // {
            //     // generate anti_path
            //     Intersection &org_its = p_rec.path.vertices[p_rec.path.path[p_rec.path.path.size() - 2]];
            //     auto tmp_its = org_its;
            //     Spectrum value1 = algorithm1_ptracer::eval(scene, p_rec.path, p_rec.sampler);
            //     org_its = anti_its;
            //     Spectrum value2 = algorithm1_ptracer::eval(scene, p_rec.path, p_rec.sampler);
            //     p_rec.image[pixel_idx] += value1 * misWeight(pdf1A, pdf1B);
            //     p_rec.image[pixel_idx] += value2 * misWeight(pdf2B, pdf2A);
            //     PSDR_WARN(std::abs(misWeight(pdf1A, pdf1B) - misWeight(pdf2B, pdf2A)) < 1e-6);
            //     org_its = tmp_its;
            // }
            // else
            // {
            //     Spectrum value1 = algorithm1_ptracer::eval(scene, p_rec.path, p_rec.sampler);
            //     p_rec.image[pixel_idx] += value1;
            // }
            // p_rec.path.path.pop_back();
        }
#endif
    }

    void handleSurfaceInteractionAD(SceneAD &sceneAD, const Intersection &its, const ParticleRecordAD &p_rec)
    {
        auto &scene = sceneAD.val;
        auto &path = p_rec.path;
        Vector dir;
        Matrix2x4 pix_uvs;
        Array4 sensor_vals = scene.sampleAttenuatedSensorDirect(
            its, p_rec.sampler, pix_uvs, dir);

        if (sensor_vals.isZero())
            return;
        Vector wi = its.toWorld(its.wi);
        Vector wo = dir, wo_local = its.toLocal(wo);
        /* Prevent light leaks due to the use of shading normals -- [Veach, p. 158] */
        Float wiDotGeoN = wi.dot(its.geoFrame.n),
              woDotGeoN = wo.dot(its.geoFrame.n);
        if (wiDotGeoN * its.wi.z() <= 0 ||
            woDotGeoN * wo_local.z() <= 0)
            return;

        /* Adjoint BSDF for shading normals -- [Veach, p. 155] */
        for (int i = 0; i < 4; i++)
        {
            if (sensor_vals(i) < Epsilon)
                continue;
            Spectrum value = sensor_vals(i) * its.ptr_bsdf->eval(its, wo_local, EBSDFMode::EImportanceWithCorrection) * p_rec.throughput;
            int pixel_idx = scene.camera.getPixelIndex(pix_uvs.col(i));

            path.append(scene.camera);                                                                            // NOTE
            path.vertices.back().pixel_idx = unravel_index(pixel_idx, {scene.camera.width, scene.camera.height}); // NOTE
            // NOTE anti_path is the inplace modified path
            auto [valid, anti_its, pdf1A, pdf1B, pdf2A, pdf2B] = algorithm1_ptracer::antithetic_surf(scene, path, p_rec.sampler);
#ifdef FORWARD
            int shape_idx = scene.getShapeRequiresGrad();
            sceneAD.gm.get(omp_get_thread_num()).shape_list[shape_idx]->param = 0;
            Float param = sceneAD.gm.get(omp_get_thread_num()).shape_list[shape_idx]->param;
#endif
            if (valid && p_rec.enable_antithetic)
            {
                // generate anti_path
                auto &org_its = path.vertices[path.path[path.path.size() - 2]];
                auto tmp_its = org_its;
                org_its.pdf /= misWeight(pdf1A, pdf1B);
                auto pathAD = algorithm1_ptracer::LightPathAD(path);
                algorithm1_ptracer::d_eval(sceneAD, pathAD, p_rec.d_image[pixel_idx], p_rec.sampler);
                org_its = anti_its;
                org_its.pdf /= misWeight(pdf2B, pdf2A);
                auto anti_pathAD = algorithm1_ptracer::LightPathAD(path);
                algorithm1_ptracer::d_eval(sceneAD, anti_pathAD, p_rec.d_image[pixel_idx], p_rec.sampler);

                // restore org_its
                org_its = tmp_its;
            }
            else
            {
                auto pathAD = algorithm1_ptracer::LightPathAD(path);
                algorithm1_ptracer::d_eval(sceneAD, pathAD, p_rec.d_image[pixel_idx], p_rec.sampler);
            }

            path.path.pop_back();

#ifdef FORWARD
            if (isfinite(sceneAD.gm.get(omp_get_thread_num()).shape_list[shape_idx]->param))
            {
                param = sceneAD.gm.get(omp_get_thread_num()).shape_list[shape_idx]->param - param;
                p_rec.grad_image[pixel_idx] += Spectrum(param, 0.f, 0.f);
            }
#endif
        }
    }

    void handleMediumInteraction(const Scene &scene, const Ray &ray, const ParticleRecord &p_rec)
    {
#ifdef USE_NEW_CAMERA_SAMPLING
        CameraDirectSamplingRecord cRec;
        Spectrum value(scene.sampleAttenuatedSensorDirect(ray.org, p_rec.medium, p_rec.sampler, cRec));
        if (!value.isZero(Epsilon) && cRec.baseVal > Epsilon)
        {
            const PhaseFunction *ptr_phase = scene.phase_list[p_rec.medium->phase_id];
            value *= ptr_phase->eval(-ray.dir, cRec.dir) * p_rec.throughput;
            if (!value.isZero(Epsilon))
                scene.camera.accumulateDirect(cRec, value, p_rec.sampler, &p_rec.image[0]);
        }
#else
        Vector2 pixel_uv;
        Vector wi = -ray.dir, wo;
        const Camera &camera = scene.camera;
        Matrix2x4 pix_uvs;
        Array4 sensor_vals = scene.sampleAttenuatedSensorDirect(ray.org, p_rec.medium, p_rec.sampler, pix_uvs, wo);
        if (!sensor_vals.isZero())
        {
            const PhaseFunction *ptr_phase = scene.phase_list[p_rec.medium->phase_id];
            for (int i = 0; i < 4; i++)
            {
                if (sensor_vals(i) < Epsilon)
                    continue;
                Spectrum value = sensor_vals(i) * ptr_phase->eval(wi, wo) * p_rec.throughput;
                int pixel_idx = scene.camera.getPixelIndex(pix_uvs.col(i));
                p_rec.image[pixel_idx] += value;
                // NOTE algorithm 1
                // p_rec.path.append(scene.camera);                                                                            // NOTE
                // p_rec.path.vertices.back().pixel_idx = unravel_index(pixel_idx, {scene.camera.width, scene.camera.height}); // NOTE
                // auto [valid, anti_its, pdf1A, pdf1B, pdf2A, pdf2B] = algorithm1_ptracer::antithetic_vol(scene, p_rec.path, p_rec.sampler, p_rec.is_equal_trans);
                // // Statistics::getInstance().getCounter("ptracer2", "path number") += 1;
                // if (valid && p_rec.enable_antithetic)
                // {
                //     // Statistics::getInstance().getCounter("ptracer2", "antithetic path number") += 1;
                //     // generate anti_path
                //     Intersection &org_its = p_rec.path.vertices[p_rec.path.path[p_rec.path.path.size() - 2]];
                //     auto tmp_its = org_its;
                //     Spectrum value1 = algorithm1_ptracer::eval(scene, p_rec.path, p_rec.sampler);
                //     org_its = anti_its;
                //     Spectrum value2 = algorithm1_ptracer::eval(scene, p_rec.path, p_rec.sampler);
                //     p_rec.image[pixel_idx] += value1 * misWeight(pdf1A, pdf1B);
                //     p_rec.image[pixel_idx] += value2 * misWeight(pdf2B, pdf2A);
                //     PSDR_WARN(std::abs(misWeight(pdf1A, pdf1B) - misWeight(pdf2B, pdf2A)) < 1e-6);
                //     org_its = tmp_its;
                // }
                // else
                // {
                //     // Statistics::getInstance().getCounter("ptracer2", "non-antithetic path number") += 1;
                //     Spectrum value = algorithm1_ptracer::eval(scene, p_rec.path, p_rec.sampler);
                //     p_rec.image[pixel_idx] += value;
                // }
                // p_rec.path.path.pop_back();
            }
        }
#endif
    }

    void handleMediumInteractionAD(SceneAD &sceneAD, const Ray &ray, const ParticleRecordAD &p_rec)
    {
        const Scene &scene = sceneAD.val;
        Vector2 pixel_uv;
        Vector wi = -ray.dir, wo;
        const Camera &camera = scene.camera;
        Matrix2x4 pix_uvs;
        Array4 sensor_vals = scene.sampleAttenuatedSensorDirect(ray.org, p_rec.medium, p_rec.sampler, pix_uvs, wo);
        if (!sensor_vals.isZero())
        {
            const PhaseFunction *ptr_phase = scene.phase_list[p_rec.medium->phase_id];
            for (int i = 0; i < 4; i++)
            {
                if (sensor_vals(i) < Epsilon)
                    continue;
                Spectrum value = sensor_vals(i) * ptr_phase->eval(wi, wo) * p_rec.throughput;
                int pixel_idx = scene.camera.getPixelIndex(pix_uvs.col(i));

#ifdef FORWARD
                int shape_idx = scene.getShapeRequiresGrad();
                sceneAD.gm.get(omp_get_thread_num()).shape_list[shape_idx]->param = 0;
                Float param = sceneAD.gm.get(omp_get_thread_num()).shape_list[shape_idx]->param;
#endif
                auto &path = p_rec.path;
                path.append(scene.camera);                                                                            // NOTE
                path.vertices.back().pixel_idx = unravel_index(pixel_idx, {scene.camera.width, scene.camera.height}); // NOTE
                // NOTE anti_path is the inplace modified path
                auto [valid, anti_its, pdf1A, pdf1B, pdf2A, pdf2B] = algorithm1_ptracer::antithetic_vol(scene, path, p_rec.sampler, p_rec.is_equal_trans);
                // Statistics::getInstance().getCounter("ptracer2", "path number") += 1;

                if (valid && p_rec.enable_antithetic)
                {
                    // Statistics::getInstance().getCounter("ptracer2", "antithetic path number") += 1;
                    // generate anti_path
                    auto &org_its = path.vertices[path.path[path.path.size() - 2]];
                    auto tmp_its = org_its;
                    org_its.pdf /= misWeight(pdf1A, pdf1B);
                    auto pathAD = algorithm1_ptracer::LightPathAD(path);
                    algorithm1_ptracer::d_eval(sceneAD, pathAD, p_rec.d_image[pixel_idx], p_rec.sampler);
                    org_its = anti_its;
                    org_its.pdf /= misWeight(pdf2B, pdf2A);
                    auto anti_pathAD = algorithm1_ptracer::LightPathAD(path);
                    algorithm1_ptracer::d_eval(sceneAD, anti_pathAD, p_rec.d_image[pixel_idx], p_rec.sampler);

                    // restore org_its
                    org_its = tmp_its;
                }
                else
                {
                    // Statistics::getInstance().getCounter("ptracer2", "non-antithetic path number") += 1;
                    auto pathAD = algorithm1_ptracer::LightPathAD(path);
                    algorithm1_ptracer::d_eval(sceneAD, pathAD, p_rec.d_image[pixel_idx], p_rec.sampler);
                }

                path.path.pop_back();

#ifdef FORWARD
                if (isfinite(sceneAD.gm.get(omp_get_thread_num()).shape_list[shape_idx]->param))
                {
                    param = sceneAD.gm.get(omp_get_thread_num()).shape_list[shape_idx]->param - param;
                    p_rec.grad_image[pixel_idx] += Spectrum(param, 0.f, 0.f);
                }
#endif
            }
        }
    }

    template <bool ad>
    void traceParticle(Type<Scene, ad> &sceneV, Type<ParticleRecord, ad> &p_rec)
    {
        const Scene &scene = value(sceneV);
        algorithm1_ptracer::LightPath &path = p_rec.getPath();
        Intersection its;
        p_rec.throughput = scene.sampleEmitterPosition(p_rec.sampler->next2D(), its);
        path.append(its); // NOTE
        if constexpr (ad)
            handleEmissionAD(sceneV, its, p_rec);
        else
            handleEmission(scene, its, p_rec);
        Ray ray;
        ray.org = its.p;
        Float pdf;

        Vector preX = its.p;
        MediumSamplingRecord mRec;
        p_rec.throughput *= its.ptr_emitter->sampleDirection(p_rec.sampler->next2D(), ray.dir, &pdf);
        Float pdfFailure = pdf; // keep track of the pdf of hitting a surface
        Float pdfSuccess = pdf; // keep track of the pdf of hitting a medium
        if (its.ptr_shape != nullptr)
            ray.dir = its.geoFrame.toWorld(ray.dir);
        int depth = 0;
        bool on_surface = true;
        p_rec.med_id = its.getTargetMediumId(ray.dir);
        p_rec.medium = its.getTargetMedium(ray.dir);
        while (p_rec.bounces < p_rec.max_bounces)
        {
            if (p_rec.throughput.isZero())
                break;
            scene.rayIntersect(ray, on_surface, its);
            bool inside_med = p_rec.medium != nullptr &&
                              p_rec.medium->sampleDistance(ray, its.t, p_rec.sampler, mRec);
            ray.org = mRec.p;
            if (inside_med)
            {
                // sampled a medium interaction
                path.append({mRec, p_rec.med_id, pdfSuccess * mRec.pdfSuccess * geometric(preX, mRec.p)}); // NOTE
                p_rec.throughput *= mRec.sigmaS * mRec.transmittance / mRec.pdfSuccess;
                if constexpr (!ad)
                    handleMediumInteraction(scene, ray, p_rec);
                else
                    handleMediumInteractionAD(sceneV, ray, p_rec);
                const auto *phase = scene.phase_list[p_rec.medium->phase_id];
                Vector wo;
                Float phase_val = phase->sample(-ray.dir, p_rec.sampler->next2D(), wo);
                if (phase_val < Epsilon)
                    break;
                Float phasePdf = phase->pdf(-ray.dir, wo);
                // pdf init, solid angle measure
                pdf = phasePdf;
                pdfFailure = phasePdf;
                pdfSuccess = phasePdf;
                p_rec.throughput *= phase_val;
                ray.dir = wo;
                on_surface = false;
                preX = ray.org;
            }
            else
            {
                // hit the surface
                // if the ray going through a medium
                if (p_rec.medium)
                {
                    pdfFailure *= mRec.pdfFailure;
                    pdfSuccess *= mRec.pdfFailure;
                    p_rec.throughput *= mRec.transmittance / mRec.pdfFailure;
                }
                if (p_rec.throughput.isZero())
                    break;
                if (!its.isValid())
                    break;
                if (!its.ptr_bsdf->isNull())
                {
                    its.pdf = pdfFailure * geometric(preX, its.p, its.geoFrame.n);
                    its.medium_id = p_rec.med_id;
                    path.append(its); // NOTE
                    if constexpr (ad)
                        handleSurfaceInteractionAD(sceneV, its, p_rec);
                    else
                        handleSurfaceInteraction(scene, its, p_rec);
                }
                Float bsdf_pdf, bsdf_eta;
                Vector wo_local, wo;
                EBSDFMode mode = EBSDFMode::EImportanceWithCorrection;
                auto bsdf_weight = its.sampleBSDF(p_rec.sampler->next3D(), wo_local, bsdf_pdf, bsdf_eta, mode);
                // pdf init, solid angle measure
                if (!its.ptr_bsdf->isNull())
                {
                    pdf = bsdf_pdf;
                    pdfFailure = bsdf_pdf;
                    pdfSuccess = bsdf_pdf;
                }
                if (bsdf_weight.isZero())
                    break;
                wo = its.toWorld(wo_local);
                Vector wi = -ray.dir;
                Float wiDotGeoN = wi.dot(its.geoFrame.n),
                      woDotGeoN = wo.dot(its.geoFrame.n);
                if (wiDotGeoN * its.wi.z() <= 0 ||
                    woDotGeoN * wo_local.z() <= 0)
                {
                    break;
                }
                p_rec.throughput *= bsdf_weight;
                if (its.isMediumTransition())
                {
                    p_rec.med_id = its.getTargetMediumId(woDotGeoN);
                    p_rec.medium = its.getTargetMedium(woDotGeoN);
                }
                ray = Ray(its.p, wo);
                on_surface = true;
                if (!its.ptr_bsdf->isNull())
                    preX = ray.org;
            }
            p_rec.bounces++;
        }
    }
}

ArrayXd ParticleTracer2::renderC(const Scene &scene, const RenderOptions &options) const
{
    const int nworker = omp_get_num_procs();

    int num_pixels = scene.camera.getNumPixels();
    int num_pixels_per_block = options.block_size * options.block_size;
    int num_blocks = (num_pixels - 1) / num_pixels_per_block + 1;

    std::vector<Spectrum> spec_list(num_pixels, Spectrum::Zero());
    ThreadManager thread_manager(spec_list, nworker);
    int blockProcessed = 0;
    Timer _("Render interior");

#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < num_blocks; i++)
    {
        for (int j = 0; j < num_pixels_per_block; j++)
        {
            const int thread_id = omp_get_thread_num();
            RndSampler sampler(options.seed, i * num_pixels_per_block + j);
            for (int k = 0; k < options.num_samples; k++)
            {
                ParticleRecord p_rec = {&sampler, options.max_bounces, 0, -1, nullptr,
                                        Spectrum::Zero(), thread_manager.get(thread_id),
                                        enable_antithetic, is_equal_trans};
                traceParticle<false>(scene, p_rec);
            }
        }

        if (verbose)
#pragma omp critical
            progressIndicator(static_cast<Float>(++blockProcessed) / num_blocks);
    }
    if (verbose)
        std::cout << std::endl;

    // Statistics::getInstance().printStats();

    thread_manager.merge();
    long long num_samples = static_cast<long long>(num_blocks) * num_pixels_per_block * options.num_samples;
    for (auto &spec : spec_list)
        spec /= static_cast<Float>(num_samples);
    return from_spectrum_list_to_tensor(spec_list, num_pixels);
}

ArrayXd ParticleTracer2::renderD(SceneAD &sceneAD, const RenderOptions &options, const ArrayXd &d_image) const
{
    if (enable_antithetic)
        std::cout << "PTracer renderD with antithetic sampling" << std::endl;
    else
        std::cout << "PTracer renderD without antithetic sampling" << std::endl;
    int size_block = sceneAD.val.camera.getNumPixels();
    int num_block = options.num_samples;
    const auto &camera = sceneAD.val.camera;
    int num_pixels = camera.getNumPixels();
    const int nworker = omp_get_num_procs();
    BlockedImage blocks({camera.width, camera.height}, {4, 4});
    // d_scene
    GradientManager<Scene> &gm = sceneAD.gm;
    // image
    std::vector<Spectrum> spec_list(camera.getNumPixels(), Spectrum::Zero());
    ThreadManager thread_manager(spec_list, nworker);
    // d_image
    auto d_image_spec_list = from_tensor_to_spectrum_list(
        d_image / sceneAD.val.camera.getNumPixels() / options.num_samples, camera.getNumPixels());
    // gradient image
    std::vector<Spectrum> grad_image(std::vector<Spectrum>(camera.getNumPixels(), Spectrum::Zero()));
    ThreadManager grad_images(grad_image, nworker);
    int blockProcessed = 0;
    Timer _("Render interior");

#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < blocks.m_BlocksTotal; i++)
    {
        ImageBlock block = blocks.getBlock(i);
        for (Array2i pixelIdx = block.curPixel(); block.hasNext(); pixelIdx = block.nextPixel())
        {
            int pixel_ravel_idx = ravel_multi_index(pixelIdx, {camera.width, camera.height});
            RndSampler sampler(options.seed, pixel_ravel_idx);
            Spectrum pixel_val = Spectrum::Zero();
            int thread_id = omp_get_thread_num();
            for (int j = 0; j < options.num_samples; j++)
            {
                ParticleRecordAD p_rec = {
                    &sampler,
                    options.max_bounces,
                    0,
                    -1,
                    nullptr,
                    Spectrum::Zero(),
                    thread_manager.get(thread_id),
                    d_image_spec_list,
                    grad_images.get(thread_id),
                    enable_antithetic,
                    is_equal_trans};
                traceParticle<true>(sceneAD, p_rec);
            }
        }

        if (verbose)
#pragma omp critical
            progressIndicator(static_cast<Float>(++blockProcessed) / blocks.size());
    }
    if (verbose)
        std::cout << std::endl;

    // Statistics::getInstance().printStats();

    thread_manager.merge();
    grad_images.merge();
    sceneAD.gm.merge();
    // size_t num_samples = size_t(size_block) * num_block;
    // for (auto &spec : spec_list)
    //     spec /= num_samples;
    // for (auto &spec : grad_image)
    //     spec /= num_samples;
    return from_spectrum_list_to_tensor(grad_image, camera.getNumPixels());
}