#include <render/common.h>
#include <render/imageblock.h>
#include <core/ray.h>
#include <core/sampler.h>
#include <render/scene.h>
#include <core/timer.h>
#include <iomanip>
#include <algorithm1.h>
#include <core/statistics.h>
#include <render/spiral.h>
#include <fmt/core.h>
#include <core/math_func.h>
#include "boundary_bidir.h"


// ============================================================================
//                          Radiance tracer with nee
// ============================================================================

void RadianceTracer::handleNee(const Spectrum &value)
{
    PSDR_ASSERT_MSG(depth < radiances.size(), "depth = {}, radiances.size() = {}", depth, radiances.size());
    radiances[depth] = value;
    depth++;
}

Spectrum RadianceTracer::Li(
    const Scene &scene, RndSampler *sampler,
    const Medium *medium, const Ray &ray, int max_bounces, bool incEmission)
{
    depth = 0;
    radiances.clear();
    radiances.resize(max_bounces, Spectrum(0.));
    Spectrum v = VolpathBase::Li(scene, sampler, medium, ray, max_bounces, incEmission);
    Spectrum ret(0.);
    for (int i = 0; i < max_bounces; i++)
    {
        ret += radiances[i];
    }
    assert(v.isApprox(ret));
    return ret;
}

Spectrum RadianceTracer::_Lins(const Scene &scene, RndSampler *sampler,
                               const Intersection &its,
                               const Vector &wi, int max_bounces) {
    Spectrum ret(0.);
    if (max_bounces < 0)
        return ret;
    // nee emitter
    DirectSamplingRecord dRec(its.p);
    Spectrum             value = scene.sampleAttenuatedEmitterDirect(
                    dRec, its, sampler->next2D(), sampler, nullptr /* unused */);
    const Medium        *med_int  = its.ptr_med_int;
    const PhaseFunction *phase    = scene.getPhase(med_int->phase_id);
    Float                phaseVal = phase->eval(wi, dRec.dir);
    ret += value * phaseVal;
    handleNee(value * phaseVal);
    
    // sample a direction from phase function
    Vector wo;
    Float  phase_val = phase->sample(wi, sampler->next2D(), wo);
    // VolpathBase vb;
    Vector dir_in    = -its.geoFrame.n;
    Vector p_shifted = its.p + dir_in * 1e-2; // if 1e-3, will cause a lot of dead loops, same for Ray{p, wo}.shift()
    ret += phase_val * VolpathBase::Li(scene, sampler, its.getTargetMedium(wo),
                          Ray{ p_shifted, wo }, max_bounces, false /* inc_emission */);
    return ret;
}

__attribute__((optnone)) std::tuple<Spectrum, std::vector<Spectrum>>
RadianceTracer::sampleSource(
    const Scene &scene, RndSampler *sampler,
    const Medium *medium, const Vector &p, const Vector &wi, int max_bounces)
{
    depth = 0;
    radiances.clear();
    radiances.resize(max_bounces, Spectrum(0.));
    Spectrum value = Lins(scene, sampler, medium, p, wi, max_bounces - 1);
    for (int i = 1; i < radiances.size(); i++)
    {
        radiances[i] += radiances[i - 1];
    }
    assert(value.isApprox(radiances.back()));
    return { value, radiances };
}

std::tuple<Spectrum, std::vector<Spectrum>>
RadianceTracer::sampleSource(
    const Scene &scene, RndSampler *sampler,
    const Intersection &its, const Vector &wi, int max_bounces) {
    depth = 0;
    radiances.clear();
    radiances.resize(max_bounces, Spectrum(0.));
    Spectrum value = _Lins(scene, sampler, its, wi, max_bounces - 1);
    for (int i = 1; i < radiances.size(); i++) {
        radiances[i] += radiances[i - 1];
    }
    // radiances[0] = value; // FIXME
    PSDR_ASSERT_MSG(value.isApprox(radiances.back()),
                    "radiances.size = {}, value: {}, {}, {}, radiances.back(): {}, {}, {}",
                    radiances.size(),
                    value[0], value[1], value[2],
                    radiances.back()[0], radiances.back()[1], radiances.back()[2]);
    return { value, radiances };
}
// ============================================================================
//                          Radiance tracer without nee
// ============================================================================

Spectrum RadianceTracer2::handleNee(const Scene &scene, RndSampler *sampler,
                                    const Medium *medium, const Vector &p, const Vector &wi)
{
    return Spectrum(0.);
}

Spectrum RadianceTracer2::handleEmission(const Intersection &its, const Vector &wo, bool incEmission)
{
    assert(its.isEmitter());
    return its.Le(wo);
}

// ============================================================================
//                           ImportanceTracer
// ============================================================================

// connect the boundary interaction to the camera directly
void ImportanceTracer::handleBoundary(const Scene &scene, RndSampler *sampler,
                                      const Medium *medium, const Vector &p,
                                      const Spectrum &throughput) {
    CameraDirectSamplingRecord cRec;
    Spectrum                   transmittance{ scene.sampleAttenuatedSensorDirect(*ctx.its_b, sampler, cRec) }; // transmittance
    // sample 1 pixel amoung 4.
    auto [pixel_id, sensor_value /*contain G*/] = scene.camera.sampleDirectPixel(cRec, sampler->next1D());
    Array2i pixel_idx                           = unravel_index(pixel_id, scene.camera.getCropSize());
    if (!ctx.d_image->contains(pixel_idx))
        return;
    Spectrum d_value = ctx.d_image->get(pixel_idx);
    // radiances contain phase value
    d_value *= sensor_value * transmittance *          /* sensor subpath */
               throughput *                            /* sigS / pdf_B */
               ctx.radiances->at(ctx.max_bounces - 1); /* emitter subpath */
    if (m_sampling_mode & EMISPath) {
        // Path MIS
        Float mis_weight = 1.;
        Float nsampleA   = scene.camera.getNumPixels(); // FIXME: might need to be divided by 4
        Float pdfA       = nsampleA / scene.getMediumArea() *
                     scene.camera.pdfPixel(pixel_idx.x(), pixel_idx.y(), -cRec.dir);
        Float pdfB = scene.camera.eval(pixel_idx.x(), pixel_idx.y(), ctx.its_b->p) *
                     std::abs(cRec.dir.dot(ctx.its_b->geoFrame.n));
        mis_weight = pdfA * pdfA / (pdfA * pdfA + pdfB * pdfB); // balance heuristic
        d_value *= mis_weight;
    } else if (m_sampling_mode & (EMIS | EDebugMIS)) {
        // surface MIS
        Float mis_weight = 1.;
        Float pdfA       = scene.camera.getNumPixels() / scene.getArea(),
              pdfB       = scene.camera.evalFilter(pixel_idx.x(), pixel_idx.y(), ctx.its_b->p) *
                     cRec.baseVal * std::abs(cRec.dir.dot(ctx.its_b->geoFrame.n));
        // different MIS heuristics
        if (m_sampling_mode & (EMISBalance | EDebugMISBalance)) {
            mis_weight = misWeightBalance(pdfA, pdfB);
        } else if (m_sampling_mode & (EMISPower | EDebugMISPower)) {
            mis_weight = misWeightPower(pdfA, pdfB);
        } else {
            Throw("unknown sampling mode");
        }
        d_value *= mis_weight;
    }

#ifdef FORWARD
    ctx.sceneAD->getDer().zeroParameter();
#endif
    algorithm1::d_velocity(*ctx.sceneAD, *ctx.its_b, d_value.sum());
#ifdef FORWARD
    Float param = ctx.sceneAD->getDer().getParameter();
    ctx.grad_image->put(pixel_idx, Spectrum(param, 0, 0));
#endif
}

void ImportanceTracer::handleSensor(const Scene &scene, RndSampler *sampler,
                                    const CameraDirectSamplingRecord &cRec,
                                    int depth, const Spectrum &throughput) {
    // sample 1 pixel amoung 4.
    auto [pixel_id, sensor_value] = scene.camera.sampleDirectPixel(cRec, sampler->next1D());
    Array2i pixel_idx = unravel_index(pixel_id, scene.camera.getCropSize());
    if (!ctx.d_image->contains(pixel_idx))
        return;
    Spectrum d_value = ctx.d_image->get(pixel_idx);
    // radiances contain phase value
    assert(ctx.max_bounces - depth >= 0);
    assert(ctx.max_bounces - depth < ctx.radiances->size());
    d_value *= sensor_value * throughput * ctx.radiances->at(ctx.max_bounces - depth);

    if (m_sampling_mode & EMISPath) {
        Float   mis_weight = 1.;
        Vertex &prev       = mis_ctx.vertices.back();
        // prev.pdf_rev is a temporary value, it will be updated in the next iteration
        prev.pdf_rev = scene.camera.eval(pixel_idx.x(), pixel_idx.y(), prev.getP()) *
                       scene.evalTransmittance(scene.camera.cpos, true, prev.getP(), true, nullptr, sampler);
        if (prev.type != Vertex::EType::EMedium)
            prev.pdf_rev *= std::abs(cRec.dir.dot(prev.its->geoFrame.n));
        Float pdfA     = mis_ctx.pdfFwd() * scene.camera.pdfPixel(pixel_idx.x(), pixel_idx.y(), -cRec.dir);
        Float pdfB     = mis_ctx.pdfRev();
        Float nsampleA = scene.camera.getNumPixels();
        mis_weight     = pdfA * nsampleA /
                     (pdfA * nsampleA + pdfB); // balance heuristic
        if(!isfinite(mis_weight))
        {
            PSDR_INFO("mis weight is not finite");
            return;
        }
        d_value *= mis_weight;
        assert(isfinite(mis_weight));
    }

    // the mis only applies to the path with the first vertex being the boundary vertex
    if (depth == 1 && (m_sampling_mode & (EMIS | EDebugMIS))) {
        Float mis_weight = 1.;
        Float pdfA       = scene.camera.getNumPixels() / scene.getArea(),
              pdfB       = scene.camera.evalFilter(cRec.pixel_idx.x(), cRec.pixel_idx.y(), ctx.its_b->p) *
                     cRec.baseVal * std::abs(cRec.dir.dot(ctx.its_b->geoFrame.n));
        if (m_sampling_mode & (EMISBalance | EDebugMISBalance)) {
            mis_weight = misWeightBalance(pdfA, pdfB);
        } else if (m_sampling_mode & (EMISPower | EDebugMISPower)) {
            mis_weight = misWeightPower(pdfA, pdfB);
        } else {
            Throw("unknown sampling mode");
        }
        d_value *= mis_weight;
    }

#ifdef FORWARD
    int shape_idx                                      = scene.getShapeRequiresGrad();
    ctx.sceneAD->getDer().shape_list[shape_idx]->param = 0;
#endif
    algorithm1::d_velocity(*ctx.sceneAD, *ctx.its_b, d_value.sum());
#ifdef FORWARD
    Float param = ctx.sceneAD->getDer().shape_list[shape_idx]->param;
    ctx.grad_image->put(pixel_idx, Spectrum(param, 0, 0));
#endif
}

// connect to the camera
void ImportanceTracer::handleMedium(const Scene &scene, RndSampler *sampler,
                                    const Medium *medium, const Vector &p, 
                                    const Vector &wi, const Vector &wo,
                                    int depth, const Spectrum &throughput) {
    CameraDirectSamplingRecord cRec;

    Spectrum value{ scene.sampleAttenuatedSensorDirect(p, medium, sampler, cRec) };

    auto  *phase = scene.getPhase(medium->phase_id);
    if (m_sampling_mode & EMISPath) {
        Vertex &prev = mis_ctx.vertices.back();
        Vertex curr = Vertex::createMedium(p, medium,
                                           /* pdf_fwd */ prev.pdf_next * geometric(prev.getP(), p) *
                                               scene.evalTransmittance(p, false, prev.getP(), true, medium, sampler),
                                           /* pdf_next */ phase->pdf(wi, wo));
        prev.pdf_rev = curr.convertDensity(phase->pdf(wo, wi), prev);
        if (prev.type != Vertex::EBoundary)
            prev.pdf_rev *= scene.evalTransmittance(p, false, prev.getP(), true, medium, sampler);
        mis_ctx.append(curr);
    }
    if (!value.isZero(Epsilon) && cRec.baseVal > Epsilon) {
        value *= phase->eval(wi, cRec.dir) * throughput;
        if (!value.isZero(Epsilon))
            handleSensor(scene, sampler, cRec, depth, value);
    }
}

// ============================================================================
//                           Importance Tracer 2
// ============================================================================
void ImportanceTracer2::handleSensor(const Scene &scene, RndSampler *sampler,
                                     const CameraDirectSamplingRecord &cRec,
                                     int depth, const Spectrum &throughput)
{
    auto [pixel_id, sensor_value] = scene.camera.sampleDirectPixel(cRec, sampler->next1D());
    Array2i pixel_idx = unravel_index(pixel_id, scene.camera.getCropSize());
    if (!d_image.contains(pixel_idx))
        return;
    Spectrum d_value = d_image.get(pixel_idx);
    // radiances contain phase value
    d_value *= sensor_value * throughput * radiance;
#ifdef FORWARD
    int shape_idx = scene.getShapeRequiresGrad();
    sceneAD.getDer().shape_list[shape_idx]->param = 0;
#endif
    algorithm1::d_velocity(sceneAD, its_b, d_value.sum());

#ifdef FORWARD
    Float param = sceneAD.getDer().shape_list[shape_idx]->param;
    grad_image.put(pixel_idx, Spectrum(param, 0, 0));
#endif
}

// =============================================================================
//                              BoundaryBidirectional
// =============================================================================
BoundaryBidirectional::BoundaryBidirectional(const Scene &scene)
{
    m_shapeDistribution = buildShapeDistribution(scene);
    if (m_adaptive)
        m_faceDistributions = buildFaceDistributions(scene);
}

BoundaryBidirectional::BoundaryBidirectional(const Properties &props)
    : Integrator(props), MISIntegrator(props) {
    m_adaptive      = props.get<bool>("adaptive", false);
    m_adaptive_mode = props.get<int>("adaptive_mode", 0);
}

void BoundaryBidirectional::configure(const Scene &scene) {
    m_shapeDistribution = buildShapeDistribution(scene);
    if (m_adaptive)
        m_faceDistributions = buildFaceDistributions(scene);
}

DiscreteDistribution BoundaryBidirectional::buildShapeDistribution(const Scene &scene) const
{
    DiscreteDistribution dist;
    for (const auto &shape : scene.shape_list)
    {
        // FIXME: check if the shape contains a medium, not sure this is a valid check in a complex scene
        if (shape->isMediumTransition())
            dist.append(shape->getArea());
        else
            dist.append(0);
    }
    dist.normalize();
    return dist;
}

std::vector<DiscreteDistribution> BoundaryBidirectional::buildFaceDistributions(const Scene &scene) const {
    std::vector<DiscreteDistribution> dists;
    RndSampler                        sampler(0, 0);
    for (const auto &shape : scene.shape_list) {
        DiscreteDistribution dist;
        for (int i = 0; i < shape->num_triangles; ++i) {
            const auto &ind    = shape->getIndices(i);
            const auto &v0     = shape->getVertex(ind(0));
            const auto &v1     = shape->getVertex(ind(1));
            const auto &v2     = shape->getVertex(ind(2));
            Vector      center = (v0 + v1 + v2) / 3;
            Float       d      = (center - scene.camera.cpos).norm();
            Float       trans  = scene.evalTransmittance(scene.camera.cpos, false,
                                                         center, true, nullptr, &sampler);
            switch (m_adaptive_mode) {
                case 0:
                    /* code */
                    dist.append(1. / square(d));
                    break;
                case 1:
                    dist.append(trans / square(d));
                    break;
                default:
                    break;
            }
        }
        dist.normalize();
        dists.push_back(dist);
    }
    return dists;
}

std::pair<Intersection, Float> BoundaryBidirectional::sampleBoundaryPoint(
    const Scene &scene, const Array2 &_rnd) const {
    Intersection           its;
    Array2                 rnd(_rnd);
    Float                  pdf;
    // NOTE: this might affect the mis weight
    int                    shape_id = m_shapeDistribution.sampleReuse(rnd[0], pdf);
    PositionSamplingRecord pRec;
    const Shape           *shape  = scene.getShape(shape_id);
    int                    tri_id = m_adaptive
                                        ? shape->samplePosition(rnd, pRec, &m_faceDistributions[shape_id])
                                        : shape->samplePosition(rnd, pRec);
    pdf *= pRec.pdf;
    its.indices[0]                = shape_id;
    its.indices[1]                = tri_id;
    its.barycentric               = pRec.uv;
    its.ptr_shape                 = shape;
    its.ptr_bsdf                  = scene.getBSDF(shape->bsdf_id);
    its.ptr_med_ext               = scene.getMedium(shape->med_ext_id);
    its.ptr_med_int               = scene.getMedium(shape->med_int_id);
    its.p                         = pRec.p;
    its.geoFrame                  = Frame(pRec.n);
    its.shFrame                   = Frame(pRec.n);
    its.pdf                       = pdf;
    its.J                         = 1.0;
    return { its, pdf };
}

void BoundaryBidirectional::sampleBoundary(SceneAD &sceneAD, VolumeBoundaryQueryRecord &bRec) const
{
    const Scene &scene = sceneAD.val;
    RndSampler *sampler = bRec.sampler;
    int max_bounces = bRec.max_bounces;
    // 1. sample a point on the volume boundary
    auto [its_b, pdf_b] = sampleBoundaryPoint(scene, sampler->next2D());
    const Medium *medium = its_b.ptr_med_int;
    assert(medium);

    const PhaseFunction *phase = scene.getPhase(medium->phase_id);
    Vector wo = squareToUniformSphere(sampler->next2D());
    // 2. sample the source subpath

    // FIXME : wi
    Vector dir_in = -its_b.geoFrame.n;
    Vector p_offset = its_b.p + ShadowEpsilon * dir_in;
    // RadianceTracer rt;
    RadianceTracer radianceTracer(m_sampling_mode);
    // auto [lins, radiances] = rt.sampleSource(scene, sampler, medium, p_offset, wo, max_bounces);
    auto [lins, radiances] = radianceTracer.sampleSource(scene, sampler, its_b, wo, max_bounces);
    // 3. sample the detector subpath
    // 4. merge the two subpaths
    // 5. compute the boundary normal velocity
    // 6. compute the boundary term
    // 7. accumulate the boundary term to the gradient image

    // without nee
    // {
    //     ImportanceTracer2 pt{sceneAD, its_b, max_bounces, bRec.d_image, lins, bRec.grad_image};
    //     pt.handleMedium(scene, sampler, medium, p_offset, 1, Spectrum::Ones() * medium->sigS(p_offset) / pdf_b);

    //     ImportanceTracer2 pt2{sceneAD, its_b, max_bounces, bRec.d_image, lins * medium->sigS(p_offset) / pdf_b / INV_FOURPI, bRec.grad_image};
    //     pt2.importance(scene, sampler, medium, Ray{p_offset, wo}, false, 1, max_bounces, Spectrum::Ones());
    // }
    // with nee
    {
        ImportanceTracer importanceTracer{ { &sceneAD, &its_b, max_bounces, &bRec.d_image, &bRec.grad_image, &radiances },
                                           m_sampling_mode };
        // connect to camera directly
        if (not (m_sampling_mode & ESkipSensor))
            importanceTracer.handleBoundary(scene, sampler, medium, its_b.p, Spectrum::Ones() * medium->sigS(its_b.p) / pdf_b);
        
        // add boundary interaction into mis_ctx
        importanceTracer.mis_ctx.append(ImportanceTracer::Vertex::createBoundary(its_b, pdf_b, INV_FOURPI));
        
        // trace importance
        if(not (m_sampling_mode & (EMIS | EDebugMIS))) /* surface MIS doesn't need boundary bidirectional integrator */
        {    
            importanceTracer.importance(scene, sampler, medium, Ray{ p_offset, wo }, false, 1, max_bounces,
                                        Spectrum::Ones() * medium->sigS(its_b.p) / pdf_b / INV_FOURPI);
        }
    }
}

ArrayXd BoundaryBidirectional::renderD(SceneAD &sceneAD, const RenderOptions &options, const ArrayXd &_d_image) const {
    PSDR_INFO("BoundaryBidirectional renderD with spp = {}, ", options.num_samples);
    const Scene            &scene   = sceneAD.val;
    [[maybe_unused]] Scene &d_scene = sceneAD.der;
    GradientManager<Scene> &gm      = sceneAD.gm;
    gm.setZero(); // zero multi-thread gradient

    const int nworker = omp_get_num_procs();
    const auto &camera = scene.camera;
    Spiral spiral(camera.getCropSize(), camera.getOffset(), 8 /* block_size */);
    ImageBlock d_image = ImageBlock(camera.getOffset(),
                                    camera.getCropSize(),
                                    _d_image / options.num_samples / camera.getCropSize().prod());
    ImageBlock grad_image = ImageBlock(camera.getOffset(), camera.getCropSize());
    ThreadManager thread_manager(grad_image, nworker);
    Timer _("Boundary Bidirectional");

#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int i = 0; i < spiral.block_count(); i++)
    {
        auto [offset, size, block_id] = spiral.next_block();
        ImageBlock block(offset, size);
        const int tid = omp_get_thread_num();
        for (Array2i pixelIdx = block.curPixel(); block.hasNext(); pixelIdx = block.nextPixel())
        {
            int pixel_ravel_idx = ravel_multi_index(pixelIdx, camera.getCropSize());
            RndSampler sampler(options.seed, pixel_ravel_idx);
            for (int j = 0; j < options.num_samples; j++)
            {
                VolumeBoundaryQueryRecord bRec{&sampler, options.max_bounces, d_image, thread_manager.get(tid)};
                sampleBoundary(sceneAD, bRec);
            }
        }
        if (verbose)
#pragma omp critical
            progressIndicator(static_cast<Float>(spiral.block_counter()) / spiral.block_count());
    }
    if (verbose)
        std::cout << std::endl;

    // merge d_scenes
    gm.merge();
    thread_manager.merge();
    /* normal related */
#ifdef NORMAL_PREPROCESS
    Timer preprocess_timer("preprocess");
    d_precompute_normal(scene, d_scene);
#endif
    d_scene.configureD(scene);
    return grad_image.flattened();
}

ArrayXd BoundaryBidirectional::renderC(const Scene &scene, const RenderOptions &options) const
{
    return ArrayXd::Zero(1, 3);
}