import numpy as np
import torch, argparse, os, sys, re, time, shutil
import enoki
from enoki.cuda_autodiff import Float32 as FloatD, Vector3f as Vector3fD
import psdr_cuda
import subprocess
from utils import *
import neural_sdf
from meshSDF import *

def render_textured_obj(out_dir, intensity):
    with open('textured_template.xml', 'r') as f:
        scene_str = f.read().format(out_dir)
    scene = psdr_cuda.Scene()
    scene.load_string(scene_str, False)
    
    npass = 16               # number of passes
    scene.opts.spp = 16     # spp per pass
    scene.opts.log_level = 0
    scene.configure()
    
    integrator = psdr_cuda.CollocatedIntegrator(intensity)
    img = integrator.renderC(scene).numpy()
    for i in range(npass - 1):
        img += integrator.renderC(scene).numpy()
    img = (img / float(npass)).reshape((scene.opts.height, scene.opts.width, 3))
    img_path = os.path.join(out_dir, 'textured')
    write_exr(img_path, img)

def get_mesh(checkpoint, resolution, out_dir):
    os.makedirs(out_dir, exist_ok=True)

    checkpoint = torch.load(checkpoint)

    sdf  = neural_sdf.ImplicitNetwork(feature_vector_size = 256,
                                     d_in = 3,
                                     d_out = 1,
                                     dims = [ 512, 512, 512, 512, 512, 512, 512, 512 ],
                                     geometric_init = True,
                                     bias = 0.6,
                                     skip_in = [4],
                                     weight_norm = True,
                                     multires = 6)
    sdf = sdf.cuda()
    sdf.load_state_dict(checkpoint['sdf_model_state_dict'])

    bbox = BoundingBox(torch.tensor([-1.0, -1.0, -1.0], device='cuda'),
                       torch.tensor([ 1.0,  1.0,  1.0], device='cuda'),
                       torch.tensor([resolution, resolution, resolution], device='cuda'))
    
    isoSurface = IsoSurface(sdf, bbox)
    vtx, faces = isoSurface()
    write_obj(vtx.detach().cpu().numpy(), faces.cpu().numpy(), os.path.join(out_dir, 'sdf.obj'))

def gen_obj_uv(obj, out_dir, num_cones=8):
    os.makedirs(out_dir, exist_ok=True)
    obj_uv = os.path.join(out_dir, 'sdf_textured.obj')
    subprocess.run(['bff-command-line', obj, obj_uv, '--nCones=8', '--normalizeUVs'])

def get_textures(checkpoint, out_dir, resolution, multisample_factor):
    obj_uv = os.path.join(out_dir, 'sdf_textured.obj')
    obj_uv_flatten = os.path.join(out_dir, 'sdf_textured_flatten.obj')

    assert(os.path.exists(obj_uv))    
    os.makedirs(out_dir, exist_ok=True)

    checkpoint = torch.load(checkpoint)

    # Create a flatten version of the object
    mesh = psdr_cuda.Mesh()
    mesh.load(obj_uv)
    mesh.configure()
    
    vertex = mesh.vertex_positions.numpy()
    vertex_uv = mesh.vertex_uv.numpy()
    face_uv_indices = mesh.face_uv_indices.numpy()
    vertex_positions_flatten = np.c_[vertex_uv, np.zeros(vertex_uv.shape[0])]
    vertex_positions_old = np.zeros_like(vertex_positions_flatten)
    face_indices_old = mesh.face_indices.numpy()
    for vtx_id, uv_id in zip(face_indices_old.reshape(-1), face_uv_indices.reshape(-1)):
        vertex_positions_old[uv_id] = vertex[vtx_id]
    write_obj(vertex_positions_flatten, face_uv_indices, obj_uv_flatten)

    # Get textures
    nrf = neural_sdf.RenderingNetwork(d_in = 3,
                                      d_out = 7,
                                      dims = [512, 512, 512, 512, 512, 512, 512, 512],
                                      weight_norm = True,
                                      multires_point = 10)
    nrf = nrf.cuda()
    nrf.load_state_dict(checkpoint['nrf_model_state_dict'])
    vtx_bsdf_params = nrf(torch.from_numpy(vertex_positions_old).float().cuda())
    bsdf_param_max = torch.tensor([1, 1, 1, 1, 1, 1, 0.5]).float().cuda()
    bsdf_param_min = torch.tensor([0, 0, 0, 0, 0, 0, 0.1]).float().cuda()
    vtx_bsdf_params = 0.5 * (vtx_bsdf_params+ 1.0) * (bsdf_param_max - bsdf_param_min) + bsdf_param_min

    with open('flatten_template.xml', 'r') as f:
        flatten_xml = f.read().format(obj_uv_flatten)

    scene = psdr_cuda.Scene()
    scene.load_string(flatten_xml, False)
    scene.opts.log_level = 0
    scene.param_map["Mesh[0]"].vertex_diffuse = Vector3fD(vtx_bsdf_params[:, 0:3])        
    scene.param_map["Mesh[0]"].vertex_specular = Vector3fD(vtx_bsdf_params[:, 3:6])        
    scene.param_map["Mesh[0]"].vertex_roughness = FloatD(vtx_bsdf_params[:, 6])

    scene.configure()

    diffuse = Vector3fD(torch.zeros(resolution * resolution).cuda())
    specular = Vector3fD(torch.zeros(resolution * resolution).cuda())
    roughness = FloatD(torch.zeros(resolution * resolution).cuda())

    psdr_cuda.NRF.get_textures(scene, resolution, multisample_factor, diffuse, specular, roughness)

    # Write textures
    diffuse_texture = diffuse.numpy().reshape((resolution, resolution, 3))
    specular_texture = specular.numpy().reshape((resolution, resolution, 3))
    roughness_texture = roughness.numpy().reshape((resolution, resolution, 1))
    write_exr(os.path.join(out_dir, 'diffuse'), diffuse_texture)
    write_exr(os.path.join(out_dir, 'specular'), specular_texture)
    write_exr(os.path.join(out_dir, 'roughness'), roughness_texture)

    del mesh, diffuse, specular, roughness
    enoki.cuda_malloc_trim()

def get_textured_obj(checkpoint, out_dir):
    get_mesh(checkpoint, 96, out_dir)
    gen_obj_uv(os.path.join(out_dir, 'sdf.obj'), out_dir)
    get_textures(checkpoint, out_dir, 4096, 4)

def render_target(scene_dir, integrator_name, intensity=0):
    scene_files = [filename for filename in os.listdir(scene_dir) if 'target' in filename and 'xml' in filename]
    print('About to render the following target scene files: {}'.format(', '.join(scene_files)))
    num_targets = len(scene_files)
    sensor_id_offset = 0
    for i in range(num_targets):
        scene_fn = os.path.join(scene_dir, f'target_{i}.xml')
        scene = psdr_cuda.Scene()
        print(f'Rendering {scene_fn}...')
        scene.load_file(scene_fn, False)
        npass = 16               # number of passes
        scene.opts.spp = 32     # spp per pass
        scene.opts.log_level = 0
        scene.configure()
        num_sensors = scene.num_sensors
        imgs = [None] * num_sensors
        sil_imgs = [None] * num_sensors
        if integrator_name == 'direct':
            integrator = psdr_cuda.DirectIntegrator(1, 1, 0)
        elif integrator_name == 'collocated':
            integrator = psdr_cuda.CollocatedIntegrator(intensity)
        else:
            assert False, "Doesn't support this integrator."
        silhouette_integrator = psdr_cuda.FieldExtractionIntegrator("silhouette")
        # render the target image
        t0 = time.process_time()
        t1 = t0
        for i in range(npass):
            for sensor_id in range(num_sensors):
                img = integrator.renderC(scene, sensor_id)
                sil_img = silhouette_integrator.renderC(scene, sensor_id)
                if i == 0:
                    imgs[sensor_id] = img.numpy()
                    sil_imgs[sensor_id] = sil_img.numpy()
                else:
                    imgs[sensor_id] += img.numpy()
                    sil_imgs[sensor_id] += sil_img.numpy()
            t2 = time.process_time()
            if t2 - t1 > 0.2:
                print("(%d/%d) done in %.2f seconds." % (i + 1, npass, t2 - t0), end="\r")
                t1 = t2
        t2 = time.process_time()
        print("(%d/%d) Total orig. rendering time: %.2f seconds." % (npass, npass, t2 - t0))

        target_dir = os.path.join(scene_dir, 'target')
        target_exr_dir = os.path.join(target_dir, 'exr')
        target_png_dir = os.path.join(target_dir, 'png')
        target_sil_dir = os.path.join(scene_dir, 'target_sil')
        target_sil_exr_dir = os.path.join(target_sil_dir, 'exr')
        target_sil_png_dir = os.path.join(target_sil_dir, 'png')

        os.makedirs(target_exr_dir, exist_ok=True)
        os.makedirs(target_png_dir, exist_ok=True)
        os.makedirs(target_sil_exr_dir, exist_ok=True)
        os.makedirs(target_sil_png_dir, exist_ok=True)
        
        for _sensor_id in range(num_sensors):
            sensor_id = _sensor_id + sensor_id_offset
            imgs[_sensor_id] = (imgs[_sensor_id]/float(npass)).reshape((scene.opts.height, scene.opts.width, 3))
            sil_imgs[_sensor_id] = (sil_imgs[_sensor_id]/float(npass)).reshape((scene.opts.height, scene.opts.width, 3))
            
            write_exr(os.path.join(target_exr_dir, f'sensor_{sensor_id}'), imgs[_sensor_id])
            write_png(os.path.join(target_png_dir, f'sensor_{sensor_id}'), imgs[_sensor_id])
            write_exr(os.path.join(target_sil_exr_dir, f'sensor_{sensor_id}'), sil_imgs[_sensor_id])
            write_png(os.path.join(target_sil_png_dir, f'sensor_{sensor_id}'), sil_imgs[_sensor_id])
        sensor_id_offset += num_sensors

        # garbage collection
        del scene, integrator, silhouette_integrator 
        enoki.cuda_malloc_trim()

def gen_idr_data(scene_dir):
    idr_data_dir = os.path.join(scene_dir, 'idr_data')
    os.makedirs(idr_data_dir, exist_ok=True)
    idr_image_dir = os.path.join(idr_data_dir, 'image')
    idr_mask_dir = os.path.join(idr_data_dir, 'mask')
    if os.path.exists(idr_data_dir): shutil.rmtree(idr_data_dir)

    target_dir = os.path.join(scene_dir, 'target', 'png')
    target_sil_dir = os.path.join(scene_dir, 'target_sil', 'png')
    assert os.path.exists(target_dir), "Can't find target images."
    assert os.path.exists(target_sil_dir), "Can't find target sillhouette images."

    shutil.copytree(target_dir, idr_image_dir)
    shutil.copytree(target_sil_dir, idr_mask_dir)
    
    scene = psdr_cuda.Scene()
    scene.load_file(os.path.join(scene_dir, 'target.xml'), False)
    scene.opts.log_level = 0
    scene.configure()

    world_mats = []
    world_mat_names = []
    scale_mats = []
    scale_mat_names = []
    for sensor_id in range(scene.num_sensors):
        camera = scene.param_map['Sensor[{0}]'.format(sensor_id)]
        world_to_sample = camera.world_to_sample.numpy()[0]
        sample_to_image = np.array([[scene.opts.width, 0, 0],
                                    [0, scene.opts.height, 0],
                                    [0, 0, 1]])
        P = sample_to_image @ world_to_sample[:3, :]
        world_mat = np.r_[P, np.array([[0, 0, 0, 1]])]
        world_mats.append(world_mat)
        world_mat_names.append(f'world_mat_{sensor_id}')
        scale_mat = np.eye(4)
        scale_mats.append(scale_mat)
        scale_mat_names.append(f'scale_mat_{sensor_id}')

        old_image_name = f'sensor_{sensor_id}.png'
        new_image_name = f'{sensor_id:06}.png'
        os.rename(os.path.join(idr_image_dir, old_image_name), os.path.join(idr_image_dir, new_image_name))
        os.rename(os.path.join(idr_mask_dir, old_image_name), os.path.join(idr_mask_dir, new_image_name))
    np.savez(os.path.join(idr_data_dir, 'cameras'), **dict(zip(world_mat_names + scale_mat_names, world_mats + scale_mats)))



if __name__ == '__main__':
    if sys.argv[1] == 'render_textured_obj' and len(sys.argv) == 4:
        render_textured_obj(sys.argv[2], float(sys.argv[3]))
    elif sys.argv[1] == 'get_textured_obj' and len(sys.argv) == 4:
        get_textured_obj(sys.argv[2], sys.argv[3])
    elif sys.argv[1] == 'get_mesh' and len(sys.argv) == 5:
        get_mesh(sys.argv[2], int(sys.argv[3]), sys.argv[4])
    elif sys.argv[1] == 'gen_obj_uv' and len(sys.argv) == 4:
        gen_obj_uv(sys.argv[2], sys.argv[3])
    elif sys.argv[1] == 'gen_obj_uv' and len(sys.argv) == 5:
        gen_obj_uv(sys.argv[2], sys.argv[3], int(sys.argv[4]))
    elif sys.argv[1] == 'get_textures' and len(sys.argv) == 4:
        get_textures(sys.argv[2], sys.argv[3], 1024, 4)
    elif sys.argv[1] == 'render_target' and len(sys.argv) == 4:
        render_target(sys.argv[2], 'collocated', float(sys.argv[3]))
    elif sys.argv[1] == 'render_target_env' and len(sys.argv) == 3:
        render_target(sys.argv[2], 'direct')
    elif sys.argv[1] == 'gen_idr_data' and len(sys.argv) == 3:
        gen_idr_data(sys.argv[2])
    else:
        print('Command not found.')
