import numpy as np
import torch, argparse, os, imageio, sys, re, time
import enoki
from enoki.cuda_autodiff import Float32 as FloatD, Vector3f as Vector3fD
import psdr_cuda
import neural_sdf
import cv2
from meshSDF import *
from renderer import RenderLayer_SDF_NRF
from utils import *
from collections import OrderedDict
import random
import subprocess
from fnmatch import fnmatch

def render_visualization(vtx_bsdf_params, scene_dir, scene_indices, sensor_indices, image_shape, render_spec):
    # print_gpu_usage("Before rendering visualization images")
    num_sensors = len(sensor_indices)
    imgs = [None] * num_sensors
    for i, (scene_id, sensor_id) in enumerate(zip(scene_indices, sensor_indices)):
        scene_file = os.path.join(scene_dir, f'opt_{scene_id}.xml')
        scene = psdr_cuda.Scene()
        scene.load_file(scene_file, False)
        npass = render_spec['npass']               # number of passes
        scene.opts.spp = render_spec['spp']     # spp per pass
        scene.opts.log_level = render_spec['log_level']
        scene.param_map["Mesh[0]"].vertex_diffuse = Vector3fD(vtx_bsdf_params[:, 0:3])        
        scene.param_map["Mesh[0]"].vertex_specular = Vector3fD(vtx_bsdf_params[:, 3:6])        
        scene.param_map["Mesh[0]"].vertex_roughness = FloatD(vtx_bsdf_params[:, 6])
        scene.configure2(sensor_indices)
        if render_spec['integrator'] == 'collocated':
            integrator = psdr_cuda.CollocatedIntegratorNRF(render_spec['intensity'])
        elif render_spec['integrator'] == 'direct':
            integrator = psdr_cuda.DirectIntegratorNRF(1, 1, render_spec['sppse_mode'])
            if 'hide_envmap' in render_spec:
                integrator.hide_emitters = render_spec['hide_envmap']
        else:
            assert False, "Doesn't support this integrator."
        # render the target image
        for j in range(npass):
            img = integrator.renderC(scene, sensor_id)
            if j == 0:
                imgs[i] = img.torch()
            else:
                imgs[i] += img.torch()

        # garbage collection
        del scene, integrator
        enoki.cuda_malloc_trim()

    for i in range(num_sensors):
        imgs[i] = (imgs[i]/float(npass)).reshape(image_shape)
    return imgs

def render_scene(scene, integrator, sil_integrator, vtx, vtx_bsdf_params, sensor_id):
    # print_gpu_usage("Before rendering optimization images")

    t0 = time.time()
    render = RenderLayer_SDF_NRF(scene, integrator, 1, sensor_indices=[sensor_id], quiet=True)
    img = render(vtx, vtx_bsdf_params)[0].reshape(scene.opts.height, scene.opts.width, 3)

    sil_render = RenderLayer_SDF_NRF(scene, sil_integrator, 1, sensor_indices=[sensor_id], quiet=True)
    sil_img = sil_render(vtx, vtx_bsdf_params)[0].reshape(scene.opts.height, scene.opts.width, 3)
    t1 = time.time()
    
    render_time = t1 - t0
    
    # garbage collection
    del render, sil_render
    enoki.cuda_malloc_trim()
    # print_gpu_usage("After rendering optimization images")
    return img, sil_img, vtx.detach(), render_time

def active_sensors(batch, num_scenes, num_sensors, vis_sensor_indices, weighted_sensor_indices, sensor_weight):
    sensor_indices = list(range(num_scenes * num_sensors))
    weights = []
    for sensor_id in sensor_indices:
        if sensor_id in vis_sensor_indices:
            weights.append(0)
        elif sensor_id in weighted_sensor_indices:
            weights.append(sensor_weight)
        else:
            weights.append(1)
    weights = np.array(weights).astype(float)
    weights /= weights.sum()
    sensor_indices = np.random.choice(sensor_indices, p=weights, size=batch, replace=False)
    return sensor_indices

# generate the target images
def optimize(max_iter          = 500,
             resolution        = 128,
             scene_dir         = 'kitty',
             batch_size        = 4,
             sdf_lr            = 1e-3,
             sdf_sched_milestones  = [250],
             sdf_sched_factor      = 0.75,
             nrf_lr            = 1e-3,
             nrf_sched_milestones  = [250],
             nrf_sched_factor      = 0.75,
             checkpoint_iter   = 100,
             from_checkpoint   = -1,
             vis_sensor_indices = [0],
             weighted_sensor_indices = [],
             sensor_weight = 5,
             img_weight = 10.0,
             eikonal_weight = 0.1,
             sil_weight = 100,
             scaling = torch.ones(3),
             bsdf_param_max = torch.tensor([1, 1, 1, 1, 1, 1, 1]).float().cuda(),
             bsdf_param_min = torch.tensor([0, 0, 0, 0, 0, 0, 0.05]).float().cuda(),
             diffuse_only = False,
             disable_nrf = False,
             render_spec       = {
                 'integrator': 'collocated',
                 'intensity': 10,
                 'sppse_mode': 1,
                 'spp': 32, 
                 'sppe': 8,
                 'sppse': 2,
                 'log_level': 0,
                 'npass': 1,
                 'hide_envmap': False,
             }):

    if render_spec['integrator'] == 'collocated':
        integrator = psdr_cuda.CollocatedIntegratorNRF(render_spec['intensity'])
    elif render_spec['integrator'] == 'direct':
        integrator = psdr_cuda.DirectIntegratorNRF(1, 1, render_spec['sppse_mode'])
        if 'hide_envmap' in render_spec:
            integrator.hide_emitters = render_spec['hide_envmap']
    else:
        assert False, "Doesn't support this integrator."

    sil_integrator = psdr_cuda.FieldExtractionIntegrator("silhouette")

    result_dir = os.path.join(scene_dir, 'results_' + render_spec['integrator'])
    os.makedirs(result_dir, exist_ok=True)

    target_dir = os.path.join(scene_dir, 'target', 'exr')
    target_sil_dir = os.path.join(scene_dir, 'target_sil', 'exr')
    assert os.path.exists(target_dir), "Can't find target images."
    assert os.path.exists(target_sil_dir), "Can't find target sillhouette images."

    # Load the target images
    num_scenes = len([filename for filename in os.listdir(scene_dir) if fnmatch(filename, 'opt_*.xml')])
    num_sensors = len(os.listdir(target_dir)) // num_scenes
    _vis_scene_indices = [id // num_sensors for id in vis_sensor_indices]
    _vis_sensor_indices = [id % num_sensors for id in vis_sensor_indices]
    target_images = [torch.from_numpy(read_exr(os.path.join(target_dir, f'sensor_{i}'))).cuda() for i in range(num_sensors * num_scenes)]
    target_sil_images = [torch.from_numpy(read_exr(os.path.join(target_sil_dir, f'sensor_{i}'))).cuda() for i in range(num_sensors * num_scenes)]

    checkpoint_dir = os.path.join(scene_dir, 'checkpoint_' + render_spec['integrator'])
    os.makedirs(checkpoint_dir, exist_ok=True)
    load_checkpoint = from_checkpoint >= 0
    if load_checkpoint:
        checkpoint = torch.load(os.path.join(checkpoint_dir, f'checkpoint_{from_checkpoint}.pt'))

    opt_obj = os.path.join(scene_dir, 'model/sdf.obj')

    # construct the neural SDF from IDR paper
    sdf  = neural_sdf.ImplicitNetwork(feature_vector_size = 256,
                                     d_in = 3,
                                     d_out = 1,
                                     dims = [ 512, 512, 512, 512, 512, 512, 512, 512 ],
                                     geometric_init = True,
                                     bias = 0.6,
                                     skip_in = [4],
                                     weight_norm = True,
                                     multires = 6)
    sdf = sdf.cuda()
    sdf_optimizer = torch.optim.Adam(sdf.parameters(), lr=sdf_lr)
    sdf_scheduler = torch.optim.lr_scheduler.MultiStepLR(sdf_optimizer, sdf_sched_milestones, gamma=sdf_sched_factor)

    nrf = neural_sdf.RenderingNetwork(d_in = 3,
                                      d_out = 7,
                                      dims = [512, 512, 512, 512, 512, 512, 512, 512],
                                      weight_norm = True,
                                      multires_point = 10)
    nrf = nrf.cuda()
    nrf_optimizer = torch.optim.Adam(nrf.parameters(), lr=nrf_lr)
    nrf_scheduler = torch.optim.lr_scheduler.MultiStepLR(nrf_optimizer, nrf_sched_milestones, gamma=nrf_sched_factor)
    
    rgb_loss = nn.L1Loss(reduction='sum')
    N = resolution
    bbox = BoundingBox(torch.tensor([-1.0, -1.0, -1.0], device='cuda'),
                       torch.tensor([ 1.0,  1.0,  1.0], device='cuda'),
                       torch.tensor([   N,    N,    N], device='cuda'))
    image_shape = target_images[0].shape
    num_elm = batch_size * torch.numel(target_images[0])     
    start_iter = 0

    if load_checkpoint:
        sdf.load_state_dict(checkpoint['sdf_model_state_dict'])
        sdf_optimizer.load_state_dict(checkpoint['sdf_optimizer_state_dict'])
        sdf_scheduler.load_state_dict(checkpoint['sdf_scheduler_state_dict'])
        nrf.load_state_dict(checkpoint['nrf_model_state_dict'])
        nrf_optimizer.load_state_dict(checkpoint['nrf_optimizer_state_dict'])
        nrf_scheduler.load_state_dict(checkpoint['nrf_scheduler_state_dict'])
        start_iter = checkpoint['it'] + 1

    npass = render_spec['npass']

    for it in range(start_iter, max_iter + 1):
        # enoki.cuda_malloc_trim()
        print_gpu_usage(f'at the start of iter {it}')
        sdf_optimizer.zero_grad()
        if not disable_nrf:
            nrf_optimizer.zero_grad()

        sensor_indices = active_sensors(batch_size, num_scenes, num_sensors, vis_sensor_indices, weighted_sensor_indices, sensor_weight)

        t0 = time.time()
        isoSurface = IsoSurface(sdf, bbox, scaling=scaling)
        vtx, faces = isoSurface()
        t1 = time.time()
        mc_time = t1 - t0
        print('[INFO] time spent on performing marching cube: {:2f}'.format(mc_time))
        write_obj(vtx.detach().cpu().numpy(), faces.cpu().numpy(), opt_obj)

        vtx_bsdf_params = 0.5 * (nrf(vtx) + 1.0) * (bsdf_param_max - bsdf_param_min) + bsdf_param_min
        if diffuse_only:
            vtx_bsdf_params[:, 3:] = 0.0

        total_render_time = 0.0
        total_image_loss = 0.0
        total_sil_loss =0.0
        total_eikonal_loss = 0.0
        total_loss = 0.0
        
        t0 = time.time()
        for sensor_id in sensor_indices:
            scene_id = sensor_id // num_sensors
            _sensor_id = sensor_id % num_sensors
            opt_xml = os.path.join(scene_dir, f'opt_{scene_id}.xml')
            opt_scene = psdr_cuda.Scene()
            opt_scene.load_file(opt_xml, False)
            opt_scene.opts.spp = render_spec['spp']
            opt_scene.opts.sppe = render_spec['sppe']
            opt_scene.opts.sppse = render_spec['sppse']
            opt_scene.opts.log_level = render_spec['log_level']
            for ipass in range(npass):
                # Image loss
                opt_image, opt_sil_image, vtx_pos, render_time = render_scene(opt_scene, integrator, sil_integrator, vtx, vtx_bsdf_params, _sensor_id)
                total_render_time += render_time

                img_loss = rgb_loss(target_images[sensor_id], opt_image)
                sil_loss = rgb_loss(target_sil_images[sensor_id], opt_sil_image)
                img_loss /= num_elm
                sil_loss /= num_elm

                # Eikonal loss
                eikonal_loss = torch.zeros(1, device='cuda')
                n_eik_points = 4096
                eikonal_points = torch.empty(n_eik_points, 3).uniform_(-3.0, 3.0).cuda()
                points_all = torch.cat([vtx_pos, eikonal_points], dim=0)
                g = sdf.gradient(points_all)
                surface_points_grad = g[:N, 0, :].clone().detach()
                grad_theta = g[:, 0, :]
                eikonal_loss = ((grad_theta.norm(2, dim=1) - 1) ** 2).mean()

                loss = img_weight * img_loss + eikonal_weight * eikonal_loss + sil_weight * sil_loss
                total_image_loss += img_loss.detach().cpu().item()
                total_sil_loss += sil_loss.detach().cpu().item()
                total_eikonal_loss += eikonal_loss.detach().cpu().item()
                total_loss += loss.detach().cpu().item()

                loss.backward(retain_graph=True)
            del opt_scene
            enoki.cuda_malloc_trim()

        t1 = time.time()
        total_time = t1 - t0

        print(f'[Info] time spent on rendering opt images: {total_render_time:3f}s (avg time per view: {(total_render_time / batch_size):3f}s)')
        print(f'[Info] total time spent: {total_time:3f}s')
        print('[INFO] iter = %d, image_loss = %.4e, eikonal_loss = %.4e, silhouette_loss = %.4e, total_loss = %.4e' % (it, total_image_loss, total_eikonal_loss, total_sil_loss, total_loss))
        
        current_opt_imgs = render_visualization(vtx_bsdf_params, scene_dir, _vis_scene_indices, _vis_sensor_indices, image_shape, render_spec)
        current_target_imgs = [target_images[sensor_id] for sensor_id in vis_sensor_indices]
        
        result_image = make_grid(np.stack([torch.cat(current_target_imgs, dim=1).cpu().numpy(), torch.cat(current_opt_imgs, dim=1).detach().cpu().numpy()], axis=0), 1)
        write_png(f'{result_dir}/iter_' +'{0}'.format(it), result_image)

        sdf_optimizer.step()
        if not disable_nrf:
            nrf_optimizer.step()
        sdf_scheduler.step()
        nrf_scheduler.step()

        if it % checkpoint_iter == 0:
            torch.save({
                'it': it,
                'sdf_model_state_dict': sdf.state_dict(),
                'sdf_optimizer_state_dict': sdf_optimizer.state_dict(),
                'sdf_scheduler_state_dict': sdf_scheduler.state_dict(),
                'nrf_model_state_dict': nrf.state_dict(),
                'nrf_optimizer_state_dict': nrf_optimizer.state_dict(),
                'nrf_scheduler_state_dict': nrf_scheduler.state_dict(),
                'image_loss': total_image_loss,
                'eikonal_loss': total_eikonal_loss,
                'sil_loss': total_sil_loss,
                'total_loss': total_loss,
                'total_render_time': total_render_time,
                'total_time': total_time,
                'mc_time': mc_time
            }, os.path.join(checkpoint_dir, f'checkpoint_{it}.pt'))
            write_obj(vtx.detach().cpu().numpy(), faces.cpu().numpy(), os.path.join(checkpoint_dir, 'sdf.obj'))
            print('[INFO] checkpoint is updated')


def main():
    parser = argparse.ArgumentParser(description='Optimizing neural SDF to match the target image')
    args = parser.parse_args()

    # Run.
    optimize()

    # Done.
    print("Done.")

#----------------------------------------------------------------------------

if __name__ == "__main__":
    main()