import psdr
from pypsdr import SceneManager, SceneTransform
import pypsdr
import torch
import sys, os
import argparse
import numpy as np
from scipy.ndimage.filters import gaussian_filter
import matplotlib.pyplot as plt
import time

def set_direct_guiding(scene_manager, gspec, integrator, options, quiet = False):
    scene_manager.set_direct_guiding(gspec, integrator, options, quiet)

def create_output_directory(dir_out):
    if not os.path.exists(dir_out):
        os.mkdir( dir_out )
    if not os.path.exists(dir_out + 'grad_img'):
        os.mkdir( dir_out + 'grad_img' )
    if not os.path.exists(dir_out + 'iterations'):
        os.mkdir( dir_out + 'iterations' )
    if not os.path.exists(dir_out + 'plot'):
        os.mkdir( dir_out + 'plot' )

def time_to_string(time_elapsed):
    hours, rem = divmod(time_elapsed, 3600)
    minutes, seconds = divmod(rem, 60)
    hours = int(hours)
    minutes = int(minutes)
    if hours > 0:
        ret = "{:0>2}h {:0>2}m {:0>2.2f}s".format(hours, minutes, seconds)
    elif minutes > 0:
        ret = "{:0>2}m {:0>2.2f}s".format(minutes, seconds)
    else:
        ret = "{:0>2.2f}s".format(seconds)
    return ret

def main(args):
    dir_out = './results/'
    create_output_directory(dir_out)
    scene, integrator = pypsdr.load_mitsuba('./scene.xml')

    max_bounces = 1
    spp_target = 1024   #spp for generating reference image
    spp = 8
    sppse0 = 8
    sppse1 = 0
    sppe = 0

    options = psdr.RenderOptions(13, spp_target, max_bounces, sppe, sppse0, False)
    options.sppse1 = sppse1
    options.grad_threshold = 5e9

    scene_args = pypsdr.serialize_scene(scene)
    xforms = [ [SceneTransform("SHAPE_GLOBAL_ROTATE", torch.tensor([0, 0, 1.0], dtype=torch.float), 1)] ]
    init_param = torch.tensor([0.0]*psdr.nder, dtype=torch.float)            # Be careful
    scene_manager = SceneManager(scene_args, xforms, init_param)
    gspec_direct = [40000, 1, 1, 8]
    guiding_quiet = True
    param_target = torch.tensor( [0.6] )

    if args.mode == 1:
        options.sppse0 = options.sppse1 = options.sppe = 0
        img_init = scene_manager.render(integrator, options) [0, :, :, :]
        pypsdr.imwrite(img_init, dir_out + 'image_init.exr')
        scene_manager.set_arguments( param_target )
        img_target = scene_manager.render(integrator, options) [0, :, :, :]
        pypsdr.imwrite(img_target, dir_out + 'image_target.exr')
    elif args.mode == 2:
        integrator = psdr.DirectADps()
        options.spp = spp
        set_direct_guiding(scene_manager, gspec_direct, integrator, options, False)
        img_grad = scene_manager.render(integrator, options)
        for i in range(psdr.nder):
            img_deriv = img_grad[i+1, :, :, :]
            pypsdr.imwrite(img_deriv, dir_out + 'grad_img/deriv%d.exr' % i)
    else:
        print('[INFO] optimization for inverse rendering starts...')
        fig = plt.figure(figsize=(3*(psdr.nder+1), 3))
        gs1 = fig.add_gridspec(nrows=1, ncols=psdr.nder+1)

        loss_record = [[]]
        for i in range(psdr.nder):
            loss_record.append([])
        scene_manager.set_arguments( param_target )
        img_target = scene_manager.render(integrator, options) [0, :, :, :]
        pypsdr.imwrite(img_target, dir_out + 'image_target.exr')
        scene_manager.reset()

        options.spp = spp
        integrator = psdr.DirectADps()
        lossFunc = pypsdr.ADLossFunc.apply
        param = torch.tensor([0.0]*psdr.nder, dtype=torch.float, requires_grad=True)
        lr = 2e-2
        optimizer = torch.optim.Adam( [param], lr=lr)
        grad_out_range = torch.tensor([0]*psdr.nder, dtype=torch.float)

        file_loss  = open(dir_out + 'iterations/iter_loss.log', 'w')
        file_param = open(dir_out + 'iterations/iter_param.log', 'w')
        num_pyramid_lvl = 9
        weight_pyramid  = 4
        options.quiet = True
        times = []
        img_all = [] # None
        num_iters = 200
        for t in range(200):
            print('[Iter %3d]' % t, end=' ')
            optimizer.zero_grad()
            options.seed = t + 1
            # print('[iter %d] re-compute guiding..' % t)
            start = time.time()
            set_direct_guiding(scene_manager, gspec_direct, integrator, options, True)
            time_elapsed = time.time() - start
            hours, rem = divmod(time_elapsed, 3600)
            minutes, seconds = divmod(rem, 60)
            print("Total preprocess time: {:0>2.2f}s".format(seconds))

            img = lossFunc(scene_manager, integrator, options, param,
                           grad_out_range, torch.tensor([10000.0]*psdr.nder, dtype=torch.float),  # out of range penalty (not well tested)
                           num_pyramid_lvl, weight_pyramid, -1, 0, times, img_all)

            pypsdr.imwrite(img, dir_out + ('/iterations/iter_%d.exr' % t))
            if img_all is not None:
                for i in range(psdr.nder):
                    pypsdr.imwrite(img_all[0][i + 1, :, :, :], dir_out + ('/iterations/iter_%d_%d.exr' % (t, i)))

            # compute losses
            img_loss = (img - img_target).pow(2).mean()
            opt_loss = np.sqrt(img_loss.detach().numpy())
            param_loss = (param - param_target).pow(2).sum().sqrt()
            print('render time: %s; opt. loss: %.3e; param. loss: %.3e' % (time_to_string(times[-1]), opt_loss, param_loss))

            # write image/param loss
            file_loss.write("%d, %.5e, %.5e, %.2e\n" % (t, opt_loss, param_loss, times[-1]))
            file_loss.flush()

            # write param values
            file_param.write("%d" % t)
            for i in range(psdr.nder):
                file_param.write(", %.5e" % param[i])
            file_param.write("\n")
            file_param.flush()

            # plot the results
            # image loss
            loss_record[0].append(opt_loss)
            ax = fig.add_subplot(gs1[0])
            ax.plot(loss_record[0], 'b')
            ax.set_title('Img. RMSE')
            ax.set_xlim([0, num_iters])
            ax.set_yscale('log')
            # param record
            for i in range(psdr.nder):
                loss_record[i+1].append( param[i].detach()-param_target[i] )
                ax = fig.add_subplot(gs1[i+1])
                ax.plot(loss_record[i+1], 'b')
                ax.set_title('Img. RMSE')
                ax.set_xlim([0, num_iters])
                rng = max( abs(loss_record[i+1][0])*5, 10*lr)
                ax.set_ylim([-rng, rng])
                ax.set_title( 'Param. %d'%(i+1) )
            plt.savefig(dir_out+'plot/frame_{:03d}.png'.format(t), bbox_inches='tight')
            plt.clf()

            img_loss.backward()
            optimizer.step()

            grad_out_range = scene_manager.set_arguments(param)
        file_loss.close()
        file_param.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
            description='Script for Veach Scene',
            epilog='Cheng Zhang (chengz20@uci.edu)')
    parser.add_argument('-mode', metavar='mode', type=int, default=0, help='[0] optimization\
                                                                            [1] Tune param for target image\
                                                                            [2] Tune param for optimization')
    args = parser.parse_args()
    main(args)