#include "path2anti.h"
#include <render/common.h>
#include <render/imageblock.h>
#include <core/ray.h>
#include <core/sampler.h>
#include <render/scene.h>
#include <core/timer.h>
#include <iomanip>
#include "algorithm1.h"
#include <signal.h>

namespace path2anti
{
    Spectrum __Li(const Scene &scene, const Ray &_ray, RadianceQueryRecord &rRec, LightPath *path, LightPath *pathDual, bool &__hasAntithetic)
    {
        Ray ray(_ray), rayDual(_ray);
        RndSampler *sampler = rRec.sampler;
        Array2i pixel_idx = rRec.pixel_idx;
        Intersection its, itsDual;
        if (path)
        {
            path->clear(pixel_idx);
            path->append(scene.camera); // NOTE
        }
        if (pathDual)
        {
            pathDual->clear(pixel_idx);
            pathDual->append(scene.camera); // NOTE
        }

        const int max_depth = rRec.max_bounces;
        bool hasPrimary = true, hasAntithetic = false;
        Spectrum ret = Spectrum::Zero();
        scene.rayIntersect(ray, true, its);
        if (!its.isValid()){
            __hasAntithetic = hasAntithetic;
            return Spectrum::Zero();
        }

        its.pdf = 1.;
        if (path) path->append(its);
        if (pathDual) pathDual->append(its);
        if (its.ptr_bsdf->isConductor()){
            hasAntithetic = true;
        }
        else {
            hasAntithetic = false;
        }

        Spectrum throughput = Spectrum::Ones();
        Spectrum throughputDual = Spectrum::Ones();
        Float weightPrim = 1.0, weightDual = 0.0;
        Vector wo, woDual;

        // Float eta = 1.0f;

        // directly seen emitter
        if (its.isEmitter())
            ret += throughput * its.Le(-ray.dir);

        //first bounce nee
        DirectSamplingRecord dRec(its);
        auto value = scene.sampleEmitterDirect(sampler->next2D(), dRec);
        wo = its.toLocal(dRec.dir);
        if (!value.isZero(Epsilon))
        {
            auto bsdf_val = its.evalBSDF(wo);
            Float bsdf_pdf = its.pdfBSDF(wo);
            Float pdf_nee = dRec.pdf / geometric(its.p, dRec.p, dRec.n);
            auto mis_weight = square(pdf_nee) / (square(pdf_nee) + square(bsdf_pdf));
            ret += throughput * value * bsdf_val * mis_weight;

            dRec.pdf /= mis_weight;  // NOTE
            if (path) path->append_nee({dRec}); // NOTE
        }

        // sample first bounce bsdf
        Float bsdf_pdf, bsdf_pdf_dual, bsdf_eta, bsdf_eta_dual;
        Spectrum bsdf_weight, bsdf_weight_dual ;
        Vector pre_p;
        Array3d rnd = sampler->next3D();
        if (!hasAntithetic){
            bsdf_weight= its.sampleBSDF(rnd, wo, bsdf_pdf, bsdf_eta);
            if (bsdf_weight.isZero(Epsilon))
                hasPrimary = false;
            wo = its.toWorld(wo);
            ray = Ray(its.p, wo);

            pre_p = its.p;
            if (!scene.rayIntersect(ray, true, its))
                hasPrimary = false;

            throughput *= bsdf_weight;
        } else {
            its.sampleBSDFDual(rnd, wo, woDual, bsdf_pdf, bsdf_pdf_dual, bsdf_eta, bsdf_eta_dual, bsdf_weight, bsdf_weight_dual);
            if (bsdf_weight.isZero(Epsilon))
                hasPrimary = false;
            if (bsdf_weight_dual.isZero(Epsilon))
                hasAntithetic = false;
            wo = its.toWorld(wo);
            woDual = its.toWorld(woDual);
            ray = Ray(its.p, wo);
            rayDual = Ray(its.p, woDual);

            pre_p = its.p;
            if (!scene.rayIntersect(ray, true, its))
                hasPrimary = false;
            if (!scene.rayIntersect(rayDual, true, itsDual))
                hasAntithetic = false;

            throughput *= bsdf_weight;
            throughputDual *= bsdf_weight_dual;
            if (hasPrimary && hasAntithetic){
                weightPrim = bsdf_pdf / (bsdf_pdf + bsdf_pdf_dual);
                weightDual = bsdf_pdf_dual / (bsdf_pdf + bsdf_pdf_dual);
                if(!(bsdf_pdf > 1e-10 && bsdf_pdf_dual > 1e-10)) {
                    std::cout << "bsdf_pdf: " << bsdf_pdf << std::endl;
                    std::cout << "bsdf_pdf_dual: " << bsdf_pdf_dual << std::endl;
                    std::cout << "bsdf_weight: " << bsdf_weight << std::endl;
                    std::cout << "bsdf_weight_dual: " << bsdf_weight_dual << std::endl;
                    assert(false);
                }
            } else if (hasPrimary){
                weightPrim = 1.0;
                weightDual = 0.0;
            } else if (hasAntithetic){
                weightPrim = 0.0;
                weightDual = 1.0;
            } else {
                __hasAntithetic = hasAntithetic;
                return ret;
            }
        }
        // eta *= bsdf_eta;

        // eval direct lighting
        if (hasPrimary){
            if (its.isEmitter())
            {
                Spectrum light_contrib = its.Le(-ray.dir);
                if (!light_contrib.isZero(Epsilon))
                {
                    auto dist_sq = (its.p - ray.org).squaredNorm();
                    auto geometry_term = its.wi.z() / dist_sq;
                    Float pdf_nee = scene.pdfEmitterSample(its) / geometry_term;
                    auto mis_weight = square(bsdf_pdf) / (square(pdf_nee) + square(bsdf_pdf)) * weightPrim;
                    ret += throughput * light_contrib * mis_weight;

                    its.pdf = bsdf_pdf * geometric(pre_p, its.p, its.geoFrame.n) / mis_weight;
                    if (path) path->append_bsdf(its); // NOTE
                }
            }
            its.pdf = bsdf_pdf * geometric(pre_p, its.p, its.geoFrame.n);
            if (path) path->append(its); // NOTE
        }
        if (hasAntithetic) {
            if (itsDual.isEmitter())
            {
                Spectrum light_contrib = itsDual.Le(-rayDual.dir);
                if (!light_contrib.isZero(Epsilon))
                {
                    auto dist_sq = (itsDual.p - rayDual.org).squaredNorm();
                    auto geometry_term = itsDual.wi.z() / dist_sq;
                    Float pdf_nee = scene.pdfEmitterSample(itsDual) / geometry_term;
                    auto mis_weight = square(bsdf_pdf_dual) / (square(pdf_nee) + square(bsdf_pdf_dual)) * weightDual;
                    ret += throughputDual * light_contrib * mis_weight;

                    itsDual.pdf = bsdf_pdf_dual * geometric(pre_p, itsDual.p, itsDual.geoFrame.n) / mis_weight;
                    if (pathDual) pathDual->append_bsdf(itsDual); // NOTE
                }
            }
            itsDual.pdf = bsdf_pdf_dual * geometric(pre_p, itsDual.p, itsDual.geoFrame.n);
            if (pathDual) pathDual->append(itsDual); // NOTE
        }

        
        sampler->save();
        if (hasPrimary){
            Intersection _its = its;
            Ray _ray = ray;
            LightPath* _path = path;
            Spectrum _throughput = throughput;
            Float _weight = weightPrim;
            for (int depth = 1; depth < max_depth && _its.isValid(); depth++)
            {
                // Direct illumination
                Float _pdf_nee;
                Vector _wo;
                DirectSamplingRecord _dRec(_its);
                auto _value = scene.sampleEmitterDirect(sampler->next2D(), _dRec);
                _wo = _its.toLocal(_dRec.dir);
                if (!_value.isZero(Epsilon))
                {
                    auto _bsdf_val = _its.evalBSDF(_wo);
                    Float _bsdf_pdf = _its.pdfBSDF(_wo);
                    _pdf_nee = _dRec.pdf / geometric(_its.p, _dRec.p, _dRec.n);
                    auto _mis_weight = _weight;
                    // auto _mis_weight = 1.0;
                    ret += _throughput * _value * _bsdf_val * _mis_weight;

                    _dRec.pdf /= _mis_weight;  // NOTE
                    if (_path) _path->append_nee({_dRec}); // NOTE
                }
                // Indirect illumination
                Float _bsdf_pdf, _bsdf_eta;
                auto _bsdf_weight = _its.sampleBSDF(sampler->next3D(), _wo, _bsdf_pdf, _bsdf_eta);
                if (_bsdf_weight.isZero(Epsilon))
                    break;
                _wo = _its.toWorld(_wo);
                _ray = Ray(_its.p, _wo);

                Vector _pre_p = _its.p;
                if (!scene.rayIntersect(_ray, true, _its))
                    break;

                _throughput *= _bsdf_weight;
                // eta *= bsdf_eta;

                if (_its.isEmitter())
                {
                    Spectrum _light_contrib = _its.Le(-_ray.dir);
                    if (!_light_contrib.isZero(Epsilon))
                    {
                        auto _dist_sq = (_its.p - _ray.org).squaredNorm();
                        auto _geometry_term = _its.wi.z() / _dist_sq;
                        _pdf_nee = scene.pdfEmitterSample(_its) / _geometry_term;
                        auto _mis_weight = square(_bsdf_pdf) / (square(_pdf_nee) + square(_bsdf_pdf)) * _weight;
                        // ret += _throughput * _light_contrib * _mis_weight;

                        _its.pdf = _bsdf_pdf * geometric(pre_p, _its.p, _its.geoFrame.n) / _mis_weight;
                        // if (_path) _path->append_bsdf(_its); // NOTE
                    }
                }
                _its.pdf = _bsdf_pdf * geometric(_pre_p, _its.p, _its.geoFrame.n);
                if (_path) _path->append(_its); // NOTE
            }
        }

        sampler->restore();
        if (hasAntithetic){
            Intersection _its = itsDual;
            Ray _ray = rayDual;
            LightPath* _path = pathDual;
            Spectrum _throughput = throughputDual;
            Float _weight = weightDual;
            for (int depth = 1; depth < max_depth && _its.isValid(); depth++)
            {
                // Direct illumination
                Float _pdf_nee;
                Vector _wo;
                DirectSamplingRecord _dRec(_its);
                auto _value = scene.sampleEmitterDirect(sampler->next2D(), _dRec);
                _wo = _its.toLocal(_dRec.dir);
                if (!_value.isZero(Epsilon))
                {
                    auto _bsdf_val = _its.evalBSDF(_wo);
                    Float _bsdf_pdf = _its.pdfBSDF(_wo);
                    _pdf_nee = _dRec.pdf / geometric(_its.p, _dRec.p, _dRec.n);
                    auto _mis_weight = _weight;
                    // auto _mis_weight = 1.0;
                    ret += _throughput * _value * _bsdf_val * _mis_weight;

                    _dRec.pdf /= _mis_weight;  // NOTE
                    if (_path) _path->append_nee({_dRec}); // NOTE
                }
                // Indirect illumination
                Float _bsdf_pdf, _bsdf_eta;
                auto _bsdf_weight = _its.sampleBSDF(sampler->next3D(), _wo, _bsdf_pdf, _bsdf_eta);
                if (_bsdf_weight.isZero(Epsilon))
                    break;
                _wo = _its.toWorld(_wo);
                _ray = Ray(_its.p, _wo);

                Vector _pre_p = _its.p;
                if (!scene.rayIntersect(_ray, true, _its))
                    break;

                _throughput *= _bsdf_weight;
                // eta *= bsdf_eta;

                if (_its.isEmitter())
                {
                    Spectrum _light_contrib = _its.Le(-_ray.dir);
                    if (!_light_contrib.isZero(Epsilon))
                    {
                        auto _dist_sq = (_its.p - _ray.org).squaredNorm();
                        auto _geometry_term = _its.wi.z() / _dist_sq;
                        _pdf_nee = scene.pdfEmitterSample(_its) / _geometry_term;
                        auto _mis_weight = square(_bsdf_pdf) / (square(_pdf_nee) + square(_bsdf_pdf)) * _weight;
                        // ret += _throughput * _light_contrib * _mis_weight;

                        _its.pdf = _bsdf_pdf * geometric(pre_p, _its.p, _its.geoFrame.n) / _mis_weight;
                        // if (_path) _path->append_bsdf(_its); // NOTE
                    }
                }
                _its.pdf = _bsdf_pdf * geometric(_pre_p, _its.p, _its.geoFrame.n);
                if (_path) _path->append(_its); // NOTE
            }
        }
        __hasAntithetic = hasAntithetic;
        return ret;
    }
} // namespace path2_meta

Spectrum Path2Anti::Li(const Scene &scene, const Ray &ray, RadianceQueryRecord &rRec) const
{
#if 0
    LightPath path;    
    auto value = path2_meta::__Li(scene, ray, rRec, &path);
    Spectrum ret = Spectrum::Zero();
    if (!value.isZero(Epsilon))
    {
        ret = algorithm1::eval(scene, path, rRec.sampler);
        if (!ret.allFinite())
            ret.setZero();
    }
    return ret;
#else
    bool hasAntithetic;
    return path2anti::__Li(scene, ray, rRec, nullptr, nullptr, hasAntithetic);
#endif
}

void Path2Anti::LiAD(SceneAD &sceneAD, const Ray &ray, RadianceQueryRecord &rRec, const Spectrum &d_res) const
{
    LightPath path, pathDual;
    bool hasAntithetic = false;
    Spectrum value = path2anti::__Li(sceneAD.val, ray, rRec, &path, &pathDual, hasAntithetic);
    LightPathAD pathAD(path), pathDualAD(pathDual);
    if (!value.isZero(Epsilon)){
        algorithm1::d_eval(sceneAD.val, sceneAD.getDer(), pathAD, d_res, rRec.sampler);
        if (hasAntithetic) {
            algorithm1::d_eval(sceneAD.val, sceneAD.getDer(), pathDualAD, d_res, rRec.sampler);
        }
    }
        
}
