#include "volpathADps.h"
#include "scene.h"
#include "sampler.h"
#include "rayAD.h"
#include "intersectionAD.h"
#include "stats.h"
#include "math_func.h"
#include <iomanip>
#include "omp.h"
#include <fstream>

#define DEBUG(x)                                                            \
	do                                                                      \
	{                                                                       \
		std::cerr << std::setprecision(15) << #x << ": " << x << std::endl; \
	} while (0)

using namespace std;

/******    Interior Term    ******/
void VolPathTracerADps::renderInterior(const Scene &scene, const RenderOptions &options, ptr<float> rendered_image) const {
    const auto &camera = scene.camera;
    const bool cropped = camera.rect.isValid();
    const int num_pixels = cropped ? camera.rect.crop_width * camera.rect.crop_height
                                   : camera.width * camera.height;
    const int size_block = 4;
    const int nworker = omp_get_num_procs();

    // Pixel sampling
    int num_block = std::ceil((Float)num_pixels/size_block);
    int finished_block = 0;
	#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
    for (int index_block = 0; index_block < num_block; index_block++) {
        int block_start = index_block*size_block;
        int block_end = std::min((index_block+1)*size_block, num_pixels);

        for (int idx_pixel = block_start; idx_pixel < block_end; idx_pixel++) {
            int ix = cropped ? camera.rect.offset_x + idx_pixel % camera.rect.crop_width
                             : idx_pixel % camera.width;
            int iy = cropped ? camera.rect.offset_y + idx_pixel / camera.rect.crop_width
                             : idx_pixel / camera.width;
            RndSampler sampler(options.seed, idx_pixel);

            SpectrumAD pixel_val;
            for (int idx_sample = 0; idx_sample < options.num_samples; idx_sample++) {
                const Array2 rnd = options.num_samples_primary_edge >= 0 ? Array2(sampler.next1D(), sampler.next1D()) : Array2(0.5f, 0.5f);
                SpectrumAD tmp = pixelColorAD(scene, options, &sampler, static_cast<Float>(ix + rnd.x()), static_cast<Float>(iy + rnd.y()));
                bool val_valid = std::isfinite(tmp.val[0]) && std::isfinite(tmp.val[1]) && std::isfinite(tmp.val[2]) && tmp.val.minCoeff() >= 0.0f;
                Float tmp_val = tmp.der.abs().maxCoeff();
                bool der_valid = std::isfinite(tmp_val) && tmp_val < options.grad_threshold;
                if ( val_valid && der_valid ) {
                    pixel_val += tmp;
                } else {
                    omp_set_lock(&messageLock);
                    if (!val_valid)
                        std::cerr << std::scientific << std::setprecision(2) << "\n[WARN] Invalid path contribution: [" << tmp.val << "]" << std::endl;
                    if (!der_valid)
                        std::cerr << std::scientific << std::setprecision(2) << "\n[WARN] Rejecting large gradient: [" << tmp.der << "]" << std::endl;
                    omp_unset_lock(&messageLock);
                }
            }
            pixel_val /= options.num_samples;

            rendered_image[idx_pixel*3    ] = static_cast<float>(pixel_val.val(0));
            rendered_image[idx_pixel*3 + 1] = static_cast<float>(pixel_val.val(1));
            rendered_image[idx_pixel*3 + 2] = static_cast<float>(pixel_val.val(2));
            for ( int ch = 1; ch <= nder; ++ch ) {
                int offset = (ch*num_pixels + idx_pixel)*3;
                rendered_image[offset    ] = static_cast<float>((pixel_val.grad(ch - 1))(0));
                rendered_image[offset + 1] = static_cast<float>((pixel_val.grad(ch - 1))(1));
                rendered_image[offset + 2] = static_cast<float>((pixel_val.grad(ch - 1))(2));
            }
        }

        if ( !options.quiet ) {
            omp_set_lock(&messageLock);
            progressIndicator(Float(++finished_block)/num_block);
            omp_unset_lock(&messageLock);
        }
    }

    if ( !options.quiet ) std::cout << std::endl;
}

SpectrumAD VolPathTracerADps::LiAD(const Scene &scene, RndSampler *sampler, const RayAD &inRay, int pixel_x, int pixel_y, int max_depth, const Medium *med_default) const
{
	Ray ray = inRay.toRay();
	Intersection its;
	IntersectionAD itsAD;
	SpectrumAD ret;
	const Medium* ptr_med = med_default;
	Spectrum throughput = Spectrum::Ones();
	SpectrumAD throughputAD(Spectrum::Ones());
	Float eta = 1.0f;
	bool incEmission = true;
	scene.rayIntersect(ray, false, its);
	EventRecordAD preEvt(inRay.org, VectorAD(), ptr_med);

	for(int depth = 0; depth <= max_depth; depth++)
	{
		bool isFirst = incEmission;
		int max_interactions = max_depth - depth - 1;
		// sample next event
		bool inside_med = ptr_med != nullptr &&
				  ptr_med->sampleDistance(ray, its.t, sampler->next2D(), sampler, ray.org, throughput);

		// handle hybrid
		if (isHybrid())
		{
			if (isFirst)
			{
				if (inside_med)
					break;
				else
				{
					if (!its.ptr_bsdf->isNull())
						if (!its.ptr_bsdf->isTransmissive())
							break;
				}
			}
		}

		// handle event
		if (inside_med)
		{
			if (depth >= max_depth) break;
			// if (throughput.isZero(Epsilon))
			// 	break;
			const PhaseFunction* ptr_phase = scene.phase_list[ptr_med->phase_id];

			VectorAD x;
			FloatAD J;
			if(!ptr_med->getPoint(ray.org, x, J))
				break;
			VectorAD pre_x = preEvt.x();
			VectorAD dir = x - pre_x;
			FloatAD dist = dir.norm();
			dir /= dist;

			// calculate throughput
			SpectrumAD f = isFirst ? scene.camera.evalFilterAD(pixel_x, pixel_y, x) : preEvt.f(scene, dir);
			FloatAD G = isFirst ? FloatAD(1.) : 1.0 / dist.square();
			FloatAD Tratio = scene.evalTransmittanceAD(RayAD(pre_x, dir), preEvt.onSurface, preEvt.getMedium(dir.val), dist, sampler, depth, true);
			SpectrumAD sigS = ptr_med->sigSAD(x);
			Float sigT = ptr_med->sigT(x.val);
			Float pdf = sigT * G.val;
			throughputAD *= f * Tratio * sigS * G * J / pdf;

			Vector wo;
			VectorAD woAD;
			Float pdf_nee;
			Float G1;
			auto value = scene.sampleAttenuatedEmitterDirectAD(x, sampler->next2D(), sampler, ptr_med, max_interactions, woAD, pdf_nee, &G1);
			if (!value.isZero(Epsilon))
			{
				FloatAD phase_val = ptr_phase->evalAD(-dir, woAD);
				if (!phase_val.isZero(Epsilon))
				{
					Float phase_pdf = ptr_phase->pdf(-dir.val, woAD.val) * G1;
					auto mis_weight = pdf_nee < 0.0 ? 1.0 : square(pdf_nee) / (square(pdf_nee) + square(phase_pdf));
					// mis_weight = 1;
					ret += throughputAD * value * mis_weight * phase_val;
				}
			}

			// indirect illumination
			if (ptr_phase->sample(-ray.dir, sampler->next2D(), wo) == 0)
				break;
			// trace a new ray in this direction
			ray.dir = wo;
			SpectrumAD attenuated_radiance = scene.rayIntersectAndLookForEmitterADps(RayAD(x, ray.dir), false, sampler, ptr_med, max_interactions, its, woAD, pdf_nee, G1);
			if (!attenuated_radiance.isZero(Epsilon))
			{
				FloatAD phase_val = ptr_phase->evalAD(-dir, woAD);
				Float phase_pdf = ptr_phase->pdf(-dir.val, ray.dir) * G1;
				auto mis_weight = square(phase_pdf) / (square(pdf_nee) + square(phase_pdf));
				ret += throughputAD * attenuated_radiance * mis_weight * phase_val / phase_pdf;
			}
			incEmission = false;
			preEvt = EventRecordAD(x, -dir, ptr_med);
		}
		else
		{
			if (!its.isValid())
				break;

			// getPoint
			FloatAD J;
			scene.getPoint(its, preEvt.x(), itsAD, J);
			VectorAD x = itsAD.p;
			VectorAD pre_x = preEvt.x();
			VectorAD dir = x - pre_x;
			FloatAD dist = dir.norm();
			dir /= dist;

			// calculate throughput
			if (!its.ptr_bsdf->isNull())
			{
				SpectrumAD f = isFirst ? scene.camera.evalFilterAD(pixel_x, pixel_y, itsAD) : preEvt.f(scene, dir);
				FloatAD G = isFirst ? FloatAD(1.) : itsAD.geoFrame.n.dot(-dir) / dist.square();
				FloatAD Tratio = scene.evalTransmittanceAD(RayAD(pre_x, dir), preEvt.onSurface, preEvt.getMedium(dir.val), dist, sampler, depth, true);
				Float pdf = G.val;
				throughputAD *= f * Tratio * G * J / pdf;
			}
			if(!throughputAD.val.allFinite()) break;//!
			if (its.isEmitter() && incEmission)
				ret += throughputAD * itsAD.Le(-dir);

			if (depth >= max_depth)
				break;

			Float pdf_nee;
			Vector wo;
			if (!its.ptr_bsdf->isNull())
			{
				VectorAD woAD;
				Float G1;
				auto value = scene.sampleAttenuatedEmitterDirectAD(itsAD, sampler->next2D(), sampler, ptr_med, max_interactions, woAD, pdf_nee, &G1);
				if (!value.isZero(Epsilon))
				{
					auto bsdf_val = itsAD.evalBSDF(woAD);
					Float bsdf_pdf = itsAD.pdfBSDF(woAD.val) * G1;
					auto mis_weight = pdf_nee < 0.0 ? 1.0 : square(pdf_nee) / (square(pdf_nee) + square(bsdf_pdf));
					// mis_weight = 1.0;
					ret += throughputAD * value * bsdf_val * mis_weight;
				}
			}
			// Indirect illumination
			Float bsdf_pdf, bsdf_eta;
			auto bsdf_weight = its.sampleBSDF(sampler->next3D(), wo, bsdf_pdf, bsdf_eta);
			if (bsdf_weight.isZero())
				break;
			wo = its.toWorld(wo);
			ray = Ray(its.p, wo);

			if (its.isMediumTransition())
				ptr_med = its.getTargetMedium(wo);

			if (its.ptr_bsdf->isNull())
				scene.rayIntersect(ray, true, its);
			else
			{
				Float G1;
				VectorAD woAD;
				SpectrumAD attenuated_radiance = scene.rayIntersectAndLookForEmitterADps(RayAD(x, ray.dir), true, sampler, ptr_med, max_interactions, its, woAD, pdf_nee, G1);
				eta *= bsdf_eta;
				if (!attenuated_radiance.isZero(Epsilon))
				{
					auto bsdf_val = itsAD.evalBSDF(itsAD.toLocal(woAD));
					bsdf_pdf *= G1;
					auto mis_weight = square(bsdf_pdf) / (square(pdf_nee) + square(bsdf_pdf));
					// mis_weight = 1.;
					ret += throughputAD * attenuated_radiance * mis_weight * bsdf_val / bsdf_pdf;
				}
				incEmission = false;
				preEvt = EventRecordAD(itsAD);
			}
		}
	}
	return ret;
}

SpectrumAD VolPathTracerADps::pixelColorAD(const Scene &scene, const RenderOptions &options, RndSampler *sampler, Float x, Float y) const {
	int pixel_x = static_cast<int>(x);
	int pixel_y = static_cast<int>(y);
	SpectrumAD ret;
	Ray rays[2];
	const int max_depth = options.max_bounces;
	const auto &camera = scene.camera;
	scene.camera.samplePrimaryRayFromFilter(pixel_x, pixel_y, sampler->next2D(), rays[0], rays[1]);
	const Medium *init_med = camera.getMedID() == -1 ? nullptr : scene.medium_list[camera.getMedID()];
	VectorAD o = scene.camera.cpos;
	uint64_t state = sampler->state;
	for (int i = 0; i < 2; i++)
	{
		sampler->state = state;
		Ray &ray = rays[i];
		ret += 1./2. * LiAD(scene, sampler, RayAD(o, ray.dir), pixel_x, pixel_y, max_depth, init_med);
	}
	return ret;
}