import vredner
from pyvredner import SceneManager, SceneTransform
import pyvredner
import torch
import sys
import argparse

# img_1 = pyvredner.imread('../example_1/dielectric/non_edge_32768_bdptAD.exr')
# img_2 = pyvredner.imread('../example_1/dielectric/primary_edge_32768.exr')
# img_3 = pyvredner.imread('../example_1/dielectric/indirect_edge_16384_guided.exr')

# img = img_1 +  img_2 + img_3
# pyvredner.imwrite(img, '../example_1/dielectric/all.exr')

# img = pyvredner.imread('../example_1/dielectric/deriv_fd_particle.exr')
# pyvredner.imwrite(torch.abs(img), '../example_1/dielectric/deriv_fd_particle_abs.exr')
def main(args):
	if args.operator == 'abs':
		print("[INFO] abs the input image")
		img = pyvredner.imread(args.input[0])
		pyvredner.imwrite(torch.abs(img), args.output)
	elif args.operator == 'neg':
		print("[INFO] negate the input image")
		img = pyvredner.imread(args.input[0])
		pyvredner.imwrite(-img, args.output)
	elif args.operator == 'add':
		print('[INFO] add %d images' % len(args.input))
		img = pyvredner.imread(args.input[0])
		for i in range(1, len(args.input)):
			img += pyvredner.imread(args.input[i])
		pyvredner.imwrite(img, args.output)
	elif args.operator == 'pt2exr':
		img = torch.load(args.input[0])
		pyvredner.imwrite(img, args.output)
	elif args.operator == 'diff':
		img1 = pyvredner.imread(args.input[0])
		img2 = pyvredner.imread(args.input[1])
		pyvredner.imwrite(img1-img2, args.output)
	elif args.operator == 'divide':
		img1 = pyvredner.imread(args.input[0])
		img2 = pyvredner.imread(args.input[1])
		pyvredner.imwrite(img1.abs()/img2.abs(), args.output)
	else:
		print("[ERROR] operator '%s' is not supported" % args.operator)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
            description='Script for simple EXR utilities',
            epilog='Cheng Zhang (chengz20@uci.edu)')

    parser.add_argument('operator', metavar='operator', type=str, help='EXR operations to perform')
    parser.add_argument('--output', metavar='output', type=str, help='output path')
    parser.add_argument('--input', metavar='input', type=str, nargs='+', help='input path(s)')

    args = parser.parse_args()
    main(args)