import os
import logging

import torch
import numpy as np
import math

# from . import kpn
# from . import unet
# from . import denoiserutils

# from render import util
# from render import optixutils as ou

###############################################################################
# Single frame denoiser
###############################################################################

# class DLDenoiser(torch.nn.Module):
# 	def __init__(self, influence=1.0):
# 		super(DLDenoiser, self).__init__() 
# 
# 		print("---> Denoiser constructor")
# 
# 		#####################################################################################
# 		encoder_features    = [[32, 32], [32], [32], [32], [32]]
# 		bottleneck_features = [32]
# 		decoder_features    = [[64, 64], [64, 64], [64, 64], [64, 32], [32, 32]]
# 		
# 		self.denoiser = unet.UNet(9, 3, encoder_features, bottleneck_features, decoder_features)
# 		#####################################################################################
# 
# 		self.kpn_kernel = 5
# 		self.kpn_layers = 3
# 		self.kpn        = kpn.RecursiveHierarchicalKernelPredictor([f[-1] for f in list(reversed(decoder_features))[:self.kpn_layers]], self.kpn_kernel, features32=True)
# 		self.set_influence(influence)
# 
# 	def set_influence(self, factor):
# 		self.influence = max(0.0, min(factor, 1.0))
# 
# 	def forward(self, input):
# 		if self.influence <= 0.0:
# 			return input[..., 0:3]
# 		# Kernel prediction
# 		input_nchw = input[..., 0:9].permute(0, 3, 1, 2) # remove gb_depth because it confuses the denoiser
# 		color = input_nchw[:, 0:3, ...]
# 		denoised = self.denoiser(input_nchw)
# 		filtered = self.kpn(color, None, denoised.decoder_mips[:self.kpn_layers])
# 		return torch.lerp(input[..., 0:3], filtered.color.permute(0, 2, 3, 1), self.influence)

class GaussianDenoiser(torch.nn.Module):
	def __init__(self, influence=1.0):
		super(GaussianDenoiser, self).__init__()
		self.set_influence(influence)

	def set_influence(self, factor):
		self.sigma = max(factor * 2, 0.0001)
		self.variance = self.sigma**2.
		self.N = 2 * math.ceil(self.sigma * 2.5) + 1

	def forward(self, input):
		input = input.unsqueeze(0).to('cuda')
		grid_y, grid_x = torch.meshgrid(torch.linspace(
			-self.N//2, self.N//2, self.N, dtype=torch.float32, device="cuda"),
			 torch.linspace(-self.N//2, self.N//2, self.N, dtype=torch.float32, device="cuda"))
		xy_grid = torch.stack([grid_x, grid_y], dim=-1)

		gaussian_kernel = (.5*np.pi*self.variance) * torch.exp(
			-torch.sum(xy_grid**2., dim=-1) / (2*self.variance))
		gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)
		
		def _w(c, cN):
			return torch.stack(list(gaussian_kernel if i == c else torch.zeros(
				self.N, self.N, dtype=torch.float32, 
				device="cuda") for i in range(cN)), dim=0)

		gaussian_weights = torch.stack(list(_w(i, 3) for i in range(3)), dim=0)

		input_nchw = input[..., :3].permute(0, 3, 1, 2)
		return torch.nn.functional.conv2d(
			input_nchw, gaussian_weights, padding=self.N//2).permute(0, 2, 3, 1).squeeze().to('cpu')

# class BilateralDenoiser(torch.nn.Module):
# 	def __init__(self, influence=1.0):
# 		super(BilateralDenoiser, self).__init__()
# 		self.set_influence(influence)
# 
# 	def set_influence(self, factor):
# 		self.sigma = max(factor * 2, 0.0001)
# 		self.variance = self.sigma**2.
# 		self.N = 2 * math.ceil(self.sigma * 2.5) + 1
# 
# 	def svgf(self, col, nrm, zdz, kd):
# 		eps    = 0.0001
# 		accum_col = torch.zeros_like(col)
# 		accum_w = torch.zeros_like(col[..., 0:1])
# 		for y in range(-self.N, self.N+1):
# 			for x in range(-self.N, self.N+1):
# 				dist_sqr = (x**2 + y**2)
# 				dist = np.sqrt(dist_sqr)
# 				w_xy = np.exp(-dist_sqr / (2 * self.variance))
# 
# 				with torch.no_grad():
# 					nrm_tap = torch.roll(nrm, (-y, -x), (1, 2))
# 					w_normal = torch.pow(torch.clamp(util.dot(nrm_tap, nrm), min=eps, max=1.0), 128.0)           # From SVGF
# 					zdz_tap = torch.roll(zdz, (-y, -x), (1, 2))
# 					w_depth = torch.exp(-(torch.abs(zdz_tap[..., 0:1] - zdz[..., 0:1]) / torch.clamp(zdz[..., 1:2] * dist, min=eps)) ) # From SVGF	
# 
# 					w = w_xy * w_normal * w_depth
# 
# 				col_tap = torch.roll(col, (-y, -x), (1, 2))
# 				accum_col += col_tap * w
# 				accum_w += w
# 		return accum_col / torch.clamp(accum_w, min=eps)
# 
# 	def forward(self, input):
# 		col    = input[..., 0:3]
# 		nrm    = util.safe_normalize(input[..., 3:6]) # Bent normals can produce < 1 length normals here.
# 		kd     = input[..., 6:9]
# 		zdz    = input[..., 9:11]
# 		return ou.svgf(col, nrm, zdz, kd, self.sigma)
# 		#return self.svgf(col, nrm, zdz, kd)
# 

#----------------------------------------------------------------------------
# sRGB color transforms
#----------------------------------------------------------------------------

def _rgb_to_srgb(f: torch.Tensor) -> torch.Tensor:
    return torch.where(f <= 0.0031308, f * 12.92, torch.pow(torch.clamp(f, 0.0031308), 1.0/2.4)*1.055 - 0.055)

def rgb_to_srgb(f: torch.Tensor) -> torch.Tensor:
    assert f.shape[-1] == 3 or f.shape[-1] == 4
    out = torch.cat((_rgb_to_srgb(f[..., 0:3]), f[..., 3:4]), dim=-1) if f.shape[-1] == 4 else _rgb_to_srgb(f)
    assert out.shape[0] == f.shape[0] and out.shape[1] == f.shape[1] and out.shape[2] == f.shape[2]
    return out

def _srgb_to_rgb(f: torch.Tensor) -> torch.Tensor:
    return torch.where(f <= 0.04045, f / 12.92, torch.pow((torch.clamp(f, 0.04045) + 0.055) / 1.055, 2.4))

def srgb_to_rgb(f: torch.Tensor) -> torch.Tensor:
    assert f.shape[-1] == 3 or f.shape[-1] == 4
    out = torch.cat((_srgb_to_rgb(f[..., 0:3]), f[..., 3:4]), dim=-1) if f.shape[-1] == 4 else _srgb_to_rgb(f)
    assert out.shape[0] == f.shape[0] and out.shape[1] == f.shape[1] and out.shape[2] == f.shape[2]
    return out

def safe_normalize(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor:
    return x / length(x, eps)

def length(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor:
    return torch.sqrt(torch.clamp(dot(x,x), min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN

def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    return torch.sum(x*y, -1, keepdim=True)

class OIDNDenoiser(torch.nn.Module):
	def __init__(self, influence=1.0, use_guides=True, backprop_guides=False):
		super(OIDNDenoiser, self).__init__()
		self.channels = 9 if use_guides else 3
		self.backprop_guides = backprop_guides
		from pyTensorRay.denoiser.oidn import model
		self.unet = model.UNet(self.channels, 3)
		self.load_weights('C:/Users/yuchenw/Documents/dev/TensorRay/pyTensorRay/denoiser/oidn/rt_ldr.tza')
		# self.load_weights('denoiser/oidn/rt_ldr_alb_nrm.tza' if use_guides else 'denoiser/oidn/rt_ldr.tza')
		self.set_influence(influence)

	def set_influence(self, factor):
		self.influence = max(0.0, min(factor, 1.0))

	def map_forward(self, x):
		return rgb_to_srgb(torch.log(x + 1.0))

	def map_inverse(self, x):
		return torch.exp(srgb_to_rgb(x)) - 1.0

	def forward(self, input):
		input = input.unsqueeze(0).to('cpu')
		if self.influence <= 0.0:
			return input[..., 0:3]

		# map and sanitize inputs
		rgb = torch.clamp(self.map_forward(input[..., 0:3]), 0, 1)
		if self.channels == 9:
			albedo = torch.clamp(input[..., 6:9], 0, 1)
			normals = torch.clamp(safe_normalize(input[..., 3:6]) * 0.5 + 0.5, 0, 1)
		else:
			albedo = torch.zeros_like(rgb)
			normals = torch.zeros_like(rgb)

		# detach guide buffers?
		if not self.backprop_guides:
			albedo = albedo.detach()
			normals = normals.detach()

		# run network
		oidn_input = torch.cat((rgb, albedo, normals), dim=-1)[..., 0:self.channels]
		oidn_output = self.unet.forward(oidn_input.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)

		# apply inverse input mapping and interpolate
		return torch.lerp(input[..., 0:3], self.map_inverse(oidn_output), self.influence).squeeze().to('cpu')

	def load_weights(self, tza_path: str):
		from pyTensorRay.denoiser.oidn import tza
		reader = tza.Reader(tza_path)
		with torch.no_grad():
			for name, param in self.unet.named_parameters():
				param.data = torch.from_numpy(np.array(reader[name][0]))
