#include "boundary_unidir.h"

#include <algorithm1.h>
#include <core/statistics.h>
#include <render/intersection.h>
#include <render/scene.h>

static StatsCounter avgPathLength("Boundary Unidirectional", "Average path length", EAverage);

//=============================================================================
//              BoundaryUnidirectionalBase for Unidirectional Boundary
//=============================================================================
int BoundaryUnidirectionalBase::handleBoundary(const Scene &scene, RndSampler *sampler,
                                               const Medium *medium, const Ray &_ray,
                                               int max_bounces, const Spectrum &throughput,
                                               Float pdf) {
    Ray ray = _ray;
    // find all null intersections
    std::vector<Intersection> its_list = scene.rayIntersectAll(ray, false);
    if (its_list.empty())
        return 0;
    // estimate the boundary integral
    for (size_t i = 0; i < its_list.size(); i++) {
        const Intersection &its   = its_list[i];
        Spectrum            _Lins = Lins(scene, sampler, its, -ray.dir, max_bounces - 1);
        // eq. 31
        Float  transmittance = scene.evalTransmittance(ray.org, true, its.p, true, medium, sampler);
        Vector dir           = (its.p - ray.org).normalized();
        Float  d_u           = (throughput * ctx.dI * _Lins * its.ptr_med_int->sigS(its.p) * transmittance).sum() /
                    std::abs(ray.dir.dot(its.geoFrame.n));
        // MIS weight
        if (m_sampling_mode & EMISPath) {
            Float mis_weight = 1.;
            Vertex &prev = mis_ctx.vertices.back();
            Vertex curr       = Vertex::createMedium(its.p, its.ptr_med_int,
                                                     /* pdf_fwd */ prev.pdf_next * geometric(prev.p, its.p, its.geoFrame.n),
                                                     0. /* unused */);
            prev.pdf_rev = curr.convertDensity(INV_FOURPI, prev) * 
                        scene.evalTransmittance(its.p, true, prev.getP(), true, its.getTargetMedium(prev.getP() - its.p), sampler);
            curr.pdf_rev = scene.pdfMediumBoundaryPoint(its);
            mis_ctx.append(curr);
            Float pdfA       = mis_ctx.pdfFwd();
            Float pdfB       = mis_ctx.pdfRev();
            mis_weight = pdfA / (pdfA + scene.camera.getNumPixels() * pdfB);
            mis_ctx.vertices.pop_back();
            if(!isfinite(mis_weight))
            {
                PSDR_INFO("mis weight is not finite");
                return its_list.size();
            }
            d_u *= mis_weight;
        } else if (m_sampling_mode & (EMIS | EDebugMIS)) {
            Float mis_weight = 1.;
            Float pdfA       = pdf * geometric(ray.org, its.p, its.geoFrame.n);
            Float pdfB       = scene.pdfMediumBoundaryPoint(its);
            if (m_sampling_mode & (EMISBalance | EDebugMISBalance))
                mis_weight = misWeightBalance(pdfA, pdfB);
            else if (m_sampling_mode & (EMISPower | EDebugMISPower))
                mis_weight = misWeightPower(pdfA, pdfB);
            else
                Throw("Unknown sampling mode {}", m_sampling_mode);
            d_u *= mis_weight;
        }
        algorithm1::d_velocity(*ctx.sceneAD, its, d_u);
    }
    return its_list.size();
}

Spectrum BoundaryUnidirectionalBase::Lins(const Scene &scene, RndSampler *sampler,
                                          const Intersection &its,
                                          const Vector &wi, int max_bounces) {
    Spectrum ret(0.);
    if (max_bounces < 0)
        return ret;
    // nee emitter
    DirectSamplingRecord dRec(its.p);
    Spectrum             value = scene.sampleAttenuatedEmitterDirect(
        dRec, its, sampler->next2D(), sampler, nullptr /* unused */);

    const Medium        *med_int  = its.ptr_med_int;
    const PhaseFunction *phase    = scene.getPhase(med_int->phase_id);
    Float                phaseVal = phase->eval(wi, dRec.dir);
    Float mis_weight = misWeightPower(dRec.pdf, phase->pdf(wi, dRec.dir) * dRec.G);
    ret += value * phaseVal * mis_weight;

    // sample a direction from phase function
    Vector      wo;
    Float       phase_val = phase->sample(wi, sampler->next2D(), wo);
    VolpathBase vb;

    const Medium *medium = nullptr;
    if (its.isValid() && its.isMediumTransition())
            medium = its.getTargetMedium(wo);
    Intersection _its;
    value = scene.rayIntersectAndLookForEmitter(Ray(its.p, wo), true, sampler, medium, _its, dRec);
    mis_weight = misWeightPower(phase->pdf(wi, dRec.dir) * dRec.G, dRec.pdf);
    ret += value * phase_val * mis_weight;

    vb.stats_callback = [](const VolpathBase::Stats &stats) {
        avgPathLength += stats.depth;
    };

    Vector dir_in    = -its.geoFrame.n;
    Vector p_shifted = its.p + dir_in * 1e-2; // if 1e-3, will cause a lot of dead loops, same for Ray{p, wo}.shift()
    ret += phase_val * vb.Li(scene, sampler, its.getTargetMedium(wo),
                             Ray{ p_shifted, wo }, max_bounces, false /* inc_emission */);
    return ret;
}

void BoundaryUnidirectionalBase::handleSensor(const Scene &scene, RndSampler *sampler, const Ray &_ray,
                                              const Array2i &pixel_idx, const Medium *medium, int max_bounces) {
    Ray ray = _ray;
    // find all null intersections
    std::vector<Intersection> its_list = scene.rayIntersectAll(ray, false);
    if (its_list.empty())
        return;
    // estimate the boundary integral
    for (size_t i = 0; i < its_list.size(); i++) {
        const Intersection &its   = its_list[i];
        Spectrum            _Lins = Lins(scene, sampler, its, -ray.dir, max_bounces - 1);
        // eq. 31
        Float  transmittance = scene.evalTransmittance(ray.org, true, its.p, true, medium, sampler);
        Vector dir           = (its.p - ray.org).normalized();

        Float d_u = (ctx.dI * _Lins * its.ptr_med_int->sigS(its.p) * transmittance).sum();

        if (m_sampling_mode & EMISPath) {
            Float                      mis_weight = 1.;
            CameraDirectSamplingRecord cRec;
            if (!scene.camera.sampleDirect(its.p, cRec))
                PSDR_ERROR("sample direct failed");
            if (!scene.isVisible(its.p, true, scene.camera.cpos, true))
                PSDR_ERROR("not visible");
            Float pdfA1 = scene.camera.eval(pixel_idx.x(), pixel_idx.y(), its.p);
            Float pdfA = pdfA1 * std::abs(ray.dir.dot(its.geoFrame.n));
            Float pdfB  = scene.camera.getNumPixels() / scene.getMediumArea() *
                         scene.camera.pdfPixel(pixel_idx.x(), pixel_idx.y(), -cRec.dir); // REVIEW / 2
            mis_weight = pdfA1 * pdfA / (pdfA * pdfA + pdfB * pdfB); // power heuristic
            d_u *= mis_weight;
        } else if (m_sampling_mode & (EMIS | EDebugMIS)) {
            Float                      mis_weight = 1.;
            CameraDirectSamplingRecord cRec;
            if (!scene.camera.sampleDirect(its.p, cRec))
                PSDR_ERROR("sample direct failed");
            if (!scene.isVisible(its.p, true, scene.camera.cpos, true))
                PSDR_ERROR("not visible");

            Float pdfA1 = scene.camera.eval(pixel_idx.x(), pixel_idx.y(), its.p), // area measure
                pdfA    = pdfA1 * std::abs(ray.dir.dot(its.geoFrame.n)),
                  pdfB  = scene.camera.getNumPixels() / scene.getArea();

            if (m_sampling_mode & (EMISBalance | EDebugMISBalance)) {
                mis_weight = pdfA1 / (pdfA + pdfB);
            } else if (m_sampling_mode & (EMISPower | EDebugMISPower)) {
                mis_weight = pdfA * pdfA1 / (pdfA * pdfA + pdfB * pdfB);
            } else {
                Throw("unknown mis mode {}", m_sampling_mode);
            }

            d_u *= mis_weight;
        } else {
            d_u /= std::abs(ray.dir.dot(its.geoFrame.n));
        }
        algorithm1::d_velocity(*ctx.sceneAD, its, d_u);
    }
}

void BoundaryUnidirectionalBase::handleMedium(const Scene &scene, RndSampler *sampler,
                                              const Medium *medium, const Vector &p, 
                                              const Vector &wi, const Vector &wo,
                                              int max_bounces, const Spectrum &throughput) {
    const PhaseFunction *phase = scene.getPhase(medium->phase_id);
    assert(phase);
    if (max_bounces <= 0)
        return;
    Float  pdf_phase   = phase->pdf(wi, wo);
    Ray    ray{ p, wo };
    if (m_sampling_mode & EMISPath) {
        Vertex &prev = mis_ctx.vertices.back();
        Vertex  curr = Vertex::createMedium(p, medium,
                                            /* pdf_fwd */ prev.pdf_next * geometric(prev.getP(), p) *
                                                scene.evalTransmittance(p, false, prev.getP(), true, medium, sampler),
                                            pdf_phase); /* pdf_next */
        if (prev.type != Vertex::EType::ECamera)
            prev.pdf_rev = curr.convertDensity(phase->pdf(wo, wi), prev) *
                           scene.evalTransmittance(p, false, prev.getP(), true, medium, sampler);
        mis_ctx.append(curr);
        handleBoundary(scene, sampler, medium, ray, max_bounces, throughput, pdf_phase);
    } else {
        if (m_sampling_mode & ESolidAngleSampling)
            handleBoundary(scene, sampler, medium, ray, max_bounces, throughput, pdf_phase);
        if (m_sampling_mode & EAreaSampling) {
            PositionSamplingRecord pRec;
            auto                   rnd           = sampler->next2D();
            Intersection           its           = scene.sampleMediumBoundary(rnd, pRec);
            Vector                 dir           = (its.p - p).normalized();
            Float                  phase_value   = phase->eval(wi, dir);
            Spectrum               _Lins         = Lins(scene, sampler, its, -dir, max_bounces - 1);
            Float                  transmittance = scene.evalTransmittance(p, true, its.p, true, medium, sampler);
            Float                  d_u           = (throughput * ctx.dI * _Lins * its.ptr_med_int->sigS(its.p) * phase_value * transmittance * geometric(p, its.p)).sum() / pRec.pdf;
            Float                  pdfA          = scene.pdfMediumBoundaryPoint(its);
            Float                  pdfB          = phase->pdf(wi, dir) * geometric(p, its.p, its.geoFrame.n);
            if (m_sampling_mode & (EMIS | EDebugMIS)) {
                Float mis_weight = 1.;
                if (m_sampling_mode & (EMISBalance | EDebugMISBalance))
                    mis_weight = misWeightBalance(pdfA, pdfB);
                else if (m_sampling_mode & (EMISPower | EDebugMISPower))
                    mis_weight = misWeightPower(pdfA, pdfB);
                else
                    Throw("unknown mis mode");
                d_u *= mis_weight;
            }
            algorithm1::d_velocity(*ctx.sceneAD, its, d_u);
        }
    }
}

void BoundaryUnidirectionalBase::handleSurface(const Scene &scene, RndSampler *sampler,
                                               const Intersection &its,
                                               int max_bounces, const Spectrum &throughput) {
    if (max_bounces <= 0)
        return;
    Vector   wo;
    Float    bsdf_pdf, bsdf_eta;
    Spectrum bsdf_val = its.sampleBSDF(sampler->next3D(), wo, bsdf_pdf, bsdf_eta);
    wo                = its.toWorld(wo);
    Ray           ray{ its.p, wo };
    const Medium *medium = its.getTargetMedium(wo);
    if (m_sampling_mode & ESolidAngleSampling)
        handleBoundary(scene, sampler, medium, ray.shifted(), max_bounces, throughput * bsdf_val, bsdf_pdf);
    if (m_sampling_mode & EAreaSampling) {
        PositionSamplingRecord pRec;
        Intersection           its_b         = scene.sampleMediumBoundary(sampler->next2D(), pRec);
        Vector                 dir           = (its_b.p - its.p).normalized();
        Spectrum               bsdf_value    = its.evalBSDF(its.toLocal(dir));
        Spectrum               _Lins         = Lins(scene, sampler, its, -dir, max_bounces - 1);
        Float                  transmittance = scene.evalTransmittance(its.p, true, its_b.p, true, medium, sampler);
        Float                  d_u           = (throughput * ctx.dI * _Lins * its.ptr_med_int->sigS(its.p) * bsdf_value * transmittance * geometric(its.p, its_b.p)).sum() / pRec.pdf;
        Float                  pdfA          = scene.pdfMediumBoundaryPoint(its);
        Float                  pdfB          = its.pdfBSDF(its.toLocal(dir)) * geometric(its.p, its_b.p, its.geoFrame.n);
        if (m_sampling_mode & (EMIS | EDebugMIS)) {
            Float mis_weight = 1.;
            if (m_sampling_mode & (EMISBalance | EDebugMISBalance))
                mis_weight = misWeightBalance(pdfA, pdfB);
            else if (m_sampling_mode & (EMISPower | EDebugMISPower))
                mis_weight = misWeightPower(pdfA, pdfB);
            else
                Throw("unknown mis mode");
            d_u *= mis_weight;
        }
        algorithm1::d_velocity(*ctx.sceneAD, its, d_u);
    }
}

bool BoundaryUnidirectionalBase::isNext(int bounces) const {
    // only handle the light path with the first vertex being the boundary vertex
    return !(m_sampling_mode & EMISFirst);
}

// ================================================================================
//                              BoundaryUnidirectional
// ================================================================================
Spectrum BoundaryUnidirectional::Li(const Scene &scene, const Ray &ray, RadianceQueryRecord &rRec) const {
    assert(false);
    return Spectrum::Zero();
}

void BoundaryUnidirectional::LiAD(SceneAD &sceneAD, const Ray &ray, RadianceQueryRecord &rRec, const Spectrum &d_res) const {
    BoundaryUnidirectionalBase bub({ &sceneAD, d_res }, m_sampling_mode);

    bub.stats_callback = [](const VolpathBase::Stats &stats) {
        avgPathLength.incrementBase();
        avgPathLength += stats.depth;
    };

    bub.mis_ctx.append(Vertex::createCamera(
        sceneAD.val.camera, rRec.pixel_idx,
        /* pdf_next */ sceneAD.val.camera.evalDir(rRec.pixel_idx.x(), rRec.pixel_idx.y(), ray.dir),
        /* pdf_rev */ sceneAD.val.camera.pdfPixel(rRec.pixel_idx.x(), rRec.pixel_idx.y(), ray.dir)));

    bub.Li1(sceneAD.val, rRec.sampler, nullptr, ray, rRec.pixel_idx, rRec.max_bounces, rRec.incEmission);
}