import torch
import torch.nn as nn

from . import denoiserutils

#################################################################################
# Recursive and hierarchical kernel predictor.
#################################################################################

def roundup32(x): return ((x + 31) // 32) * 32

class RecursiveHierarchicalKernelPredictor(nn.Module):
	def __init__(self, input_features, kernel_size, normalize=True, splat=False, features32=False):
		super(RecursiveHierarchicalKernelPredictor, self).__init__() 
		self.kernel_size    = kernel_size
		self.input_features = input_features
		self.normalize      = normalize
		self.splat          = splat

		# Create common KPN layer, used for all miplevels
		from torch_utils import WeightedFilter
		self.kpn_layer = WeightedFilter(3, kernel_size, bias=False, splat=splat).cuda()

		# Create weight layers
		if features32:
			self.wa_layers = nn.ModuleList(
				[nn.Conv2d(self.input_features[0], roundup32(2*self.kernel_size*self.kernel_size + 1), 1, padding=0)] +
				[nn.Conv2d(features, roundup32(self.kernel_size*self.kernel_size + 1), 1, padding=0) for features in self.input_features[1:]]
				).cuda()
		else:
			self.wa_layers = nn.ModuleList(
				[nn.Conv2d(self.input_features[0], 2*self.kernel_size*self.kernel_size + 1, 1, padding=0)] +
				[nn.Conv2d(features, self.kernel_size*self.kernel_size + 1, 1, padding=0) for features in self.input_features[1:]]
				).cuda()

	def forward(self, color, color_prev, weight_features):
		assert len(weight_features) == len(self.input_features)

		levels_a = []
		levels_c = []
		levels_w = []
		ksize = self.kernel_size*self.kernel_size

		for i in range(len(weight_features) - 1, -1, -1):
			# Create scaled mip version of input
			scaled_color = nn.functional.avg_pool2d(color, kernel_size=1<<i, stride=1<<i)

			# Run mixing layer and get weights + alphas
			wa = self.wa_layers[i](weight_features[i])
			weights = wa[:, 0:ksize*2, ...] if color_prev is not None and i == 0 else wa[:, 0:ksize, ...]
			alpha = wa[:, ksize*2:ksize*2+1, ...] if color_prev is not None and i == 0 else wa[:, ksize:ksize+1, ...]

			# Get normalized weights through softmax
			if self.normalize:
				weights = nn.functional.softmax(weights, dim=1)

			# Evaluate the kernel predicted filter
			i_f = self.kpn_layer(scaled_color.contiguous(), weights[:, 0:ksize, ...].contiguous())
			#levels_c = [i_f] + levels_c

			# Mix with result of previous miplevel
			if i != len(weight_features) - 1 and self.normalize:
				#####################################################################################################################################################
				# Eq. 5 in http://zurich.disneyresearch.com/~fabricer/publications/vogels-2018-kpal.pdf
				#UD = nn.functional.interpolate(nn.functional.avg_pool2d(i_f, kernel_size=2, stride=2), size=i_f.shape[-2:], mode='bilinear', align_corners=False)
				#U  = nn.functional.interpolate(i_c, mode='bilinear', size=i_f.shape[-2:], align_corners=False)
				#alpha = torch.sigmoid(alpha)
				#i_c = i_f + (U - UD) * alpha
				#####################################################################################################################################################
				# Instead using 
				U  = nn.functional.interpolate(i_c, mode='bilinear', size=i_f.shape[-2:], align_corners=False)
				i_c = i_f * (1.0 - alpha) + U * alpha
				#levels_a = [utils.jetmap(alpha)] + levels_a
			elif i != len(weight_features) - 1:
				i_c = nn.functional.interpolate(i_c, mode='bilinear', size=i_f.shape[-2:], align_corners=False) + i_f
			else:
				i_c = i_f # Coarsest level

			# For base miplevel mix with previous frame
			if color_prev is not None and i == 0:
				i_w = self.kpn_layer(color_prev.contiguous(), weights[:, ksize:2*ksize, ...].contiguous())
				i_c = i_c + i_w
				#levels_w = [i_w] + levels_w

		return denoiserutils.object_from_dict({
			'color'    : i_c,
		})