#include "scene.h"
#include "binnedPtracerADps.h"

#define IGNORE_SHADING_NORMAL_CORRECTION

using RecordEntry = Binned_ParticleTracerAD_PathSpace::RecordEntry;

void Binned_ParticleTracerAD_PathSpace::renderInterior(const Scene &scene, const RenderOptions &options, 
    ptr<float> rendered_image) const {
    const auto &camera = scene.camera;
    int size_block = camera.getNumPixels();
    int num_block = options.num_samples;
    int num_pixels = camera.getNumPixels();
    int num_bins = camera.pif->num_bins;
    const int nworker = omp_get_num_procs();
    // const int nworker = 1;

    // init
    image_per_thread.resize(nworker);
    for (int i = 0; i < nworker; i++) {
        image_per_thread[i].resize(num_bins * num_pixels, SpectrumAD(0.0));
    }

#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int index_block = 0; index_block < num_block; index_block++) {
        int block_start = index_block * size_block;
        int block_end = (index_block + 1) * size_block;
        for (int index_sample = block_start; index_sample < block_end; index_sample++) {
            RndSampler sampler(options.seed, index_sample);
            int thread_id = omp_get_thread_num();
            traceParticleAD(scene, &sampler, options.max_bounces, thread_id, options.grad_threshold);
        }

        if (!options.quiet) {
            omp_set_lock(&messageLock);
            progressIndicator(Float(index_block) / num_block);
            omp_unset_lock(&messageLock);
        }
    }

     for (int ithread = 0; ithread < nworker; ithread++) {
        for (int idx_bin = 0; idx_bin < num_bins; idx_bin++) {
            int bin_offset = idx_bin * (nder + 1) * num_pixels;
            for (int idx_pixel = 0; idx_pixel < num_pixels; idx_pixel++) {
                auto &pixel_val = image_per_thread[ithread][idx_bin * num_pixels + idx_pixel];
                for (int ch = 0; ch <= nder; ++ch) {
                    int offset = (bin_offset + ch * num_pixels + idx_pixel) * 3;
                    if (ch == 0) {
                        rendered_image[offset] += static_cast<float>(pixel_val.val(0));
                        rendered_image[offset + 1] += static_cast<float>(pixel_val.val(1));
                        rendered_image[offset + 2] += static_cast<float>(pixel_val.val(2));
                    }
                    else {
                        rendered_image[offset] += static_cast<float>((pixel_val.grad(ch - 1))(0));
                        rendered_image[offset + 1] += static_cast<float>((pixel_val.grad(ch - 1))(1));
                        rendered_image[offset + 2] += static_cast<float>((pixel_val.grad(ch - 1))(2));
                    }
                }
            }
        }
    }

    size_t num_samples = size_t(size_block) * num_block;
    for (int idx_entry = 0; idx_entry < num_bins * (nder + 1) * num_pixels; idx_entry++) {
        for (int ichannel = 0; ichannel < 3; ichannel++) {
            rendered_image[idx_entry * 3 + ichannel] /= num_samples;
        }
    }
}

void Binned_ParticleTracerAD_PathSpace::traceParticleAD(const Scene &scene, RndSampler *sampler, int max_bounces, 
    int thread_id, Float grad_threshold) const {
    IntersectionAD itsAD, pre_itsAD;
    Float pdf;
    FloatAD J;
    SpectrumAD power = scene.sampleEmitterPosition(sampler->next2D(), itsAD, J, &pdf);
    Intersection its = itsAD.toIntersection();
    pre_itsAD = itsAD;

    std::vector<RecordEntry> records;

    handleEmissionAD(itsAD, scene, sampler, power, max_bounces, records);

    Ray ray;
    ray.org = its.p;
    //! it's nonAD since it's area light
    its.ptr_emitter->sampleDirection(sampler->next2D(), ray.dir, &pdf);
    ray.dir = its.geoFrame.toWorld(ray.dir);
    int depth = 0;
    int hit = 0;
    FloatAD path_length = 0.0;
    SpectrumAD throughput(Spectrum::Ones());
    bool on_surface = true;

    while (!throughput.val.isZero() && depth < max_bounces) {
        scene.rayIntersect(ray, on_surface, its);
        int max_interactions = max_bounces - depth - 1;
        
        if (!its.isValid())
            break;

        // getPoint
        FloatAD J;
        scene.getPoint(its, pre_itsAD.p, itsAD, J);
        VectorAD x = itsAD.p;
        VectorAD dir = x - pre_itsAD.p;
        FloatAD dist = dir.norm();
        dir /= dist;
        path_length += dist;

        // calculate power
        if (hit == 0) {
            power *= pre_itsAD.ptr_emitter->evalDirectionAD(pre_itsAD.geoFrame.n, dir) / pdf;
        }
        // calculate preEvt bsdf/phase value
        SpectrumAD f(Spectrum::Ones());
        if (hit > 0) {
            VectorAD wo_local = pre_itsAD.toLocal(dir);
            Float pdf = pre_itsAD.pdfBSDF(wo_local.val);
            f = pre_itsAD.evalBSDF(wo_local) / pdf;
        }
        // calculate throughput
        FloatAD G = itsAD.geoFrame.n.dot(-dir) / dist.square();
        Float pdf = G.val;
        throughput *= f * G * J / pdf;
        handleSurfaceInteractionAD(itsAD, scene, nullptr, sampler, power * throughput, max_interactions, path_length, records);

        pre_itsAD = itsAD;
        hit++;

        Float bsdf_pdf, bsdf_eta;
        Vector wo_local, wo;
        EBSDFMode mode;
#ifdef IGNORE_SHADING_NORMAL_CORRECTION
        mode = EBSDFMode::EImportance;
#else
        mode = EBSDFMode::EImportanceWithCorrection;
#endif
        auto bsdf_weight = its.sampleBSDF(sampler->next3D(), wo_local, bsdf_pdf, bsdf_eta, mode);
        if (bsdf_weight.isZero())
            break;
        wo = its.toWorld(wo_local);
        /* Prevent light leaks due to the use of shading normals -- [Veach, p. 158] */
        Vector wi = -ray.dir;
        Float wiDotGeoN = wi.dot(its.geoFrame.n),
                woDotGeoN = wo.dot(its.geoFrame.n);
        if (wiDotGeoN * its.wi.z() <= 0 || woDotGeoN * wo_local.z() <= 0) break;

        ray = Ray(its.p, wo);
        on_surface = true;

        depth++;
    }

    auto &image = image_per_thread[thread_id];
    auto &pif = scene.camera.pif;
    for (const auto &[contrib, path_length, idx_pixel] : records) {
        Vector2i bin_range = pif->getBinIndexRange(path_length.val);
        for (int idx_bin = bin_range[0]; idx_bin <= bin_range[1]; idx_bin++) {
            FloatAD pif_kernel = pif->evalAD(path_length, idx_bin);
            image[idx_bin * scene.camera.getNumPixels() + idx_pixel] += contrib * pif_kernel;
        }
    }
}

void Binned_ParticleTracerAD_PathSpace::handleEmissionAD(const IntersectionAD &its, const Scene &scene, 
    RndSampler *sampler, const SpectrumAD &weight, int max_bounces, std::vector<RecordEntry> &records) const {
    Vector2AD pix_uv;
    VectorAD dir;
    Matrix2x4AD pix_uvs;
    Vector4AD sensor_vals = scene.sampleAttenuatedSensorDirectAD(its, nullptr, sampler, max_bounces, pix_uvs, dir);
    FloatAD path_length = (its.p - scene.camera.cpos).norm();
    if (!sensor_vals.val.isZero()) {
        for (int i = 0; i < 4; i++) {
            if (sensor_vals(i) > Epsilon) {
                SpectrumAD contrib = weight * sensor_vals(i) * its.ptr_emitter->evalDirectionAD(its.geoFrame.n, dir);
                int idx_pixel = scene.camera.getPixelIndex(pix_uvs.val.col(i));
                records.push_back(std::make_tuple(contrib, path_length, idx_pixel));
            }
        }
    }
}

void Binned_ParticleTracerAD_PathSpace::handleSurfaceInteractionAD(const IntersectionAD &its, const Scene &scene, 
    const Medium *ptr_med, RndSampler *sampler, const SpectrumAD &weight, int max_bounces, FloatAD path_length,
    std::vector<RecordEntry> &records) const {
    Vector2AD pix_uv;
	VectorAD dir;
	const Camera &camera = scene.camera;

	Matrix2x4AD pix_uvs;
	Vector4AD sensor_vals = scene.sampleAttenuatedSensorDirectAD(its, ptr_med, sampler, max_bounces, pix_uvs, dir);
	if (!sensor_vals.val.isZero()) {
		VectorAD wi = its.toWorld(its.wi);
		VectorAD wo = dir, wo_local = its.toLocal(wo);
		/* Prevent light leaks due to the use of shading normals -- [Veach, p. 158] */
		FloatAD wiDotGeoN = wi.dot(its.geoFrame.n),
			woDotGeoN = wo.dot(its.geoFrame.n);
		if (wiDotGeoN.val * its.wi.z().val <= 0 ||
		    woDotGeoN.val * wo_local.z().val <= 0)
			return;

        path_length += (its.p - camera.cpos).norm();

		/* Adjoint BSDF for shading normals -- [Veach, p. 155] */
		for (int i = 0; i < 4; i++) {
			if (sensor_vals(i) > Epsilon) {
				SpectrumAD contrib = sensor_vals(i) * its.ptr_bsdf->evalAD(its, wo_local, EBSDFMode::EImportanceWithCorrection) * weight;
				// !
                if (contrib.val.isNaN().maxCoeff())
                    continue;
				int idx_pixel = camera.getPixelIndex(pix_uvs.val.col(i));
                records.push_back(std::make_tuple(contrib, path_length, idx_pixel));
            }
		}
	}
}