import vredner
from pyvredner import SceneManager, SceneTransform
import pyvredner
import torch
import sys, os
import argparse
import numpy as np
from scipy.ndimage.filters import gaussian_filter
import matplotlib.pyplot as plt
import time


def time_to_string(time_elapsed):
    hours, rem = divmod(time_elapsed, 3600)
    minutes, seconds = divmod(rem, 60)
    hours = int(hours)
    minutes = int(minutes)
    if hours > 0:
        ret = "{:0>2}h {:0>2}m {:0>2.2f}s".format(hours, minutes, seconds)
    elif minutes > 0:
        ret = "{:0>2}m {:0>2.2f}s".format(minutes, seconds)
    else:
        ret = "{:0>2.2f}s".format(seconds)
    return ret

def create_output_directory(dir_out):
    if not os.path.exists(dir_out):
        os.mkdir( dir_out )


def main(args):
    dir_out = './results/'
    create_output_directory(dir_out)
    max_bounces = 9
    scene, integrator = pyvredner.load_mitsuba('./scene.xml')
    fn_orig = dir_out + '/orig_%s.exr' % args.integrator
    fn_deriv = dir_out + '/deriv_%s.exr' % args.integrator
    sppse0 = 0
    sppse1 = 0
    spp = 0
    sppe = 0

    if args.integrator == 'volpath':
        spp = 1024 * 16
        integrator = vredner.VolPathTracer()
    elif args.integrator == 'volpath_simple':
        spp = 8192
        integrator = vredner.VolPathTracerSimple()
    elif args.integrator == 'volpathAD':
        spp = 1024
        sppse0 = 4
        sppe = 1024
        integrator = vredner.VolPathTracerAD()
    elif args.integrator == 'fd':
        # spp = 1024 * 8 * 64          #delta = 1e-1
        spp = 1024 * 8 * 128         #delta = 2e-2
        integrator = vredner.VolPathTracer()
    elif args.integrator == 'volpathADps':
        spp = 0
        sppse0 = 2048
        integrator = vredner.VolpathSimpleADps()
    elif args.integrator == 'volpathADps2':
        spp = 2048
        # spp = 1024                # 10m 36s
        # spp = 65536               # 11h 20m
        # sppe = 128
        # sppse0 = 8192
        sppe = 1024
        integrator = vredner.VolPathADps()
    else:
        assert(False)

    # sppe = 1024 * 64 * 8
    sppe = 1024
    options = vredner.RenderOptions(13, spp, max_bounces, sppe, sppse0, False)
    options.sppse0 = sppse0
    options.sppse1 = sppse1
    scene_args = pyvredner.serialize_scene(scene)
    xforms = []
    xforms.append( [SceneTransform("SHAPE_GLOBAL_ROTATE", torch.tensor([ 0.0, 0.2, 0.0], dtype=torch.float), 1),
                    SceneTransform("MEDIUM_VARY", "heterogeneous", 0,
                                    torch.tensor([ 0.0, 0.0, 0.0]),
                                    torch.tensor([ 0.0, 0.2, 0.0]),
                                    0, 0)] )
    scene_manager = SceneManager(scene_args, xforms)
    # scene_manager.set_primary_guiding([10000, 0, 0, 1024], integrator, options, quiet = False)
    # plt.plot(scene_manager.grid_guide_primary.numpy())
    # plt.savefig("visPrimary.png")
    if args.integrator == 'fd':
        # fd_delta = 1e-1
        # img0 = scene_manager.render(integrator, options)[0, :, :, :]
        # # pyvredner.imwrite(img0, dir_out+"orig_%s0.exr"%(args.integrator))
        # scene_manager.set_arguments(torch.tensor([fd_delta]))
        # img1 = scene_manager.render(integrator, options)[0, :, :, :]
        # # pyvredner.imwrite(img1, dir_out+"orig_%s1.exr"%(args.integrator))
        # img_fd = (img1-img0)/fd_delta
        # pyvredner.imwrite(img_fd, dir_out+"deriv_%s.exr"%(args.integrator))
        fd_delta = 2e-2
        scene_manager.set_arguments(torch.tensor([-fd_delta]))
        img0 = scene_manager.render(integrator, options)[0, :, :, :]
        scene_manager.reset()
        scene_manager.set_arguments(torch.tensor([fd_delta]))
        img1 = scene_manager.render(integrator, options)[0, :, :, :]
        img_fd = (img1-img0)/(2*fd_delta)
        pyvredner.imwrite(img_fd, dir_out+"deriv_%s.exr"%(args.integrator))
    else:
        img_new = scene_manager.render(integrator, options)
        img_deriv = img_new[1,:,:,:]
        pyvredner.imwrite(img_new[0], dir_out+"orig_%s.exr"%(args.integrator))
        pyvredner.imwrite(img_deriv, dir_out+"deriv_%s.exr"%(args.integrator))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
            description='Script for generating validation results',
            epilog='Cheng Zhang (chengz20@uci.edu)')

    parser.add_argument('integrator', metavar='integrator', nargs='?', type=str, default="volpathADps2")

    args = parser.parse_args()
    main(args)