from pyvredner import SceneManager, SceneTransform
import pyvredner
import vredner
import torch
import sys
import os
import argparse
import numpy as np
from scipy.ndimage.filters import gaussian_filter
import matplotlib.pyplot as plt
import time


def time_to_string(time_elapsed):
    hours, rem = divmod(time_elapsed, 3600)
    minutes, seconds = divmod(rem, 60)
    hours = int(hours)
    minutes = int(minutes)
    if hours > 0:
        ret = "{:0>2}h {:0>2}m {:0>2.2f}s".format(hours, minutes, seconds)
    elif minutes > 0:
        ret = "{:0>2}m {:0>2.2f}s".format(minutes, seconds)
    else:
        ret = "{:0>2.2f}s".format(seconds)
    return ret


def create_output_directory(dir_out):
    if not os.path.exists(dir_out):
        os.mkdir(dir_out)
    if not os.path.exists(dir_out + 'iterations'):
        os.mkdir(dir_out + 'iterations')


def main(args):
    dir_out = './results_%s/' % args.integrator
    create_output_directory(dir_out)
    max_bounces = 7
    scene, integrator = pyvredner.load_mitsuba('./scene.xml')
    spp_target = 1024 * 10

    if args.integrator == 'volpathAD':
        spp_iter = 128
        sppse0 = 32
        sppse1 = 0
        integrator = vredner.VolPathTracerAD()
    elif args.integrator == 'volpathADps':
        spp_iter = 64
        sppse0 = 0
        sppse1 = 0
        integrator = vredner.VolPathADps()
    elif args.integrator == 'hybridADps':
        spp_iter = 512
        sppse0 = 128
        sppse1 = 128
        integrator = vredner.hybridADps()
    else:
        assert(False)

    options = vredner.RenderOptions(13, spp_target, max_bounces, 0, 0, False)
    options.sppse = sppse0
    options.sppe = 0
    options.sppse0 = sppse0
    options.sppse1 = sppse1

    scene_args = pyvredner.serialize_scene(scene)
    init_param = torch.tensor([0.0]*vredner.nder, dtype=torch.float)
    target_param = torch.tensor([2.0, 1.0], dtype=torch.float)
    xforms = [[
        SceneTransform("EMITTER_POINT_VARY", 0, torch.tensor([0.0, 0.0, 0.0], dtype=torch.float),
                        torch.tensor([100.0, 0.0, 0.0], dtype=torch.float))
    ],
    [
        SceneTransform("EMITTER_POINT_VARY", 0, torch.tensor([0.0, 0.0, 0.0], dtype=torch.float),
                        torch.tensor([0.0, 0.0, -50.0], dtype=torch.float))
    ],
    ]
    scene_manager = SceneManager(scene_args, xforms, init_param)

    # render initial & target image
    integrator0 = vredner.hybridADps()
    write_init_image = False
    if write_init_image:
        img_init = scene_manager.render(integrator0, options)[0, :, :, :]
        pyvredner.imwrite(img_init, dir_out + 'image_init.exr')
    scene_manager.set_arguments(target_param)
    img_target = scene_manager.render(integrator0, options)[0, :, :, :]
    pyvredner.imwrite(img_target, dir_out + 'image_target.exr')
    options.spp = spp_iter
    scene_manager.reset()

    # optimization hypher-params
    learning_rate = 1e-1  # 1/20 1/50 target_param
    num_pyramid_lvl = 9
    weight_pyramid = 4
    options.quiet = False
    num_iters = 200

    param = torch.tensor([0.0]*vredner.nder,
                         dtype=torch.float, requires_grad=True)  # !
    optimizer = torch.optim.Adam([param], lr=learning_rate)
    lossFunc = pyvredner.ADLossFunc.apply
    grad_out_range = torch.tensor([0]*vredner.nder, dtype=torch.float)
    print('[INFO] optimization starts...')
    file_error = open(dir_out + 'error.log', 'w')
    file_param = open(dir_out + 'param.log', 'w')
    times = []
    img_all = []

    fig = plt.figure(figsize=(3*(vredner.nder+1), 3))
    gs1 = fig.add_gridspec(nrows=1, ncols=vredner.nder+1)
    if not os.path.exists(dir_out + 'plot'):
        os.mkdir(dir_out + 'plot')
    loss_record = [[]]
    for i in range(vredner.nder):
        loss_record.append([])

    for t in range(num_iters):
        print('[Iter %3d]' % t, end=' ')
        optimizer.zero_grad()  # !
        options.seed = t + 1
        img = lossFunc(scene_manager, integrator, options, param,
                       # out of range penalty (not well tested)
                       grad_out_range, torch.tensor(
                           [10000.0]*vredner.nder, dtype=torch.float),
                       num_pyramid_lvl, weight_pyramid, -1, 0, times, img_all)

        pyvredner.imwrite(img, dir_out + ('/iterations/iter_%d.exr' % t))
        if img_all is not None:
            for i in range(vredner.nder):
                pyvredner.imwrite(
                    img_all[0][i + 1, :, :, :], dir_out + ('/iterations/iter_%d_%d.exr' % (t, i)))

        # compute losses
        image_loss = (img - img_target).pow(2).mean()
        # image_loss = np.sqrt(img_loss.detach().numpy())
        param_loss = (param - target_param).pow(2).mean().sqrt()
        print('render time: %s; image. RMSE: %.3e; param. RMSE: %.3e' %
              (time_to_string(times[-1]), image_loss.sqrt(), param_loss))
        # write image/param loss
        file_error.write("%d, %.5e, %.5e, %.2e\n" %
                         (t, image_loss.item(), param_loss, times[-1]))
        file_error.flush()

        # write param values
        file_param.write("%d" % t)
        for i in range(vredner.nder):
            file_param.write(", %.5e" % param[i])
        file_param.write("\n")
        file_param.flush()

        # plot image loss
        loss_record[0].append(image_loss.item())
        ax = fig.add_subplot(gs1[0])
        ax.plot(loss_record[0], 'b')
        ax.set_title('Image RMSE')
        ax.set_xlim([0, num_iters])
        # ax.set_yscale('log')
        # plot param record
        for i in range(vredner.nder):
            loss_record[i+1].append(param[i].detach()-target_param[i])
            ax = fig.add_subplot(gs1[i+1])
            ax.plot(loss_record[i+1], 'b')
            ax.set_title('Image RMSE')
            ax.set_xlim([0, num_iters])
            rng = max(abs(loss_record[i+1][0])*5, learning_rate)
            ax.set_ylim([-rng, rng])
            ax.set_title('Param. %d' % (i+1))
        plt.savefig(
            dir_out+'plot/frame_{:03d}.png'.format(t), bbox_inches='tight')
        plt.clf()

        image_loss.backward()
        optimizer.step()
        grad_out_range = scene_manager.set_arguments(param)

    file_error.close()
    file_param.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description='Script for generating validation results',
        epilog='Cheng Zhang (chengz20@uci.edu)')

    parser.add_argument('integrator', metavar='integrator', nargs="?", type=str, default="hybridADps")

    args = parser.parse_args()
    main(args)
