import numpy as np
import matplotlib.cm as cm
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from mpl_toolkits.axes_grid1 import make_axes_locatable
from pypsdr.utils.io import imread
import argparse, os
from skimage.transform import resize

def convertEXR2ColorMap(exrfile, pngfile, vmin, vmax, resize_scale, withColorBar = False):
	cur_dir = os.getcwd()
	os.chdir(os.path.dirname(__file__))
	cubicL = LinearSegmentedColormap.from_list("cubicL", np.loadtxt("CubicL.txt"), N=256)
	os.chdir(cur_dir)
	def remap(img):
		return img
	    # return np.multiply(np.sign(img), np.log1p(np.abs(100*img)))

	def rgb2gray(img):
		return 0.2989*img[:, :, 0] + 0.5870*img[:, :, 1] + 0.1140*img[:, :, 2]

	img_input = rgb2gray(imread(exrfile))
	img_input = resize(img_input, (img_input.shape[0]* resize_scale, img_input.shape[1]* resize_scale), anti_aliasing=True)
	ratio = img_input.shape[0]/img_input.shape[1]
	if withColorBar:
		fig = plt.figure(figsize=(5, 5*ratio/1.1))
	else:
		fig = plt.figure(figsize=(5, 5*ratio))

	im = plt.imshow(remap(img_input), interpolation='bilinear', vmin=vmin, vmax=vmax, cmap=cubicL)
	plt.axis('off')
	if withColorBar:
		plt.subplots_adjust(bottom=0.0, left=-0.05, top=1.0)
		cax = fig.add_axes([0.9, 0.02, 0.04, 0.96])
		plt.colorbar(im, cax=cax)
	else:
		plt.subplots_adjust(bottom=0.0, left=0.0, top=1.0, right=1.0)
	plt.savefig(pngfile, dpi=100)
	plt.close()

def main(args):
	if os.path.isdir(args.path):
		for root, dirs, files in os.walk(args.path):
			for file in files:
				if file.endswith(".exr"):
					path_in = os.path.join(root, file)
					path_out = os.path.splitext(path_in)[0] + '.png'
					convertEXR2ColorMap(path_in, path_out, args.vmin, args.vmax, args.resize, args.colorbar > 0)

	else:
		path_in = args.path
		path_out = os.path.splitext(path_in)[0] + '.png'
		convertEXR2ColorMap(path_in, path_out, args.vmin, args.vmax, args.resize, args.colorbar > 0)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
            description='Script for simple EXR utilities',
            epilog='Cheng Zhang (chengz20@uci.edu)')

    parser.add_argument('path', metavar='path', type=str, help='input path to convert to color map')
    parser.add_argument('-vmin', metavar='vmin', type=str, help='minimum value for color map')
    parser.add_argument('-vmax', metavar='vmax', type=str, help='maximum value for color map')
    parser.add_argument('-resize', metavar='resize', type=float, default=1.0, help='rescale scalar for error image')
    parser.add_argument('-colorbar', metavar='colorbar', type=int, default=0, help='add colorbar to figure or not')

    args = parser.parse_args()
    main(args)