#include "directADps.h"
#include "scene.h"
#include "sampler.h"
#include "rayAD.h"
#include <chrono>
#include <omp.h>

#define DIRECT_SAMPLES 4


Spectrum DirectIntegratorAD_PathSpace::Li(const Scene& scene, RndSampler* sampler, Ray ray) const {
    Intersection its;
    Spectrum ret(0.0f);

    // Perform the first intersection
    scene.rayIntersect(ray, true, its);
    if (its.isValid()) {
        if (its.isEmitter()) ret += its.Le(-ray.dir);

        Spectrum tmp(0.0f);
        for ( int i = 0; i < DIRECT_SAMPLES; ++i ) {
            // Direct illumination
            Float pdf_nee;
            Vector wo;
            auto value = scene.sampleEmitterDirect(its, sampler->next2D(), sampler, wo, pdf_nee);
            if (!value.isZero()) {
                auto bsdf_val = its.evalBSDF(wo);
                Float bsdf_pdf = its.pdfBSDF(wo);
                auto mis_weight = square(pdf_nee) / (square(pdf_nee) + square(bsdf_pdf));
                tmp += value * bsdf_val * mis_weight;
            }

            // Indirect illumination
            Float bsdf_pdf, bsdf_eta;
            auto bsdf_weight = its.sampleBSDF(sampler->next3D(), wo, bsdf_pdf, bsdf_eta);
            if (!bsdf_weight.isZero()) {
                wo = its.toWorld(wo);
                ray = Ray(its.p, wo);
                if (scene.rayIntersect(ray, true, its)) {
                    if (its.isEmitter()) {
                        Spectrum light_contrib = its.Le(-ray.dir);
                        if (!light_contrib.isZero()) {
                            auto dist_sq = (its.p - ray.org).squaredNorm();
                            auto geometry_term = its.wi.z() / dist_sq;
                            pdf_nee = scene.pdfEmitterSample(its) / geometry_term;
                            auto mis_weight = square(bsdf_pdf) / (square(pdf_nee) + square(bsdf_pdf));
                            tmp += bsdf_weight * light_contrib * mis_weight;
                        }
                    }
                }
            }
        }
        ret += tmp/static_cast<Float>(DIRECT_SAMPLES);
    }
    return ret;
}


SpectrumAD DirectIntegratorAD_PathSpace::LiAD(const Scene& scene, RndSampler* sampler, const RayAD& _ray) const {
    RayAD ray(_ray);

    IntersectionAD its;
    SpectrumAD ret;

    // Perform the first intersection
    scene.rayIntersectAD(ray, true, its);
    if ( its.isValid() ) {
        if ( its.isEmitter() ) ret += its.Le(-ray.dir);

        SpectrumAD tmp;
        for ( int i = 0; i < DIRECT_SAMPLES; ++i ) {
            // Area sampling
            {
                SpectrumAD value;
                Float pdf1, pdf2;

                // Light sampling
                {
                    VectorAD wo;
                    Float G;
                    value = scene.sampleEmitterDirectAD(its, sampler->next2D(), sampler, wo, pdf1, &G);
                    if ( !value.isZero(Epsilon) ) {
                        SpectrumAD bsdf_val = its.evalBSDF(wo);
                        pdf2 = its.pdfBSDF(wo.val)*G;
                        tmp += value*bsdf_val*pdf1/(pdf1 + pdf2);
                    }
                }

                // BSDF sampling
                {
                    Vector wo;
                    Float bsdf_eta;
                    Intersection _its = its.toIntersection();
                    if ( !_its.sampleBSDF(sampler->next3D(), wo, pdf2, bsdf_eta).isZero(Epsilon) ) {
                        wo = _its.shFrame.toWorld(wo);
                        Ray ray1 = Ray(its.p.val, wo);
                        Intersection its1;
                        if ( scene.rayIntersect(ray1, true, its1) && its1.isEmitter() ) {
                            VectorAD x2, n2;
                            FloatAD J;
                            scene.getPoint(its1, x2, n2, J);

                            VectorAD dir = x2 - its.p;
                            FloatAD distSqr = dir.squaredNorm();
                            dir /= distSqr.sqrt();

                            SpectrumAD light_contrib = its1.ptr_emitter->evalAD(n2, -dir);
                            if ( !light_contrib.isZero(Epsilon) ) {
                                SpectrumAD bsdf_val = its.evalBSDF(its.toLocal(dir));
                                FloatAD G = n2.dot(-dir)/distSqr;
                                pdf2 *= G.val;
                                pdf1 = scene.pdfEmitterSample(its1);
                                tmp += bsdf_val*light_contrib*G*J/(pdf1 + pdf2);
                            }
                        }
                    }
                }
            }
        }
        ret += tmp/static_cast<Float>(DIRECT_SAMPLES);
    }
    return ret;
}


Spectrum DirectIntegratorAD_PathSpace::pixelColor(const Scene &scene, const RenderOptions &options, RndSampler *sampler, Float x, Float y) const {
    return Li(scene, sampler, Ray(scene.camera.samplePrimaryRay(x, y)));
}


SpectrumAD DirectIntegratorAD_PathSpace::pixelColorAD(const Scene &scene, const RenderOptions &options, RndSampler *sampler, Float x, Float y) const {
    return LiAD(scene, sampler, RayAD(scene.camera.samplePrimaryRayAD(x, y)));
}


void DirectIntegratorAD_PathSpace::render(const Scene &scene, const RenderOptions &_options, ptr<float> rendered_image) const {
    RenderOptions options = _options;
    if ( options.max_bounces != 1 ) {
        options.max_bounces = 1;
        if ( !options.quiet )
            std::cerr << "[WARN] Forcing max_bounces to one." << std::endl;
    }
    IntegratorAD_PathSpace::render(scene, options, rendered_image);
#ifdef USE_BOUNDARY_NEE
    if ( options.num_samples_secondary_edge_direct > 0 )
        renderEdgesDirect(scene, options, rendered_image);
#else
    if ( options.num_samples_secondary_edge_indirect > 0 )
        renderEdges(scene, options, rendered_image);
#endif
}
