#
# Copyright 2019 NVIDIA Corporation. All rights reserved.
#

import os
import re
import glob
import math
import time

import numpy as np
import torch

###############################################################################
# Some utility functions to make pytorch and numpy behave the same
###############################################################################

def _pow(x, y):
	if isinstance(x, torch.Tensor) or isinstance(y, torch.Tensor):
		return torch.pow(x, y)
	else:
		return np.power(x, y)

def _log(x):
	if isinstance(x, torch.Tensor):
		return torch.log(x)
	else:
		return np.log(x)

def _exp(x):
	if isinstance(x, torch.Tensor):
		return torch.exp(x)
	else:
		return np.exp(x)

def _clamp(x, y, z):
	if isinstance(x, torch.Tensor) or isinstance(y, torch.Tensor) or isinstance(z, torch.Tensor):
		return torch.clamp(x, y, z)
	else:
		return np.clip(x, y, z)

def _zeros_like(x):
	if isinstance(f, torch.Tensor):
		return torch.zeros_like(x)
	else:
		return np.zeros_like(x)

def _mean(x, dim=None, keepdim=False):
	if isinstance(f, torch.Tensor):
		return x.mean(dim=dim, keepdim=keepdim)
	else:
		return x.mean(axis=dim, keepdims=keepdim)

def space_to_depth(x):
	assert x.shape[-1] % 2 == 0 and x.shape[-2] % 2 == 0
	return torch.cat((x[..., 0::2, 0::2], x[..., 1::2, 0::2], x[..., 0::2, 1::2], x[..., 1::2, 1::2]), dim=-3)

def depth_to_space(x):
	assert x.shape[-3] % 4 == 0
	c = x.shape[-3] // 4
	res = torch.zeros_like(x).view(*[s for s in x.shape[:-3]], c, x.shape[-2] * 2, x.shape[-1] * 2)
	res[..., 0::2, 0::2] = x[..., 0*c:1*c, :, :].clone()
	res[..., 1::2, 0::2] = x[..., 1*c:2*c, :, :].clone()
	res[..., 0::2, 1::2] = x[..., 2*c:3*c, :, :].clone()
	res[..., 1::2, 1::2] = x[..., 3*c:4*c, :, :].clone()
	return res

###############################################################################
# Create a object with members from a dictionary
###############################################################################

class DictObject:
	def __init__(self, _dict):
		self.__dict__.update(**_dict)

def object_from_dict(_dict):
	return DictObject(_dict)

def object_merge_dict(obj, _dict):
	o_dict = obj.__dict__.copy()
	o_dict.update(_dict)
	return DictObject(o_dict)

###############################################################################
# HWC <-> CHW format conversion
###############################################################################

def HWCtoCHW(x):
	return np.moveaxis(x, -1, 0)

def CHWtoHWC(x):
	return np.moveaxis(x, 0, -1)

###############################################################################
# SMAPE Loss
###############################################################################

def SMAPE(d, r):
	denom = torch.abs(d) + torch.abs(r) + 0.01 
	return torch.mean(torch.abs(d-r) / denom)

def tSMAPE(d, r, dd, rr):
	denom = torch.abs(d) + torch.abs(r) + 0.01 
	return torch.mean(torch.abs(dd-rr) / denom)

###############################################################################
# Tonemapping
###############################################################################

def tonemap_nop(f):
	return f

def tonemap_scale(f, c):
	return f * c

def tonemap_gamma(f):
	return _pow(f, 1.0/2.2)

def tonemap_reinhard(f):
	fc = _clamp(f, 0.0, 65536.0)
	return fc / (fc + 1.0)

def tonemap_log(f):
	fc = _clamp(f, 0.0, 65536.0)
	return _log(fc + 1.0)

def tonemap_log_gamma(f):
	fc = _clamp(f, 0.0, 65536.0)
	return _pow(_log(fc + 1.0), 1.0/2.2)

#Transfer function taken from https://arxiv.org/pdf/1712.02327.pdf
def tonemap_srgb(f):
	a = 0.055
	if isinstance(f, torch.Tensor):
		return torch.where(f > 0.0031308, _pow(torch.clamp(f, 0.0031308), 1.0/2.4)*(1 + a) - torch.ones_like(f)*a, 12.92*f)
	else:
		return np.where(f > 0.0031308, _pow(f, 1.0/2.4)*(1 + a) - a, 12.92*f)


###############################################################################
# Jetmap
###############################################################################

def jetmap(v):
	bias = torch.cuda.FloatTensor([3.0, 2.0, 1.0]).view(1,3,1,1)
	vv = torch.clamp(v[:, None, :, :].repeat(1, 3, 1, 1) if len(v.shape) == 3 else v.repeat(1, 3, 1, 1), 0.0, 1.0)
	return torch.clamp(1.5 - torch.abs(4.0 * vv - bias), 0.0, 1.0)

###############################################################################
# Create a folder if it doesn't exist
###############################################################################

def mkdir(x):
	os.makedirs(x, exist_ok=True)

###############################################################################
# Reprojection / warping helper function
###############################################################################

def warp(col, motionvecs, resolution):
	# Collapse to 4D tensor if higher dimension
	cView = col.view(col.shape[0], -1, col.shape[-2], col.shape[-1])
	# Create grid lookup coordinates that align with pixel centers for align_corners=False
	grid_y, grid_x = torch.meshgrid([
		torch.linspace(-1.0 + 1.0 / col.shape[-2], 1.0 - 1.0 / col.shape[-2], col.shape[-2], device='cuda'), 
		torch.linspace(-1.0 + 1.0 / col.shape[-1], 1.0 - 1.0 / col.shape[-1], col.shape[-1], device='cuda')
	])
	# Offset by normalized motion vectors
	mvs = torch.stack((
		motionvecs[:, 0, ...] * (2.0 * resolution[1] / col.shape[-1]) + grid_x[None, :, :],
		motionvecs[:, 1, ...] * (2.0 * resolution[0] / col.shape[-2]) + grid_y[None, :, :]
	), dim=3)
	return torch.nn.functional.grid_sample(cView, mvs, padding_mode='border', align_corners=False).view(*col.shape)

###############################################################################
# Initialization for D2S/S2D networks
###############################################################################

def icnr(x, scale=2, init=torch.nn.init.kaiming_normal_):
	"ICNR init of `x`, with `scale` and `init` function."
	ni, nf, h, w = x.shape
	if ni % (scale**2) == 0:
		ni2 = ni // (scale**2)
		z = torch.zeros_like(x[0:ni2, ...])
		k = init(z).transpose(0, 1)
		k = k.contiguous().view(ni2, nf, -1)
		k = k.repeat(1, 1, scale**2)
		k = k.contiguous().view([nf, ni, h, w]).transpose(0, 1)
		x = k.clone()
	else:
		init(x)