#include "invDirectADps.h"
#include "scene.h"
#include "sampler.h"
#include "rayAD.h"
#include "intersectionAD.h"
#include <iomanip>

void InvDirectAD_PathSpace::renderAD(const Scene &scene, const RenderOptions &options, ptr<float> rendered_image) const {
    int size_block = 100000;                //
    int num_block = options.num_samples;
    const auto &camera = scene.camera;
    int num_pixels = camera.getNumPixels();
    const int nworker = omp_get_num_procs();
    image_per_thread.resize(nworker);
    for (int i = 0; i < nworker; i++) {
        image_per_thread[i].resize(num_pixels *(nder+1), Spectrum::Zero());
    }

#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int index_block = 0; index_block < num_block; index_block++) {
        int block_start = index_block*size_block;
        int block_end = (index_block+1)*size_block;
        for (int index_sample = block_start; index_sample < block_end; index_sample++) {
            RndSampler sampler(options.seed, index_sample);
            int thread_id = omp_get_thread_num();
            traceParticleAD(scene, &sampler, options.max_bounces, thread_id);
        }

        if ( !options.quiet ) {
            omp_set_lock(&messageLock);
            progressIndicator(Float(index_block)/num_block);
            omp_unset_lock(&messageLock);
        }
    }

    for (int ithread = 0; ithread < nworker; ithread++) {
        for (int idx_entry = 0; idx_entry < num_pixels*(nder+1); idx_entry++) {
            for (int ichannel = 0; ichannel < 3; ichannel++) {
                rendered_image[idx_entry*3 + ichannel] += image_per_thread[ithread][idx_entry][ichannel];
            }
        }
    }

    size_t num_samples = size_t(size_block) * num_block;
    for (int idx_entry = 0; idx_entry < num_pixels*(nder+1); idx_entry++) {
        for (int ichannel = 0; ichannel < 3; ichannel++) {
            rendered_image[idx_entry*3 + ichannel] /= num_samples;
        }
    }
}



void InvDirectAD_PathSpace::traceParticleAD(const Scene& scene, RndSampler *sampler, int max_bounces, int thread_id) const {
    IntersectionAD its;
    Intersection its0;
    FloatAD J;
    SpectrumAD power = scene.sampleEmitterPosition(sampler->next2D(), its, J);
    // SpectrumAD power = SpectrumAD(scene.sampleEmitterPosition(sampler->next2D(), its0));
    // scene.getPoint(its0, its, J);
    power *= J;

    // connect emitter to sensor directly
    handleEmissionAD(its.toIntersection(), scene, sampler, power.val, max_bounces, thread_id);

    Ray ray;
    ray.org = its.p.val;
    Float pdf;
    its.ptr_emitter->sampleDirection(sampler->next2D(), ray.dir, &pdf);
    ray.dir = its.geoFrame.toWorld(ray.dir).val;
    Intersection its1;
    if (scene.rayIntersect(ray, true, its1)) {
        power /= pdf;
        handleSurfaceInteractionAD(its1, its, scene, nullptr, sampler, power, max_bounces, thread_id);
    }
}


void InvDirectAD_PathSpace::handleEmissionAD(const Intersection& its, const Scene& scene, RndSampler *sampler,
                                             const Spectrum& weight, int max_bounces, int thread_id) const
{
    Vector2 pix_uv;
    Vector dir;
    Float transmittance = scene.sampleAttenuatedSensorDirect(its, sampler, max_bounces, pix_uv, dir);

    if (transmittance != 0.0) {
        SpectrumAD value = SpectrumAD(weight) * transmittance * its.ptr_emitter->evalDirection(its.geoFrame.n, dir);
        int idx_pixel = scene.camera.getPixelIndex(pix_uv);
        int num_pixels = scene.camera.getNumPixels();
        image_per_thread[thread_id][idx_pixel] += value.val;

        for (int ch = 1; ch <= nder; ch++) {
            image_per_thread[thread_id][ch*num_pixels + idx_pixel] += value.grad(ch - 1);
        }
    }
}


void InvDirectAD_PathSpace::handleSurfaceInteractionAD(const Intersection& its, const IntersectionAD its0_ad, const Scene& scene, const Medium* ptr_med, RndSampler *sampler,
                                                       const SpectrumAD& weight, int max_bounces, int thread_id) const {
    Vector2 pix_uv;
    Vector dir;
    Float transmittance = scene.sampleAttenuatedSensorDirect(its, sampler, max_bounces, pix_uv, dir);

    if (transmittance != 0.0f) {
        RayAD cameraRay = scene.camera.samplePrimaryRayAD(pix_uv[0], pix_uv[1]);
        IntersectionAD its_ad;
        if ( scene.rayIntersectAD(cameraRay, false, its_ad) && (its_ad.p.val - its.p).norm() < ShadowEpsilon ) {
            SpectrumAD value = weight;

            VectorAD dir0 = its0_ad.p - its_ad.p;
            FloatAD d = dir0.norm();
            dir0 /= d;
            FloatAD G = its_ad.geoFrame.n.dot(dir0).abs()/d.square();
            value *= its0_ad.ptr_emitter->evalDirectionAD(its0_ad.geoFrame.n, -dir0) * G / G.val;

            its_ad.wi = its_ad.toLocal(dir0);
            VectorAD wi =  dir0,
                     wo = -cameraRay.dir, wo_local = its_ad.toLocal(wo);
            /* Prevent light leaks due to the use of shading normals -- [Veach, p. 158] */
            Float wiDotGeoN = wi.val.dot(its.geoFrame.n), woDotGeoN = wo.val.dot(its.geoFrame.n);
            if (wiDotGeoN * its_ad.wi.val.z() <= 0 || woDotGeoN * wo_local.val.z() <= 0)
                return;
            value *= transmittance * its_ad.ptr_bsdf->evalAD(its_ad, wo_local, EBSDFMode::EImportanceWithCorrection) / wo_local.z() * wo_local.val.z();
            int idx_pixel = scene.camera.getPixelIndex(pix_uv);
            int num_pixels = scene.camera.getNumPixels();

            image_per_thread[thread_id][idx_pixel] += value.val;

            for (int ch = 1; ch <= nder; ch++) {
                image_per_thread[thread_id][ch*num_pixels + idx_pixel] += value.grad(ch - 1);
            }
        }
    }
}

void InvDirectAD_PathSpace::render(const Scene &scene, const RenderOptions &options, ptr<float> rendered_image) const {
    // non-EdgeTerm
    renderAD(scene, options, rendered_image);
    // EdgeTerm
    if ( options.num_samples_secondary_edge > 0 ) {
        renderEdgesDirect(scene, options, rendered_image);
        renderEdges(scene, options, rendered_image);
    }
    // PrimaryEdgeTerm
    if ( options.num_samples_primary_edge > 0 && scene.ptr_edgeManager->getNumPrimaryEdges() > 0 )
        IntegratorAD::renderPrimaryEdges(scene, options, rendered_image);
}
