#include "boundary.h"
#include <render/scene.h>
#include <core/math_func.h>
#include <core/timer.h>
#include <core/logger.h>
namespace
{
    static DebugInfo debugInfo;
    [[maybe_unused]] void velocity(const Scene &scene,
                                   const BoundarySamplingRecord &bRec,
                                   Float &res)
    {
        const Shape *shape = scene.shape_list[bRec.shape_id];
        const Edge &edge = shape->edges[bRec.edge_id];
        const Vector &xB_0 = shape->getVertex(edge.v0);
        const Vector &xB_1 = shape->getVertex(edge.v1);
        const Vector &xB_2 = shape->getVertex(edge.v2);

        const Shape *shapeS = scene.shape_list[bRec.shape_id_S];
        const auto &indS = shapeS->getIndices(bRec.tri_id_S);
        const Vector &xS_0 = shapeS->getVertex(indS[0]);
        const Vector &xS_1 = shapeS->getVertex(indS[1]);
        const Vector &xS_2 = shapeS->getVertex(indS[2]);

        res = normal_velocity(scene.camera.cpos,
                              xB_0, xB_1, xB_2, bRec.t, bRec.dir,
                              xS_0, xS_1, xS_2);
    }

    void d_velocity(const Scene &scene, Scene &d_scene,
                    const EdgeRaySamplingRecord &eRec,
                    Float d_u)
    {
        [[maybe_unused]] Float u;
#if defined(ENZYME) && defined(ENZYME_BOUNDARY_PRIMARY)
        __enzyme_autodiff((void *)velocity,
                          enzyme_dup, &scene, &d_scene,
                          enzyme_const, &eRec,
                          enzyme_dupnoneed, &u, &d_u);
#endif
    }

    struct LightPath
    {
        LightPath() {}

        void clear()
        {
            vertices.clear();
            weight.clear();
        }
        void setZero()
        {
            for (auto &v : vertices)
                v.setZero();
        }
        void append(const Intersection &v)
        {
            vertices.push_back(v);
            weight.push_back(0.f);
        }
        void reserve(int n)
        {
            vertices.reserve(n);
            weight.reserve(n);
        }
        int size() const { return vertices.size(); }
        const Intersection &operator[](std::size_t idx) const { return vertices[idx]; }
        Intersection &operator[](std::size_t idx) { return vertices[idx]; }

        bool isCameraPath;
        std::vector<Intersection> vertices;
        std::vector<Float> weight;
    };
    
    struct BDPTRecord
    {
        BDPTRecord(RndSampler* rnd_sampler, int maxBounces)
        : sampler(rnd_sampler), max_bounces(maxBounces) {
            light_path.clear();
            camera_path.clear();
        }
        
        mutable RndSampler *sampler;
        int max_bounces;

        mutable LightPath light_path;
        mutable LightPath camera_path;

        LightPath &getLightPath() const { return light_path; }
        LightPath &getCameraPath() const { return camera_path;}
        Float computeWeight(const Scene& scene, int s, int t, const Intersection &cam);
    };

    Float misRatio(Float pdf0, Float pdf1)
    {
        Float ret = pdf0/pdf1;
        ret *= ret;
        return ret;
    }

    void evalBDPTPath(const Scene& scene, const LightPath& camera_path, int s, const LightPath& light_path, int t,
                    Spectrum& contrb_all)
    {
        if ( t == -1 ) {
            //assert(s > 0);
            const Intersection& v_cam = camera_path[s];
            const Vector& pre_p = (s == 0) ? scene.camera.cpos : camera_path[s-1].p;
            if (v_cam.isEmitter()) {
                Vector dir = (pre_p - v_cam.p).normalized();
                contrb_all = v_cam.Le(dir) * camera_path[s].value;
            }
        } else {
            const Intersection& v_lgt = light_path[t];
            const Intersection& v_cam = camera_path[s];
            if (!scene.isVisible(v_lgt.p, true, v_cam.p, true)) return;

            Vector dir = v_lgt.p - v_cam.p;
            Float dist2 = dir.squaredNorm();
            dir /= std::sqrt(dist2);
            Spectrum seg_lgt = (t == 0) ? Spectrum::Ones() * v_lgt.ptr_emitter->evalDirection(v_lgt.geoFrame.n, -dir)
                                        : v_lgt.evalBSDF(v_lgt.toLocal(-dir), EBSDFMode::EImportanceWithCorrection);
            seg_lgt /= dist2;
            Spectrum seg_cam = v_cam.evalBSDF(v_cam.toLocal(dir), EBSDFMode::ERadiance);
            contrb_all = seg_cam * seg_lgt * v_cam.value * v_lgt.value;
        }
    }



    void sampleLightPath(const Scene& scene, RndSampler* sampler, LightPath& path)
    {
        int max_bounces = path.vertices.capacity() - 1;
        Float pdf;
        Ray ray;
        Intersection its = path[0];
        for (int ibounce = 0; ibounce < max_bounces; ibounce++) {
            Spectrum throughput(1.0f);
            if (ibounce == 0 && !path.isCameraPath) {
                throughput *= its.ptr_emitter->sampleDirection(sampler->next2D(), ray.dir, &pdf);
                if (its.ptr_shape != nullptr)
                    ray.dir = its.geoFrame.toWorld(ray.dir);
            } else {
                Float bsdf_eta;
                Vector wo_local, wo;
                EBSDFMode mode = path.isCameraPath ? EBSDFMode::ERadiance : EBSDFMode::EImportanceWithCorrection;
                throughput *= its.sampleBSDF(sampler->next3D(), wo_local, pdf, bsdf_eta, mode);
                wo = its.toWorld(wo_local);
                // check light leak
                Vector wi = its.toWorld(its.wi);
                Float wiDotGeoN = wi.dot(its.geoFrame.n), woDotGeoN = wo.dot(its.geoFrame.n);
                if (wiDotGeoN * its.wi.z() <= 0 || woDotGeoN * wo_local.z() <= 0)
                    throughput.setZero();
                ray.dir = wo;
            }
            ray.org = its.p;
            Intersection its_debug = its;
            if (throughput.isZero(Epsilon) || !scene.rayIntersect(ray, true, its)) break;
            Float G = geometric(ray.org, its.p, its.geoFrame.n);
            if (G < Epsilon || its.t < ShadowEpsilon) break;     // Hack for fixing numerical issues
            its.pdf = pdf * G;
            its.value = path[ibounce].value * throughput;
            path.append(its);
        }
        
        Float pdf1 = 1.0;
        for (int ibounce = 1; ibounce < path.size(); ibounce++) {
            path.weight[ibounce] = 1.0f;
            if (ibounce > 1) {
                path.weight[ibounce] += path.weight[ibounce-1] * misRatio(pdf1, path[ibounce-2].pdf);
            }
            if (ibounce < path.size() - 1) {
                const Intersection &its0 = path[ibounce-1];
                Intersection &its1 = path[ibounce];
                Vector wo_local = its1.wi;
                Vector wi = (path[ibounce+1].p - its1.p).normalized();
                its1.wi = its1.toLocal(wi);
                pdf1 = its1.pdfBSDF(wo_local) * geometric(its1.p, its0.p, its0.geoFrame.n);
                its1.wi = wo_local;
            }
        }
    }

    Float BDPTRecord::computeWeight(const Scene& scene, int s, int t, const Intersection &cam)
    {
        Float pdf_fwd1, pdf_fwd2, pdf_bwd1, pdf_bwd2;
        Float inv_weight = 1.0f;

       if (s >= 0) {
            pdf_fwd1 = camera_path[s].pdf;
            Vector dir;
            if (t >= 0) {
                dir = camera_path[s].p - light_path[t].p;
                Float dist2 = dir.squaredNorm();
                dir /= std::sqrt(dist2);
                pdf_bwd1 = (t == 0) ? light_path[t].pdfEmitter(light_path[t].toLocal(dir))      // Here maybe we should use geometric normal
                                : light_path[t].pdfBSDF(light_path[t].toLocal(dir));
                pdf_bwd1 *= abs(dir.dot(camera_path[s].geoFrame.n)) / dist2;
            } else {
                pdf_bwd1 = scene.pdfEmitterSample(camera_path[s]);
            }

            inv_weight += misRatio(pdf_bwd1, pdf_fwd1);     // pdf(s-1, t+1) / pdf(s, t)
            if (s >= 1) {
                pdf_fwd2 = camera_path[s-1].pdf;
                Vector wo_local = camera_path[s].wi;
                if (t >= 0) {
                    camera_path[s].wi = camera_path[s].toLocal(-dir);
                    pdf_bwd2 = camera_path[s].pdfBSDF(wo_local);
                    camera_path[s].wi = wo_local;
                } else {
                    pdf_bwd2 = camera_path[s].pdfEmitter(wo_local);
                }
                pdf_bwd2 *= geometric(camera_path[s].p, camera_path[s-1].p, camera_path[s-1].geoFrame.n);
                inv_weight += camera_path.weight[s] * misRatio(pdf_bwd1 * pdf_bwd2, pdf_fwd1 * pdf_fwd2);     
            }
        }

        if ( t >= 0 ) {
            pdf_fwd1 = light_path[t].pdf;
            if (s >= 0) {
                Vector dir = light_path[t].p - camera_path[s].p;
                Float dist2 = dir.squaredNorm();
                dir /= std::sqrt(dist2);
                pdf_bwd1 = camera_path[s].pdfBSDF(camera_path[s].toLocal(dir));
                pdf_bwd1 *= std::abs(light_path[t].geoFrame.n.dot(dir)) / dist2;
            } 
            else {
                Vector dir = light_path[t].p - cam.p;
                Float dist2 = dir.squaredNorm();
                dir /= std::sqrt(dist2);
                pdf_bwd1 = cam.pdfBSDF(cam.toLocal(dir));
                pdf_bwd1 *= std::abs(light_path[t].geoFrame.n.dot(dir)) / dist2;
            }
            inv_weight += misRatio(pdf_bwd1, pdf_fwd1);     // pdf(s+1, t-1) / pdf(s, t)

            if ( t >= 1 ) {
                pdf_fwd2 = light_path[t-1].pdf;
                Vector wo_local = light_path[t].wi;
                Vector wi = (((s >= 0) ? camera_path[s].p : cam.p) - light_path[t].p).normalized(); 
                light_path[t].wi = light_path[t].toLocal(wi);
                pdf_bwd2 = light_path[t].pdfBSDF(wo_local);
                light_path[t].wi = wo_local;
                pdf_bwd2 *= geometric(light_path[t].p, light_path[t-1].p, light_path[t-1].geoFrame.n);
                // Sum(pdf(s+t-j, j)) / pdf(s, t), where -1 <= j <= t-2
                inv_weight += light_path.weight[t] * misRatio(pdf_bwd1 * pdf_bwd2, pdf_fwd1 * pdf_fwd2);              
            }
        }
        return 1.f / inv_weight;
    }

    Spectrum bidirTrace(const Scene &scene, RndSampler *sampler, const Ray &_ray, int max_depth)
    { 
        BDPTRecord b_rec = BDPTRecord(sampler, max_depth);
        Spectrum ret = Spectrum::Zero();

        LightPath &camera_path = b_rec.getCameraPath();
        Intersection cam;
        Ray prim_ray = _ray;
        scene.rayIntersect(prim_ray, true, cam);
        // sample camera path
        {
            camera_path.isCameraPath = true;
            camera_path.reserve(b_rec.max_bounces + 1);
            Ray ray;
            Float bsdf_eta;
            Vector wo_local, wo;
            EBSDFMode mode = EBSDFMode::ERadiance;
            Float pdf;
            Spectrum val = cam.sampleBSDF(sampler->next3D(), wo_local, pdf, bsdf_eta, mode);
            wo = cam.toWorld(wo_local);
            ray.org = cam.p;
            ray.dir = wo;
            // camera.samplePrimaryRayFromFilter(pixel_idx[0], pixel_idx[1], b_rec.sampler->next2D(),
            //                                   ray, ray_ant);
            Intersection its;
            if (scene.rayIntersect(ray, true, its)) {
                Float dist = (its.p - cam.p).norm();
                its.pdf = pdf;
                its.value = val;
                camera_path.append(its);
                sampleLightPath(scene, b_rec.sampler, camera_path);
            }
        }

        LightPath &light_path = b_rec.getLightPath();
        // sample light path
        {
            light_path.isCameraPath = false;
            light_path.reserve(b_rec.max_bounces + 1);
            Intersection its;
            its.value = scene.sampleEmitterPosition(b_rec.sampler->next2D(), its);
            light_path.append(its);
            sampleLightPath(scene, b_rec.sampler, light_path);
        }
        Float num_path = 0.f;
        // zero-bounce contrb.
        if (camera_path.size() > 0) {
            Spectrum path_val(0.f);
            evalBDPTPath(scene, camera_path, 0, light_path, -1, path_val);
            if (!path_val.isZero()) {
                Float mis_weight = b_rec.computeWeight(scene, 0, -1, cam);
                if((mis_weight * path_val).allFinite())
                    ret += mis_weight * path_val;
            }
        }

        // contrb from jittered sampling (i.e. s >= 0)
        for(int nbounce = 1; nbounce <= b_rec.max_bounces; nbounce++) {
            int s_min = std::max(0, nbounce - light_path.size());
            int s_max = std::min(nbounce, camera_path.size() - 1);
            for (int s = s_min; s <= s_max; s++) {
                int t = nbounce - 1 - s;
                Spectrum path_val(0.f);
                bool antithetic_success = false;
                evalBDPTPath(scene, camera_path, s, light_path, t, path_val);
                if (path_val.isZero()) continue;
                Float mis_weight = b_rec.computeWeight(scene, s, t, cam);
                if((mis_weight * path_val).allFinite())
                    ret += mis_weight * path_val;
            }
        }

        Vector wo;
        for (int t = 1; t < light_path.size(); t++) {
            Intersection& its = light_path[t];
            if (!scene.isVisible(its.p, true, cam.p, true)) continue;
            wo = (cam.p - its.p);
            Float inv_dist = 1.0f / wo.norm();
            wo *= inv_dist;
            Vector wi = its.toWorld(its.wi);
            Vector wo_local = its.toLocal(wo);
            Vector wo_local_cam = cam.toLocal(wo);
            /* Prevent light leaks due to the use of shading normals -- [Veach, p. 158] */
            Float wiDotGeoN = wi.dot(its.geoFrame.n), woDotGeoN = wo.dot(its.geoFrame.n);
            if (wiDotGeoN * its.wi.z() <= 0 || woDotGeoN * wo_local.z() <= 0) continue;
            Spectrum path_val = cam.evalBSDF(wo_local_cam) * inv_dist * inv_dist * 
                                its.value * its.evalBSDF(wo_local, EBSDFMode::EImportanceWithCorrection);
            if (path_val.isZero()) continue;
            Float mis_weight = b_rec.computeWeight(scene, -1, t, cam);
            if((mis_weight * path_val).allFinite())
                ret += mis_weight * path_val;
        }
        return ret;
    }

    void handleBoundaryInteraction(const Vector &p,
                                   const Scene &scene, Scene &d_scene,
                                   BoundarySamplingRecord &eRec,
                                   RndSampler &sampler, const Spectrum &weight,
                                   int max_depth, std::vector<Spectrum> &d_image)
    {
#if defined(FORWARD)
        int shape_idx = scene.getShapeRequiresGrad();
#endif
        CameraDirectSamplingRecord cRec;
        if (!scene.camera.sampleDirect(p, cRec))
            return;
        if (!scene.isVisible(p, true, scene.camera.cpos, true))
            return;
        auto [pixel_idx, sensor_val] = scene.camera.sampleDirectPixel(cRec, sampler.next1D());
        if (sensor_val < Epsilon)
            return;
        Float d_u = (d_image[pixel_idx] * weight * sensor_val).sum();

#ifdef FORWARD
        d_scene.shape_list[shape_idx]->param = 0;
        Float param = d_scene.shape_list[shape_idx]->param;
#endif
        d_velocity(scene, d_scene, eRec, d_u);
#ifdef FORWARD
        param = d_scene.shape_list[shape_idx]->param - param;
        const int tid = omp_get_thread_num();
        debugInfo.image_per_thread[tid][pixel_idx] += Spectrum(param, 0, 0);
#endif
    }

    void d_samplePrimaryBoundary(const Scene &scene, Scene &d_scene,
                                 RndSampler &sampler, const RenderOptions &options,
                                 const DiscreteDistribution &edge_dist,
                                 const std::vector<Vector2i> &edge_indices,
                                 std::vector<Spectrum> &d_image)
    {
        /* Sample a point on the boundary */
        BoundarySamplingRecord eRec;
        scene.sampleEdgePoint(sampler.next1D(),
                              edge_dist, edge_indices,
                              eRec);
        if (eRec.shape_id == -1)
        {
            PSDR_WARN(eRec.shape_id == -1);
            return;
        }
        const Shape *shape = scene.shape_list[eRec.shape_id];
        const Edge &edge = shape->edges[eRec.edge_id];
        Vector v0 = shape->getVertex(edge.v0);
        Vector v1 = shape->getVertex(edge.v1);
        Vector v2 = shape->getVertex(edge.v2);
        const Vector xB = v0 + (v1 - v0) * eRec.t,
                     &xD = scene.camera.cpos;
        Ray ray(xB, (xB - xD).normalized());
        Intersection its;

        if (!scene.rayIntersect(ray, true, its))
            return;
        const Vector &xS = its.p;
        // populate the data in BoundarySamplingRecord eRec
        eRec.shape_id_S = its.indices[0];
        eRec.tri_id_S = its.indices[1];
        // sanity check
        {
            // make sure the ray is tangent to the surface
            if (edge.f0 >= 0 && edge.f1 >= 0)
            {
                Vector n0 = shape->getGeoNormal(edge.f0),
                       n1 = shape->getGeoNormal(edge.f1);
                Float dotn0 = ray.dir.dot(n0),
                      dotn1 = ray.dir.dot(n1);
                if (math::signum(dotn0) * math::signum(dotn1) > -0.5)
                    return;
            }
            // NOTE prevent intersection with a backface

            Float gnDotD = its.geoFrame.n.dot(-ray.dir);
            Float snDotD = its.shFrame.n.dot(-ray.dir);
            bool success = (its.ptr_bsdf->isTransmissive() && math::signum(gnDotD) * math::signum(snDotD) > 0.5f) ||
                           (!its.ptr_bsdf->isTransmissive() && gnDotD > Epsilon && snDotD > Epsilon);
            if (!success)
                return;
            if (!scene.isVisible(xB, true, xD, true))
                return;
        }

        Vector n = (v0 - v1).cross(ray.dir).normalized();
        n *= -math::signum(n.dot(v2 - v0)); // make sure n points to the visible side

        /* Jacobian determinant that accounts for the change of variable */
        Float J = dlS_dlB(xD,
                          xB, (v0 - v1).normalized(),
                          xS, its.geoFrame.n);

        Float cosS = std::abs(its.geoFrame.n.dot(-ray.dir));
        Float baseValue = J * cosS;
        assert(baseValue > -Epsilon);
        /* Sample source path */
        Spectrum value = bidirTrace(scene, &sampler,
                            Ray(xB, ray.dir),
                            options.max_bounces);
        /* Evaluate primary boundary segment */
        handleBoundaryInteraction(xS, scene, d_scene,
                                  eRec, sampler, value * baseValue / eRec.pdf,
                                  options.max_bounces, d_image);
    }

}

ArrayXd PrimaryEdgeBidirectionalIntegrator::renderD(SceneAD &sceneAD,
                                       RenderOptions &options, const ArrayXd &__d_image) const
{
    const Scene &scene = sceneAD.val;
    [[maybe_unused]] Scene &d_scene = sceneAD.der;
    GradientManager<Scene> &gm = sceneAD.gm;
    gm.setZero(); // zero multi-thread gradient

    const int nworker = omp_get_num_procs();
    const auto &camera = scene.camera;
    const int nsamples = options.num_samples_primary_edge;
    const int nblocks = std::ceil(static_cast<Float>(camera.getNumPixels()) / (options.block_size * options.block_size));
    const int nblock_samples = options.block_size * options.block_size * nsamples;
    /* init debug info */
    debugInfo = DebugInfo(nworker, camera.getNumPixels(), nsamples);
    if (nsamples <= 0)
        return debugInfo.getArray();

    std::vector<Spectrum> _d_image_spec_list = from_tensor_to_spectrum_list(
        __d_image / nblock_samples / nblocks, camera.getNumPixels());

    Timer _("Primary boundary");

    int blockProcessed = 0;
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < nblocks; i++)
    {
        for (int j = 0; j < nblock_samples; j++)
        {
            const int tid = omp_get_thread_num();
            RndSampler sampler(options.seed, i * nblock_samples + j);
            // sample a point on the boundary
            d_samplePrimaryBoundary(scene, gm.get(tid),
                                    sampler, options,
                                    edge_dist, edge_indices,
                                    _d_image_spec_list);
        }
        if (verbose)
#pragma omp critical
            progressIndicator(static_cast<Float>(++blockProcessed) / nblocks);
    }
    if (verbose)
        std::cout << std::endl;

    // merge d_scene
    gm.merge();

    /* normal related */
#ifdef NORMAL_PREPROCESS
    Timer preprocess_timer("preprocess");
    d_precompute_normal(scene, d_scene);
#endif

    return flattened(debugInfo.getArray());
}
