#include "ptracerADps.h"
#include "scene.h"
#include "ray.h"
#include "sampler.h"
#include <assert.h>
#include <omp.h>
#include <numeric>
#include <iostream>

#define DEBUG(x)                                                            \
	do                                                                      \
	{                                                                       \
		std::cerr << std::setprecision(15) << #x << ": " << x << std::endl; \
	} while (0)

#define IGNORE_SHADING_NORMAL_CORRECTION
#define USE_CAMERA_FILTER

void ParticleTracerADps::renderInterior(const Scene &scene, const RenderOptions &options, ptr<float> rendered_image) const
{
	const auto &camera = scene.camera;
	int size_block = camera.getNumPixels();
	int num_block = options.num_samples;
	int num_pixels = camera.getNumPixels();
	const int nworker = omp_get_num_procs();
	// const int nworker = 1;

#ifndef NDEBUG
	DEBUG("debug mode");
#endif

	DEBUG(nworker);
	image_per_thread.resize(nworker);
	// init
	for (int i = 0; i < nworker; i++)
	{
		image_per_thread[i].resize(num_pixels);
		for (int j = 0; j < num_pixels; j++)
			image_per_thread[i][j] = SpectrumAD();

		path_per_thread.reserve(options.max_bounces + 2);
		path_per_thread.reserve(options.max_bounces + 2);
	}
	int finished = 0;
#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
	for (int index_block = 0; index_block < num_block; index_block++)
	{
		int block_start = index_block * size_block;
		int block_end = (index_block + 1) * size_block;
		for (int index_sample = block_start; index_sample < block_end; index_sample++)
		{
			RndSampler sampler(options.seed, index_sample);
			int thread_id = omp_get_thread_num();
			traceParticleAD(scene, &sampler, options.max_bounces, thread_id, options.grad_threshold);
		}
#pragma omp critical
		progressIndicator(Float(++finished) / num_block);
	}

	for (int ithread = 0; ithread < nworker; ithread++)
	{
		for (int idx_pixel = 0; idx_pixel < num_pixels; idx_pixel++)
		{
			auto &pixel_val = image_per_thread[ithread][idx_pixel];
			rendered_image[idx_pixel * 3] += static_cast<float>(pixel_val.val(0));
			rendered_image[idx_pixel * 3 + 1] += static_cast<float>(pixel_val.val(1));
			rendered_image[idx_pixel * 3 + 2] += static_cast<float>(pixel_val.val(2));

			// derivative
			for (int ch = 1; ch <= nder; ++ch)
			{
				int offset = (ch * num_pixels + idx_pixel) * 3;
				rendered_image[offset] += static_cast<float>((pixel_val.grad(ch - 1))(0));
				rendered_image[offset + 1] += static_cast<float>((pixel_val.grad(ch - 1))(1));
				rendered_image[offset + 2] += static_cast<float>((pixel_val.grad(ch - 1))(2));
			}
		}
	}

	size_t num_samples = size_t(size_block) * num_block;
	for (int idx_entry = 0; idx_entry < num_pixels * (nder + 1); idx_entry++)
	{
		for (int ichannel = 0; ichannel < 3; ichannel++)
		{
			rendered_image[idx_entry * 3 + ichannel] /= num_samples;
		}
	}
}

void ParticleTracerADps::traceParticleAD(const Scene &scene, RndSampler *sampler, int max_bounces, int thread_id, Float grad_threshold) const
{
	IntersectionAD itsAD;
	Float pdf;
	FloatAD J;
	SpectrumAD power = scene.sampleEmitterPosition(sampler->next2D(), itsAD, J, &pdf);
	// SpectrumAD power = scene.sampleEmitterPosition(sampler->next2D(), its, pdf);
	Intersection its = itsAD.toIntersection();

	handleEmissionAD(itsAD, scene, sampler, power, max_bounces, thread_id);

	Ray ray;
	ray.org = its.p;
	//! it's nonAD since it's area light
	its.ptr_emitter->sampleDirection(sampler->next2D(), ray.dir, &pdf);
	if (its.ptr_shape != nullptr)
		ray.dir = its.geoFrame.toWorld(ray.dir);
	int depth = 0;
	int hit = 0;
	SpectrumAD throughput(Spectrum::Ones());
	const Medium *ptr_med = its.getTargetMedium(ray.dir);
	bool on_surface = true;

	EventRecordAD preEvt(itsAD);

	while (!throughput.val.isZero() && depth < max_bounces)
	{
		scene.rayIntersect(ray, on_surface, its);
		int max_interactions = max_bounces - depth - 1;
		Spectrum T = Spectrum::Ones();
		bool inside_med = ptr_med != nullptr &&
						  ptr_med->sampleDistance(ray, its.t, sampler->next2D(), sampler, ray.org, T);
		if (inside_med)
		{
			// getPoint
			VectorAD x;
			FloatAD J;
			ptr_med->getPoint(ray.org, x, J);
			VectorAD dir = x - preEvt.x();
			FloatAD dist = dir.norm();
			dir /= dist;
			// calculate power if this is the first scattering
			if (hit == 0)
			{
				power *= preEvt.its.ptr_emitter->evalDirectionAD(preEvt.its.geoFrame.n, dir) / pdf;
			}
			// calculate preEvt phase/bsdf value
			SpectrumAD f(Spectrum::Ones());
			if (hit > 0)
				f = preEvt.f(scene, dir);
			// calculate throughput
			FloatAD G = 1.0 / dist.square();
			FloatAD T = scene.evalTransmittanceAD(RayAD(preEvt.x(), dir), preEvt.onSurface, preEvt.getMedium(dir.val), dist, sampler, depth);
			SpectrumAD sigS = ptr_med->sigSAD(x);
			Float sigT = ptr_med->sigT(x.val);
			Float pdf = T.val * sigT * G.val;
			throughput *= f * T * sigS * G * J / pdf;
			handleMediumInteractionAD(scene, ptr_med, RayAD(x, dir), sampler, throughput * power, max_interactions, thread_id);
			const PhaseFunction *ptr_phase = scene.phase_list[ptr_med->phase_id];
			Vector wo;
			Float phase_val = ptr_phase->sample(-ray.dir, sampler->next2D(), wo);
			if (phase_val == 0)
				break;
			ray.dir = wo;
			on_surface = false;

			preEvt = EventRecordAD(x, -dir, ptr_med);
			hit++;
		}
		else
		{
			if (!its.isValid())
				break;
			if (!its.ptr_bsdf->isNull())
			{
				// getPoint
				FloatAD J;
				scene.getPoint(its, preEvt.x(), itsAD, J);
				VectorAD x = itsAD.p;
				VectorAD dir = x - preEvt.x();
				FloatAD dist = dir.norm();
				dir /= dist;
				// calculate power
				if (hit == 0)
				{
					power *= preEvt.its.ptr_emitter->evalDirectionAD(preEvt.its.geoFrame.n, dir) / pdf;
				}
				// calculate preEvt bsdf/phase value
				SpectrumAD f(Spectrum::Ones());
				if (hit > 0)
					f = preEvt.f(scene, dir);
				// calculate throughput
				FloatAD G = itsAD.geoFrame.n.dot(-dir) / dist.square();
				FloatAD T = scene.evalTransmittanceAD(RayAD(preEvt.x(), dir), preEvt.onSurface, preEvt.getMedium(dir.val), dist, sampler, depth);
				Float pdf = T.val * G.val;
				throughput *= f * T * G * J / pdf;
				handleSurfaceInteractionAD(itsAD, scene, ptr_med, sampler, power * throughput, max_interactions, thread_id);

				preEvt = EventRecordAD(itsAD);
				hit++;
			}
			Float bsdf_pdf, bsdf_eta;
			Vector wo_local, wo;
			EBSDFMode mode;
#ifdef IGNORE_SHADING_NORMAL_CORRECTION
			mode = EBSDFMode::EImportance;
#else
			mode = EBSDFMode::EImportanceWithCorrection;
#endif
			auto bsdf_weight = its.sampleBSDF(sampler->next3D(), wo_local, bsdf_pdf, bsdf_eta, mode);
			if (bsdf_weight.isZero())
				break;
			wo = its.toWorld(wo_local);
			/* Prevent light leaks due to the use of shading normals -- [Veach, p. 158] */
			Vector wi = -ray.dir;
			Float wiDotGeoN = wi.dot(its.geoFrame.n),
				  woDotGeoN = wo.dot(its.geoFrame.n);
			if (wiDotGeoN * its.wi.z() <= 0 ||
				woDotGeoN * wo_local.z() <= 0)
			{
				break;
			}
			// throughput *= bsdf_weight;
			if (its.isMediumTransition())
				ptr_med = its.getTargetMedium(woDotGeoN);
			ray = Ray(its.p, wo);
			on_surface = true;
		}
		depth++;
	}
}

void ParticleTracerADps::handleEmissionAD(const IntersectionAD &its, const Scene &scene, RndSampler *sampler,
										  const SpectrumAD &weight, int max_bounces, int thread_id) const
{
	Vector2AD pix_uv;
	VectorAD dir;
	Matrix2x4AD pix_uvs;
	Vector4AD sensor_vals = scene.sampleAttenuatedSensorDirectAD(its, nullptr, sampler, max_bounces, pix_uvs, dir);
	if (!sensor_vals.val.isZero())
	{
		for (int i = 0; i < 4; i++)
		{
			if (sensor_vals(i) > Epsilon)
			{
				SpectrumAD value = weight * sensor_vals(i) * its.ptr_emitter->evalDirectionAD(its.geoFrame.n, dir);

				int idx_pixel = scene.camera.getPixelIndex(pix_uvs.val.col(i));
				image_per_thread[thread_id][idx_pixel] += value;
			}
		}
	}
}

void ParticleTracerADps::handleSurfaceInteractionAD(const IntersectionAD &its, const Scene &scene, const Medium *ptr_med, RndSampler *sampler,
													const SpectrumAD &weight, int max_bounces, int thread_id) const
{
	Vector2AD pix_uv;
	VectorAD dir;
	const Camera &camera = scene.camera;

#ifndef USE_CAMERA_FILTER
	FloatAD transmittance = scene.sampleAttenuatedSensorDirectAD(its, ptr_med, sampler, max_bounces, pix_uv, dir);
	if (transmittance.val != 0.0f)
	{
		VectorAD wi = its.toWorld(its.wi);
		VectorAD wo = dir, wo_local = its.toLocal(wo);
		/* Prevent light leaks due to the use of shading normals -- [Veach, p. 158] */
		FloatAD wiDotGeoN = wi.dot(its.geoFrame.n),
				woDotGeoN = wo.dot(its.geoFrame.n);
		if (wiDotGeoN.val * its.wi.z().val <= 0 ||
			woDotGeoN.val * wo_local.z().val <= 0)
			return;

		/* Adjoint BSDF for shading normals -- [Veach, p. 155] */
		SpectrumAD value = transmittance * its.ptr_bsdf->evalAD(its, wo_local, EBSDFMode::EImportanceWithCorrection) * weight;
		int idx_pixel = camera.getPixelIndex(pix_uv.val);
		image_per_thread[thread_id][idx_pixel] += value;
	}
#else
	if (isHybrid())
	{
		bool isDielectric = its.getBSDF()->isTwosided() && its.getBSDF()->isTransmissive();
		if (isDielectric)
			return;
	}
	Matrix2x4AD pix_uvs;
	Vector4AD sensor_vals = scene.sampleAttenuatedSensorDirectAD(its, ptr_med, sampler, max_bounces, pix_uvs, dir);
	if (!sensor_vals.val.isZero())
	{
		VectorAD wi = its.toWorld(its.wi);
		VectorAD wo = dir, wo_local = its.toLocal(wo);
		/* Prevent light leaks due to the use of shading normals -- [Veach, p. 158] */
		FloatAD wiDotGeoN = wi.dot(its.geoFrame.n),
				woDotGeoN = wo.dot(its.geoFrame.n);
		if (wiDotGeoN.val * its.wi.z().val <= 0 ||
			woDotGeoN.val * wo_local.z().val <= 0)
			return;

		/* Adjoint BSDF for shading normals -- [Veach, p. 155] */
		for (int i = 0; i < 4; i++)
		{
			if (sensor_vals(i) > Epsilon)
			{
				SpectrumAD value = sensor_vals(i) * its.ptr_bsdf->evalAD(its, wo_local, EBSDFMode::EImportanceWithCorrection) * weight;
				// !
				if (value.val.isNaN().maxCoeff())
					continue;
				int idx_pixel = camera.getPixelIndex(pix_uvs.val.col(i));
				image_per_thread[thread_id][idx_pixel] += value;
			}
		}
	}
#endif
}

void ParticleTracerADps::handleMediumInteractionAD(const Scene &scene, const Medium *ptr_med, const RayAD &ray, RndSampler *sampler,
												   const SpectrumAD &weight, int max_bounces, int thread_id) const
{
	Vector2AD pix_uv;
	VectorAD wi = -ray.dir, wo;
	Matrix2x4AD pix_uvs;
	Vector4AD sensor_vals = scene.sampleAttenuatedSensorDirectAD(ray.org, ptr_med, sampler, max_bounces, pix_uvs, wo);
	if (!sensor_vals.val.isZero())
	{
		const PhaseFunction *ptr_phase = scene.phase_list[ptr_med->phase_id];
		for (int i = 0; i < 4; i++)
		{
			if (sensor_vals(i) > Epsilon)
			{
				SpectrumAD value = sensor_vals(i) * ptr_phase->evalAD(wi, wo) * weight;
				// !
				if (value.val.isNaN().maxCoeff())
					continue;
				int idx_pixel = scene.camera.getPixelIndex(pix_uvs.val.col(i));
				image_per_thread[thread_id][idx_pixel] += value;
			}
		}
	}
}