#include <misc/Exception.h>
#include <psdr/core/cube_distrb.h>
#include <psdr/core/ray.h>
#include <psdr/core/intersection.h>
#include <psdr/core/sampler.h>
#include <psdr/core/transform.h>
#include <psdr/bsdf/bsdf.h>
#include <psdr/emitter/emitter.h>
#include <psdr/shape/mesh.h>
#include <psdr/scene/scene.h>
#include <psdr/sensor/perspective.h>
#include <psdr/integrator/direct_dual.h>

#include <psdr/core/AQ_distrb.h>

#include <fstream>

namespace psdr
{

DirectIntegratorDual::~DirectIntegratorDual() {
    for ( auto *item : m_warpper ) {
        if ( item != nullptr ) delete item;
    }

    for ( auto *item : m_aq ) {
        if ( item != nullptr ) delete item;
    }
}


DirectIntegratorDual::DirectIntegratorDual(int bsdf_samples, int light_samples, int edge_direct) : m_bsdf_samples(bsdf_samples), m_light_samples(light_samples), m_edge_direct(edge_direct) {
    PSDR_ASSERT((bsdf_samples >= 0) && (light_samples >= 0) && (bsdf_samples + light_samples > 0));
}


SpectrumC DirectIntegratorDual::Li(const Scene &scene, Sampler &sampler, const RayC &ray, MaskC active) const {
    return __Li<false>(scene, sampler, ray, active);
}


SpectrumD DirectIntegratorDual::Li(const Scene &scene, Sampler &sampler, const RayD &ray, MaskD active) const {
    return __Li<true>(scene, sampler, ray, active);
}

template <bool ad>
Spectrum<ad> DirectIntegratorDual::__Li(const Scene &scene, Sampler &sampler, const Ray<ad> &ray, Mask<ad> active) const {
    Intersection<ad> its = scene.ray_intersect<ad>(ray, active);
    active &= its.is_valid();

    Spectrum<ad> result = m_hide_emitters ? zero<Spectrum<ad>>() : its.Le(active);

    BSDFArray<ad> bsdf_array = its.shape->bsdf(active);
    if ( scene.m_emitter_env != nullptr ) {
        // Skip reflectance computations for intersections on the bounding mesh
        active &= neq(bsdf_array, nullptr);
    }

    const BSDF *bsdf = nullptr;
    if ( scene.m_bsdfs.size() == 1U || scene.m_meshes.size() == 1U ) {
        bsdf = scene.m_meshes[0]->m_bsdf;
    }

    if ( m_bsdf_samples > 0 ) {
        // BSDF sampling

        for ( int i = 0; i < m_bsdf_samples; ++i ) {
            BSDFSampleDual<ad> bs;
            if ( bsdf != nullptr ) {
                bs = bsdf->sampleDual(its, sampler.next_nd<3, ad>(), active);
            } else {
                bs = bsdf_array->sampleDual(its, sampler.next_nd<3, ad>(), active);
            }
            Mask<ad> active1 = active && bs.is_valid1;
            Mask<ad> active2 = active && bs.is_valid2;

            Ray<ad> ray1(its.p, its.sh_frame.to_world(bs.wo1));
            Ray<ad> ray2(its.p, its.sh_frame.to_world(bs.wo2));


            Intersection<ad> its1 = scene.ray_intersect<ad, ad>(ray1, active1);
            active1 &= its1.is_valid();
            active1 &= neq(its1.shape->emitter(active1), nullptr);

            Intersection<ad> its2 = scene.ray_intersect<ad, ad>(ray2, active2);
            active2 &= its2.is_valid();
            active2 &= neq(its2.shape->emitter(active2), nullptr);


            Spectrum<ad> bsdf_val1, bsdf_val2;
            Float<ad> pdf1, pdf2;
            if constexpr ( ad ) {
                Vector3fD wo1 = (its1.p - its.p) / its1.t;
                Vector3fD wo2 = (its2.p - its.p) / its2.t;

                if ( bsdf != nullptr ) {
                    bsdf_val1 = bsdf->eval(its, its.sh_frame.to_local(wo1), active1);
                    bsdf_val2 = bsdf->eval(its, its.sh_frame.to_local(wo2), active2);
                } else {
                    bsdf_val1 = bsdf_array->eval(its, its.sh_frame.to_local(wo1), active1);
                    bsdf_val2 = bsdf_array->eval(its, its.sh_frame.to_local(wo2), active2);
                }

                FloatD G_val1 = abs(dot(its1.n, -wo1))/sqr(its1.t);
                pdf1 = bs.pdf1*detach(G_val1);
                bsdf_val1 *= G_val1*its1.J/pdf1;

                FloatD G_val2 = abs(dot(its2.n, -wo2))/sqr(its2.t);
                pdf2 = bs.pdf2*detach(G_val2);
                bsdf_val2 *= G_val2*its2.J/pdf2;

            } else {
                if ( bsdf != nullptr ) {
                    bsdf_val1 = bsdf->eval(its, bs.wo1, active1);
                    bsdf_val2 = bsdf->eval(its, bs.wo2, active2);
                } else {
                    bsdf_val1 = bsdf_array->eval(its, bs.wo1, active1);
                    bsdf_val2 = bsdf_array->eval(its, bs.wo2, active2);
                }
                FloatC G_val1 = abs(dot(its1.n, -ray1.d))/sqr(its1.t);
                pdf1 = bs.pdf1*G_val1;
                bsdf_val1 /= bs.pdf1;

                FloatC G_val2 = abs(dot(its2.n, -ray2.d))/sqr(its2.t);
                pdf2 = bs.pdf2*G_val2;
                bsdf_val2 /= bs.pdf2;

            }

            Float<ad> weight1 = 1.f/static_cast<float>(m_bsdf_samples);
            Float<ad> weight2 = 1.f/static_cast<float>(m_bsdf_samples);
            if ( m_light_samples > 0 ) {
                PSDR_ASSERT(0);
                weight1 *= mis_weight<ad>(pdf1, scene.emitter_position_pdf<ad>(its.p, its1, active1));
                weight2 *= mis_weight<ad>(pdf2, scene.emitter_position_pdf<ad>(its.p, its2, active2));
            }
            masked(result, active1) += its1.Le(active1)*bsdf_val1*weight1*0.5f;
            masked(result, active2) += its2.Le(active2)*bsdf_val2*weight2*0.5f;
        }
    }

    if ( m_light_samples > 0 ) {
        PSDR_ASSERT(0);
        // Light sampling

        for ( int i = 0; i < m_light_samples; ++i ) {
            PositionSample<ad> ps = scene.sample_emitter_position<ad>(its.p, sampler.next_2d<ad>(), active);
            Mask<ad> active1 = active && ps.is_valid;

            Vector3f<ad> wo = ps.p - its.p;
            Float<ad> dist_sqr = squared_norm(wo);
            Float<ad> dist = safe_sqrt(dist_sqr);
            wo /= dist;

            Ray<ad> ray1(its.p, wo);
            Intersection<ad> its1 = scene.ray_intersect<ad, ad>(ray1, active1);
            active1 &= its1.is_valid();
            active1 &= (its1.t > dist - ShadowEpsilon) && its1.is_emitter(active1);
            //ps.pdf = scene.emitter_position_pdf<ad>(its.p, its1, active1);

            Float<ad> cos_val = dot(its1.n, -wo);
            Float<ad> G_val = abs(cos_val) / dist_sqr;
            Spectrum<ad> bsdf_val;
            Float<ad> pdf1;
            Vector3f<ad> wo_local = its.sh_frame.to_local(wo);
            if ( bsdf != nullptr ) {
                bsdf_val = bsdf->eval(its, wo_local, active1);
                pdf1 = bsdf->pdf(its, wo_local, active1);
            } else {
                bsdf_val = bsdf_array->eval(its, wo_local, active1);
                pdf1 = bsdf_array->pdf(its, wo_local, active1);
            }
            bsdf_val *= G_val*ps.J/ps.pdf;

            if constexpr ( ad ) {
                pdf1 *= detach(G_val);
            } else {
                pdf1 *= G_val;
            }

            Float<ad> weight = 1.f/static_cast<float>(m_light_samples);
            if ( m_bsdf_samples > 0 ) {
                weight *= mis_weight<ad>(ps.pdf, pdf1);
            }
            masked(result, active1) += its1.Le(active1)*bsdf_val*weight;
        }
    }

    return result;
}

void DirectIntegratorDual::preprocess_secondary_edges(const Scene &scene, const std::vector<int> &sensor_id, const std::vector<float> &config, int option) {
    PSDR_ASSERT_MSG(scene.is_ready(), "Scene needs to be configured!");
    preprocess_option = option;
    if (option == 0 || option == 2) {
        std::cout << "Guiding using AQ Distribution" << std::endl;
        if (option == 0) {
            if ( static_cast<int>(m_aq.size()) != scene.m_num_sensors )
                m_aq.resize(scene.m_num_sensors, nullptr);
            if ( m_aq[sensor_id[0]] == nullptr )
                m_aq[sensor_id[0]] = new AdaptiveQuadratureDistribution3f();
        } else if (option == 2) {
            if ( m_global_aq == nullptr )
                m_global_aq = new AdaptiveQuadratureDistribution3f();
        } else {
            PSDR_ASSERT(0);
        }

        int nloop = int(config[0]);
#if 1   // set 0 to activate edge draw
        // initial edge distribution
        FloatC tmp_cdf = scene.m_sec_edge_distrb->cmf();
        FloatC tmp_pdf = scene.m_sec_edge_distrb->pmf();
        FloatC cdfx = zero<FloatC>(slices(tmp_cdf)*nloop);

        for (int i=0; i<nloop; ++i) {
            IntC tmpID = arange<IntC>(slices(tmp_cdf))*nloop+(nloop-i-1);
            scatter_add(cdfx, tmp_cdf - i * tmp_pdf/FloatC(nloop), tmpID);
        }

#else
        FloatC raw_cdf = scene.m_sec_edge_distrb->cmf();
        FloatC tmp_cdf = gather<FloatC>(raw_cdf, scene.m_edge_cut);
        size_t cut_size = slices(tmp_cdf);
        IntC   buf_id  = arange<IntC>(cut_size)+1;
        FloatC buf_cdf = zero<FloatC>(cut_size);
        scatter_add(buf_cdf, tmp_cdf, buf_id);
        FloatC tmp_pdf = tmp_cdf - buf_cdf;

        FloatC cdfx = zero<FloatC>(cut_size*nloop);
        for (int i=0; i<nloop; ++i) {
            IntC tmpID = arange<IntC>(cut_size)*nloop+(nloop-i-1);
            scatter_add(cdfx, tmp_cdf - i * tmp_pdf/FloatC(nloop), tmpID);
        }

        // FloatC cdfx = gather<FloatC>(raw_cdf, scene.m_edge_cut); // no cutting


#endif

        cuda_eval(); cuda_sync();
        FloatC cdfy = (arange<FloatC>(int(config[1]))+FloatC(1.0)) / FloatC(config[1]);
        FloatC cdfz = (arange<FloatC>(int(config[2]))+FloatC(1.0)) / FloatC(config[2]);

        AQ_Option aqconfig = AQ_Option(config, option);

        if (option == 0) {
            m_aq[sensor_id[0]]->setup(scene, sensor_id, cdfx, cdfy, cdfz, aqconfig);
        } else if (option == 2) {
            m_global_aq->setup(scene, sensor_id, cdfx, cdfy, cdfz, aqconfig);
        } else {
            PSDR_ASSERT(0);
        }

        cuda_eval(); cuda_sync();
    } else if (option == 1) {
        std::cout << "Guiding using MC regular Cube Distribution: " << config[0] << " " << config[1] << " " << config[2] << std::endl;
        int nrounds = int(config[4]);
        if ( static_cast<int>(m_warpper.size()) != scene.m_num_sensors )
            m_warpper.resize(scene.m_num_sensors, nullptr);

        if ( m_warpper[sensor_id[0]] == nullptr )
            m_warpper[sensor_id[0]] = new HyperCubeDistribution3f();
        auto warpper = m_warpper[sensor_id[0]];

        Array<int, 3> wconfig{int(config[0]), int(config[1]), int(config[2])};

        warpper->set_resolution(wconfig);
        int num_cells = warpper->m_num_cells;
        const int64_t num_samples = static_cast<int64_t>(num_cells)*int(config[3]);
        PSDR_ASSERT(num_samples <= std::numeric_limits<int>::max());

        IntC idx = divisor<int>(int(config[3]))(arange<IntC>(num_samples));
        Vector3iC sample_base = gather<Vector3iC>(warpper->m_cells, idx);

        Sampler sampler;
        sampler.seed(arange<UInt64C>(num_samples));

        FloatC result = zero<FloatC>(num_cells);
        for ( int j = 0; j < nrounds; ++j ) {
            SpectrumC value0;
            std::tie(std::ignore, value0) = eval_secondary_edge<false>(scene, *scene.m_sensors[sensor_id[0]],
                                                                       (sample_base + sampler.next_nd<3, false>())*warpper->m_unit);
            masked(value0, ~enoki::isfinite<SpectrumC>(value0)) = 0.f;
            if ( likely(config[3] > 1) ) {
                value0 /= static_cast<float>(config[3]);
            }
            PSDR_ASSERT(all(hmin(value0) > -Epsilon));
            scatter_add(result, hmax(value0), idx);
        }
        if ( nrounds > 1 ) result /= static_cast<float>(nrounds);
        warpper->set_mass(result);
        cuda_eval(); cuda_sync();
    } else {
        PSDR_ASSERT_MSG(0, "ERROR no such config for guiding SE");
    }
}


void DirectIntegratorDual::render_secondary_edges(const Scene &scene, int sensor_id, SpectrumD &result) const {
    const RenderOption &opts = scene.m_opts;
    Vector3fC sample3 = scene.m_samplers[2].next_nd<3, false>();
    if (preprocess_option == 0 || preprocess_option == 2) {
        FloatC pdf0 = 1.f;
        Vector3fC aq_sample;
        if (m_global_aq != nullptr) {
            aq_sample = m_global_aq->sample(sample3, pdf0);
        }
        else if (!(m_aq.empty() || m_aq[sensor_id] == nullptr)) {
            aq_sample = m_aq[sensor_id]->sample(sample3, pdf0);
        } else {
            aq_sample = sample3;
        }
        auto [idx, value] = eval_secondary_edge<true>(scene, *scene.m_sensors[sensor_id], aq_sample);
        masked(value, ~enoki::isfinite<SpectrumD>(value)) = 0.f;
        masked(value, pdf0 > Epsilon) /= pdf0;
        if ( likely(opts.sppse > 1) ) {
            value /= static_cast<float>(opts.sppse);
        }
        scatter_add(result, value, IntD(idx), idx >= 0);
    } else {
        FloatC pdf0 = (m_warpper.empty() || m_warpper[sensor_id] == nullptr) ?
                       1.f : m_warpper[sensor_id]->sample_reuse(sample3);
        auto [idx, value] = eval_secondary_edge<true>(scene, *scene.m_sensors[sensor_id], sample3);
        masked(value, ~enoki::isfinite<SpectrumD>(value)) = 0.f;
        masked(value, pdf0 > Epsilon) /= pdf0;
        if ( likely(opts.sppse > 1) ) {
            value /= static_cast<float>(opts.sppse);
        }
        scatter_add(result, value, IntD(idx), idx >= 0);
    }
}

SpectrumC DirectIntegratorDual::__eval_secondary_edge(const Scene &scene, const std::vector<int> &active_sensor, const Vector3fC &sample3) const {
    BoundarySegSampleDirect bss;
    bss = scene.sample_edge_ray(sample3);
    MaskC valid = bss.is_valid;

    const Vector3fC &_p0    = detach(bss.p0);
    Vector3fC _dir = bss.p2;

    IntersectionC _its2 = scene.ray_intersect<false>(RayC(_p0, _dir), valid);
    IntersectionC _its1 = scene.ray_intersect<false>(RayC(_p0, -_dir), valid);

    Vector3fC _p2 = _its2.p;
    valid &= _its2.is_emitter(valid) && _its2.is_valid() && norm(_its2.p - _p2) < ShadowEpsilon;

    valid &= _its1.is_valid();
    Vector3fC &_p1 = _its1.p;

    FloatC      dist    = norm(_p2 - _p1),
                cos2    = abs(dot(_its2.n, -_dir));
    Vector3fC   e       = cross(bss.edge, _dir);
    FloatC      sinphi  = norm(e);
    Vector3fC   proj    = normalize(cross(e, _its2.n));
    FloatC      sinphi2 = norm(cross(_dir, proj));
    FloatC      base_v  = (_its1.t/dist)*(sinphi/sinphi2)*cos2;
    valid &= (sinphi > Epsilon) && (sinphi2 > Epsilon);

#if 0
    return (_its2.Le(valid)*(base_v/bss.pdf)) & valid;
#else
    size_t num_sam = slices(sample3);
    size_t num_kernal = num_sam * active_sensor.size();
    FloatC sensor_val(0.0);
    SpectrumC bsdf_val(0.0);

    // Maximum bsdf value for all sensors
    for (int i=0; i<active_sensor.size(); ++i) {
        SensorDirectSampleC sds = scene.m_sensors[active_sensor[i]]->sample_direct(_p1);
        RayC camera_ray = scene.m_sensors[active_sensor[i]]->sample_primary_ray(sds.q);
        Vector3fC d0 = -camera_ray.d;
        Vector3fC d0_local = _its1.sh_frame.to_local(d0);
        BSDFArrayC bsdf_array = _its1.shape->bsdf(valid);
        SpectrumC bsdf_val_temp = bsdf_array->eval(_its1, d0_local, valid) * 
                                  abs((_its1.wi.z()*dot(d0, _its1.n))/(d0_local.z()*dot(_dir, _its1.n)));
        bsdf_val = max(bsdf_val, bsdf_val_temp);
        sensor_val = max(sensor_val, sds.sensor_val);
    }
    return (bsdf_val*_its2.Le(valid)*(sensor_val*base_v/bss.pdf)) & valid;
#endif

}

template <bool ad>
std::pair<IntC, Spectrum<ad>> DirectIntegratorDual::eval_secondary_edge(const Scene &scene, const Sensor &sensor, const Vector3fC &sample3) const {
    BoundarySegSampleDirect bss;
    if (m_edge_direct) {
        bss = scene.sample_boundary_segment_direct(sample3);
    } else {

#if 0
        // draw edge per step
        std::cout << "Draw edge walking path" << std::endl;

        int interval = 100;
        FloatC rndtx = linspace<FloatC>(0.0, 1.0, interval);
        FloatC rndty = zero<FloatC>(interval);
        FloatC rndtz = zero<FloatC>(interval);
        Vector3fC rndt;
        rndt.x() = rndtx;
        rndt.y() = rndty;
        rndt.z() = rndtz;

        std::ofstream myfile;
        myfile.open ("walk.obj");

        BoundarySegSampleDirect bsst = scene.sample_edge_ray(rndt);
        Vector3fC tmp0 = detach(bsst.p0);

        for(int i=0; i<interval; ++i) {
            myfile << "v " << tmp0.x()[i] << " " << tmp0.y()[i] << " " << tmp0.z()[i] << "\n";
        }

        for(int i=0; i<interval-1; ++i) {
            myfile << "l " << i+1 << " " << i+2 << "\n";
        }

        myfile.close();
#endif

        bss = scene.sample_edge_ray(sample3);
    }

    MaskC valid = bss.is_valid;

    // _p0 on a face edge, _p2 on an emitter
    const Vector3fC &_p0    = detach(bss.p0);
    Vector3fC       _p2, _dir;

    if (m_edge_direct) {
        _p2  = bss.p2;
        _dir = normalize(_p2 - _p0);
    } else {
        _dir = bss.p2;
    }

    // check visibility between _p0 and _p2
    IntersectionC _its2;
    TriangleInfoD tri_info;
    if constexpr ( ad ) {
        _its2 = scene.ray_intersect<false>(RayC(_p0, _dir), valid, &tri_info);
    } else {
        _its2 = scene.ray_intersect<false>(RayC(_p0, _dir), valid); 
    }

    if (!m_edge_direct) {
        _p2 = _its2.p;
    }

    valid &= _its2.is_emitter(valid) && _its2.is_valid() && norm(_its2.p - _p2) < ShadowEpsilon;

    // trace another ray in the opposite direction to complete the boundary segment (_p1, _p2)
    IntersectionC _its1 = scene.ray_intersect<false>(RayC(_p0, -_dir), valid);
    valid &= _its1.is_valid();
    Vector3fC &_p1 = _its1.p;

    // project _p1 onto the image plane and compute the corresponding pixel id
    SensorDirectSampleC sds = sensor.sample_direct(_p1);
    valid &= sds.is_valid;

    // trace a camera ray toward _p1 in a differentiable fashion
    Ray<ad> camera_ray;
    Intersection<ad> its1;
    if constexpr ( ad ) {
        camera_ray = sensor.sample_primary_ray(Vector2fD(sds.q));
        its1 = scene.ray_intersect<true, false>(camera_ray, valid);
        valid &= its1.is_valid() && norm(detach(its1.p) - _p1) < ShadowEpsilon;
    } else {
        camera_ray = sensor.sample_primary_ray(sds.q);
        its1 = scene.ray_intersect<false>(camera_ray, valid);
        valid &= its1.is_valid() && norm(its1.p - _p1) < ShadowEpsilon;
    }

    // calculate base_value
    FloatC      dist    = norm(_p2 - _p1),
                cos2    = abs(dot(_its2.n, -_dir));
    Vector3fC   e       = cross(bss.edge, _dir);
    FloatC      sinphi  = norm(e);
    Vector3fC   proj    = normalize(cross(e, _its2.n));
    FloatC      sinphi2 = norm(cross(_dir, proj));
    FloatC      base_v  = (_its1.t/dist)*(sinphi/sinphi2)*cos2;
    valid &= (sinphi > Epsilon) && (sinphi2 > Epsilon);

    // evaluate BSDF at _p1
    SpectrumC bsdf_val;
    Vector3fC d0;
    if constexpr ( ad ) {
        d0 = -detach(camera_ray.d);
    } else {
        d0 = -camera_ray.d;
    }
    Vector3fC d0_local = _its1.sh_frame.to_local(d0);
    if ( scene.m_bsdfs.size() == 1U || scene.m_meshes.size() == 1U ) {
        const BSDF *bsdf = scene.m_meshes[0]->m_bsdf;
        bsdf_val = bsdf->eval(_its1, d0_local, valid);
    } else {
        BSDFArrayC bsdf_array = _its1.shape->bsdf(valid);
        bsdf_val = bsdf_array->eval(_its1, d0_local, valid);
    }
    // accounting for BSDF's asymmetry caused by shading normals
    FloatC correction = abs((_its1.wi.z()*dot(d0, _its1.n))/(d0_local.z()*dot(_dir, _its1.n)));
    masked(bsdf_val, valid) *= correction;

    SpectrumC value0 = (bsdf_val*_its2.Le(valid)*(base_v*sds.sensor_val/bss.pdf)) & valid;
    if constexpr ( ad ) {
        Vector3fC n = normalize(cross(_its2.n, proj));
        value0 *= sign(dot(e, bss.edge2))*sign(dot(e, n));

        const Vector3fD &v0 = tri_info.p0,
                        &e1 = tri_info.e1,
                        &e2 = tri_info.e2;

        RayD shadow_ray(its1.p, normalize(bss.p0 - its1.p));
        Vector2fD uv;
        std::tie(uv, std::ignore) = ray_intersect_triangle<true>(v0, e1, e2, shadow_ray);
        Vector3fD u2 = bilinear<true>(detach(v0), detach(e1), detach(e2), uv);

        SpectrumD result = (SpectrumD(value0)*dot(Vector3fD(n), u2)) & valid;
        return { select(valid, sds.pixel_idx, -1), result - detach(result) };
    } else {
        // returning the value without multiplying normal velocity for guiding
        return { -1, value0 };
    }
}

} // namespace psdr
