#include "ptracerADps.h"
#include "scene.h"
#include "sampler.h"
#include "rayAD.h"
#include "intersectionAD.h"
#include <iomanip>

void ParticleTracerAD_PathSpace::renderAD(const Scene &scene, const RenderOptions &options, ptr<float> rendered_image) const {
    int size_block = scene.camera.getNumPixels();
    int num_block = options.num_samples;
    const auto &camera = scene.camera;
    int num_pixels = camera.getNumPixels();
    const int nworker = omp_get_num_procs();
    image_per_thread.resize(nworker);
    for (int i = 0; i < nworker; i++) {
        image_per_thread[i].resize(num_pixels *(nder+1), Spectrum::Zero());
    }

#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int index_block = 0; index_block < num_block; index_block++) {
        int block_start = index_block*size_block;
        int block_end = (index_block+1)*size_block;
        for (int index_sample = block_start; index_sample < block_end; index_sample++) {
            RndSampler sampler(options.seed, index_sample);
            int thread_id = omp_get_thread_num();
            traceParticleAD(scene, &sampler, options.max_bounces, thread_id, options.grad_threshold);
        }

        if ( !options.quiet ) {
            omp_set_lock(&messageLock);
            progressIndicator(Float(index_block)/num_block);
            omp_unset_lock(&messageLock);
        }
    }

    for (int ithread = 0; ithread < nworker; ithread++) {
        for (int idx_entry = 0; idx_entry < num_pixels*(nder+1); idx_entry++) {
            for (int ichannel = 0; ichannel < 3; ichannel++) {
                rendered_image[idx_entry*3 + ichannel] += image_per_thread[ithread][idx_entry][ichannel];
            }
        }
    }

    size_t num_samples = size_t(size_block) * num_block;
    for (int idx_entry = 0; idx_entry < num_pixels*(nder+1); idx_entry++) {
        for (int ichannel = 0; ichannel < 3; ichannel++) {
            rendered_image[idx_entry*3 + ichannel] /= num_samples;
        }
    }
}

void ParticleTracerAD_PathSpace::traceParticleAD(const Scene& scene, RndSampler *sampler, int max_bounces, int thread_id, Float grad_threshold) const
{
    bidir::PathNodeAD *lightPath = m_path[thread_id];

    lightPath[0].throughput = scene.sampleEmitterPosition(sampler->next2D(), lightPath[0].itsAD, lightPath[0].J, &lightPath[0].pdf0);
    lightPath[0].throughput *= lightPath[0].J;
    lightPath[0].its = lightPath[0].itsAD.toIntersection();
    int lightPathLen = 1;
    if ( max_bounces > 0) {
        Vector wo;
        Float &pdf = lightPath[1].pdf0;
        lightPath[0].its.ptr_emitter->sampleDirection(sampler->next2D(), wo, &pdf);
        wo = lightPath[0].its.geoFrame.toWorld(wo);
        if ( scene.rayIntersect(Ray(lightPath[0].its.p, wo), true, lightPath[1].its) ) {
            scene.getPoint(lightPath[1].its, lightPath[0].itsAD.p, lightPath[1].itsAD, lightPath[1].J);
            lightPath[0].wo = lightPath[1].itsAD.p - lightPath[0].itsAD.p;
            FloatAD d = lightPath[0].wo.norm();
            lightPath[0].wo /= d;
            FloatAD G = lightPath[1].itsAD.geoFrame.n.dot(-lightPath[0].wo).abs() / d.square();
            pdf *= G.val;
            lightPath[1].throughput = lightPath[0].throughput * lightPath[0].its.ptr_emitter->evalDirectionAD(lightPath[0].itsAD.geoFrame.n, lightPath[0].wo) * lightPath[1].J * G / pdf;
            if ( max_bounces > 1)
                lightPathLen = bidir::buildPathAD(scene, sampler, max_bounces, true, &lightPath[1]) + 1;
            else
                lightPathLen = 2;
        }
    }

    const Camera& camera = scene.camera;
    int num_pixels = camera.getNumPixels();
    for (int i = 0; i < lightPathLen; i++) {
        Vector2 pix_uv;
        Vector dir;
        Intersection& its = lightPath[i].its;
        Float transmittance = scene.sampleAttenuatedSensorDirect(its, sampler, max_bounces, pix_uv, dir);
        if ( transmittance != 0.0 ) {
            SpectrumAD value = SpectrumAD(Spectrum::Zero());
            if ( i == 0 ) {
                value = (lightPath[0].throughput/lightPath[0].J) * transmittance * its.ptr_emitter->evalDirection(its.geoFrame.n, dir);
            }
            else {
                Vector wi = its.toWorld(its.wi), wo = dir, wo_local = its.toLocal(wo);
                Float wiDotGeoN = wi.dot(its.geoFrame.n), woDotGeoN = wo.dot(its.geoFrame.n);
                if (wiDotGeoN * its.wi.z() <= 0 || woDotGeoN * wo_local.z() <= 0) continue;

                RayAD cameraRay = scene.camera.samplePrimaryRayAD(pix_uv[0], pix_uv[1]);
                IntersectionAD its_ad;
                if ( scene.rayIntersectAD(cameraRay, false, its_ad) && (its_ad.p.val - its.p).norm() < ShadowEpsilon ) {
                    // value = lightPath[i].throughput_0;
                    value = lightPath[i-1].throughput;
                    VectorAD dir0 = lightPath[i-1].itsAD.p - its_ad.p;
                    FloatAD d = dir0.norm();
                    dir0 /= d;
                    FloatAD G = its_ad.geoFrame.n.dot(dir0).abs()/d.square();
                    if ( i == 1 ){
                        value *= lightPath[i-1].its.ptr_emitter->evalDirectionAD(lightPath[i-1].itsAD.geoFrame.n, -dir0) * G / lightPath[i].pdf0;
                    }
                    else {
                        value *= lightPath[i-1].its.ptr_bsdf->evalAD(lightPath[i-1].itsAD, lightPath[i-1].itsAD.toLocal(-dir0), EBSDFMode::EImportanceWithCorrection) * G / lightPath[i].pdf0;
                    }

                    its_ad.wi = its_ad.toLocal(dir0);
                    VectorAD woLocal = its_ad.toLocal(-cameraRay.dir);
                    // value *= transmittance * its_ad.ptr_bsdf->evalAD(its_ad, woLocal, EBSDFMode::EImportanceWithCorrection) / woLocal.z() * woLocal.val.z();
                    value *= transmittance * its_ad.ptr_bsdf->evalAD(its_ad, woLocal, EBSDFMode::EImportanceWithCorrection);
                }
            }


            bool val_valid = std::isfinite(value.val[0]) && std::isfinite(value.val[1]) && std::isfinite(value.val[2]) && value.val.minCoeff() >= 0.0f;
            Float tmp_val = value.der.abs().maxCoeff();
            bool der_valid = std::isfinite(tmp_val) && tmp_val < grad_threshold;

            if ( val_valid && der_valid ) {
                int idx_pixel = camera.getPixelIndex(pix_uv);
                image_per_thread[thread_id][idx_pixel] += value.val;
                for (int ch = 1; ch <= nder; ch++) {
                    image_per_thread[thread_id][ch*num_pixels + idx_pixel] += value.grad(ch - 1);
                }
            } else {
                omp_set_lock(&messageLock);
                if (!val_valid)
                    std::cerr << std::scientific << std::setprecision(2) << "\n[WARN] Invalid path contribution: [" << value.val << "]" << std::endl;
                if (!der_valid)
                    std::cerr << std::scientific << std::setprecision(2) << "\n[WARN] Rejecting large gradient: [" << value.der << "]" << std::endl;
                omp_unset_lock(&messageLock);
            }
        }
    }
}


void ParticleTracerAD_PathSpace::render(const Scene &scene, const RenderOptions &options, ptr<float> rendered_image) const {
    // non-EdgeTerm
    renderAD(scene, options, rendered_image);

    // EdgeTerm
#ifdef USE_BOUNDARY_NEE
    if ( options.num_samples_secondary_edge_direct > 0 )
        renderEdgesDirect(scene, options, rendered_image);
#endif

    if ( options.num_samples_secondary_edge_indirect > 0 )
        renderEdges(scene, options, rendered_image);

    // PrimaryEdgeTerm
    if ( options.num_samples_primary_edge > 0 && scene.ptr_edgeManager->getNumPrimaryEdges() > 0 )
        IntegratorAD::renderPrimaryEdges(scene, options, rendered_image);
}