#include "pathADps.h"
#include "scene.h"
#include "sampler.h"
#include "rayAD.h"
#include "intersectionAD.h"
#include <iomanip>

#define ALTERNATIVE_EDGE_SAMPLING

Spectrum PathTracerAD_PathSpace::Li(const Scene& scene, RndSampler* sampler, const Ray& _ray, int max_depth) const {
    Ray ray(_ray);
    Intersection its;
    Spectrum ret = Spectrum::Zero();
    scene.rayIntersect(ray, true, its);
    if (its.isValid()) {
        Spectrum throughput = Spectrum::Ones();
        Float eta = 1.0f;
        int depth = 0;
        while (depth <= max_depth && its.isValid()) {
            if (its.isEmitter() && depth == 0)
                ret += throughput*its.Le(-ray.dir);
            if (depth >= max_depth) break;
            // Direct illumination
            Float pdf_nee;
            Vector wo;
            auto value = scene.sampleEmitterDirect(its, sampler->next2D(), sampler, wo, pdf_nee);
            if (!value.isZero()) {
                auto bsdf_val = its.evalBSDF(wo);
                Float bsdf_pdf = its.pdfBSDF(wo);
                auto mis_weight = square(pdf_nee)/(square(pdf_nee) + square(bsdf_pdf));
                ret += throughput * value * bsdf_val * mis_weight;
            }
            // Indirect illumination
            Float bsdf_pdf, bsdf_eta;
            auto bsdf_weight = its.sampleBSDF(sampler->next3D(), wo, bsdf_pdf, bsdf_eta);
            if (bsdf_weight.isZero())
                break;

            wo = its.toWorld(wo);
            ray = Ray(its.p, wo);

            if (!scene.rayIntersect(ray, true, its))
                break;

            throughput *= bsdf_weight;
            eta *= bsdf_eta;
            if (its.isEmitter()) {
                Spectrum light_contrib = its.Le(-ray.dir);
                if (!light_contrib.isZero()) {
                    auto dist_sq = (its.p - ray.org).squaredNorm();
                    auto geometry_term = its.geoFrame.n.dot(-ray.dir)/dist_sq;
                    pdf_nee = scene.pdfEmitterSample(its)/geometry_term;
                    auto mis_weight = square(bsdf_pdf)/(square(pdf_nee) + square(bsdf_pdf));
                    ret += throughput * light_contrib * mis_weight;
                }
            }

            depth++;
        }
    }
    return ret;
}

SpectrumAD PathTracerAD_PathSpace::LiAD(const Scene &scene, RndSampler* sampler, const RayAD &_ray, int max_depth) const {
    RayAD ray(_ray);
    IntersectionAD its;
    if ( !scene.rayIntersectAD(ray, false, its) ) return SpectrumAD();
    assert(its.isValid());

    SpectrumAD ret, throughput(Spectrum::Ones());
    for ( int depth = 0; depth <= max_depth; ++depth ) {
        if ( its.isEmitter() && depth == 0 )
            ret += throughput * its.Le(-ray.dir);
        if (depth >= max_depth) break;

        Float pdf1, pdf2;
        // Light sampling
        {
            VectorAD wo;
            Float G;
            SpectrumAD value = scene.sampleEmitterDirectAD(its, sampler->next2D(), sampler, wo, pdf1, &G);
            if (!value.isZero(Epsilon)) {
                SpectrumAD bsdf_val = its.evalBSDF(wo);
                pdf2 = its.pdfBSDF(wo.val)*G;
                ret += throughput*value*bsdf_val*(pdf1/(pdf1 + pdf2));
            }
        }

        // BSDF sampling
        {
            Intersection _its = its.toIntersection();

            Vector wo_local;
            Float bsdf_eta;
            if ( _its.sampleBSDF(sampler->next3D(), wo_local, pdf2, bsdf_eta).isZero(Epsilon) ) break;

            Vector wo = _its.toWorld(wo_local), wi = -ray.dir.val;
            if ( wi.dot(_its.geoFrame.n)*_its.wi.z() < Epsilon || wo.dot(_its.geoFrame.n)*wo_local.z() < Epsilon ) break;

            Intersection its1;
            IntersectionAD its1_AD;
            FloatAD J;
            if ( !scene.rayIntersect(Ray(_its.p, wo), true, its1) ) break;
            assert(its1.isValid());
            scene.getPoint(its1, its.p, its1_AD, J);

            VectorAD dir = its1_AD.toWorld(its1_AD.wi);     // Pointing from the intersection point
            FloatAD G = its1_AD.geoFrame.n.dot(dir).abs()/its1_AD.t.square();
            pdf2 *= G.val;
            throughput *= its.evalBSDF(its.toLocal(-dir))*G*J/pdf2;
            if ( throughput.isZero(Epsilon) ) break;

            if ( its1.isEmitter()) {
                SpectrumAD light_contrib = its1_AD.Le(dir);
                if ( !light_contrib.isZero(Epsilon) ) {
                    pdf1 = scene.pdfEmitterSample(its1);
                    ret += throughput*light_contrib*(pdf2/(pdf1 + pdf2));
                }
            }

            ray.org = its.p;
            ray.dir = -dir;
            its = its1_AD;
        }
    }
    return ret;
}

Spectrum PathTracerAD_PathSpace::pixelColor(const Scene &scene, const RenderOptions &options, RndSampler *sampler, Float x, Float y) const {
    return Li(scene, sampler, Ray(scene.camera.samplePrimaryRay(x, y)), options.max_bounces);
}

SpectrumAD PathTracerAD_PathSpace::pixelColorAD(const Scene &scene, const RenderOptions &options, RndSampler *sampler, Float x, Float y) const {
    return LiAD(scene, sampler, RayAD(scene.camera.samplePrimaryRayAD(x, y)), options.max_bounces);
}

void PathTracerAD_PathSpace::render(const Scene &scene, const RenderOptions &options, ptr<float> rendered_image) const {
    IntegratorAD_PathSpace::render(scene, options, rendered_image);

#ifdef USE_BOUNDARY_NEE
    if ( options.num_samples_secondary_edge_direct > 0 )
        renderEdgesDirect(scene, options, rendered_image);
#endif

    if ( options.num_samples_secondary_edge_indirect > 0 )
        renderEdges(scene, options, rendered_image);
}
