# Copyright @yucwang 2022

import os
os.add_dll_directory(os.path.join(os.environ['CUDA_PATH'], 'bin'))

import sys
sys.path.insert(1, "../../")

import argparse
import dataclasses
from typing import List

import TensorRay as TR
import numpy as np
from pyTensorRay.multi_view import gen_camera, gen_cameras, gen_camera_positions, render_save_multi_views, gen_camera_positions_stratified
from pyTensorRay.utils import image_tensor_to_torch, save_torch_image, update_render_batch_options, load_torch_image
from pyTensorRay.fwd import Camera, RenderOptions, Scene
from pyTensorRay.common import mkdir
import gin
from random import seed
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap


@gin.configurable
@dataclasses.dataclass
class PreprocessConfig:
    render_options: TR.RISRenderOptions
    scene_file: str
    camera_info_file: str = ""
    batch_size: int = 1
    radius: float = 1.0
    out_dir: str = "./"
    on_hemisphere: bool = False
    integrator: TR.Integrator = TR.Path2()
    camera_info: Camera = Camera()

def run(preprocess_config):
    scene = Scene(preprocess_config.scene_file)
    scene.configure()
    width = scene.get_width(0)
    height = scene.get_height(0)
    options = preprocess_config.render_options
    print(width)
    if preprocess_config.camera_info_file == "":
        # camera_positions = gen_camera_positions(
        #     batch=preprocess_config.batch_size, 
        #     radius=preprocess_config.radius, 
        #     on_hemisphere=preprocess_config.on_hemisphere)
        camera_positions = gen_camera_positions_stratified(
            batch=8,
            view_count_sqrt=2,
            radius=preprocess_config.radius
        )
        for c in camera_positions:
            c += preprocess_config.camera_info.target

        cameras = gen_cameras(positions=camera_positions,
                              target=preprocess_config.camera_info.target,
                              up=preprocess_config.camera_info.up,
                              fov=preprocess_config.camera_info.fov,
                              resolution=[width, height])
        cameras_info = np.zeros(shape=(len(cameras), 10))
        cameras_info[:,0:3] = camera_positions
        cameras_info[:,3:6] = preprocess_config.camera_info.target
        cameras_info[:,6:9] = preprocess_config.camera_info.up
        cameras_info[:,9] = preprocess_config.camera_info.fov
    else:
        cameras_info = np.loadtxt(preprocess_config.camera_info_file)
        cameras = []
        camera_positions = cameras_info[:,0:3]
        for i in range(cameras_info.shape[0]):
            cameras.append(gen_camera(cameras_info[i,0:3], cameras_info[i,3:6], 
                            cameras_info[i,6:9], cameras_info[i,9], [width, height]))

    mkdir(preprocess_config.out_dir)
    np.savetxt(os.path.join(preprocess_config.out_dir, "cam_pos.txt"), cameras_info)
    mkdir(os.path.join(preprocess_config.out_dir, "target"))
    render_save_multi_views(scene, cameras, options,
                            preprocess_config.integrator, os.path.join(preprocess_config.out_dir, "target"))

def img_remap(img, path):
    img = img[:,:,0]
    print(img.max())
    color_map = LinearSegmentedColormap.from_list("cubicL", np.loadtxt("./CubicL.txt"), N=256)
    f, ax = plt.subplots(1)
    im = ax.imshow(img, cmap=color_map, interpolation='bilinear', vmin=-0.9, vmax=0.9)
    ax.axis('off')
    f.savefig(path, bbox_inches='tight', pad_inches = 0)

def run_mesh_vis(preprocess_config, data_dir):
    n_iter = 256
    bsdf_id = 2
    camera_id = 1
    scene = Scene(preprocess_config.scene_file)
    height = scene.get_height(0)
    width = scene.get_width(0)
    cameras_info = np.loadtxt(preprocess_config.camera_info_file)
    camera_info = gen_camera(cameras_info[camera_id,0:3], cameras_info[camera_id,3:6], cameras_info[camera_id,6:9], cameras_info[camera_id,9], [width, height])
    scene.cameras[0].update(camera_info)
    scene.configure()
    for i in range(0, n_iter, 8):
        bsdf_file = os.path.join(data_dir, "iter/albedo_iter_{}.exr".format(i+1))
        print(str(bsdf_file))
        scene.bsdfs[bsdf_id].reload(str(bsdf_file))
        # scene.configure()
        # img = load_torch_image(os.path.join(data_dir, 'iter/albedo_iter_{}.exr'.format(i+1)))[66:106,55:75]
        # img = load_torch_image(os.path.join(data_dir, 'texture_target.exr'))[66:106,55:75]
        # img -= 0.6
        img = preprocess_config.integrator.renderC(scene, preprocess_config.render_options)
        img = image_tensor_to_torch(img, height, width)
        # print('{} {}'.format(img.max(), img.min()))
        # img = (img + 44.1) / (1.5841)
        preprocess_config.integrator.step()
        # img_remap(img, os.path.join(preprocess_config.out_dir, "iter_{}.png".format(i)))
        save_torch_image(os.path.join(preprocess_config.out_dir, 'iter_{}.exr'.format(i)), img)
        print("done: {}.".format(i))


if __name__ == "__main__":
    seed(1)
    default_config = './sphere_to_cube.conf'
    parser = argparse.ArgumentParser(
        description='Script for generating validation results')
    parser.add_argument('config_file', metavar='config_file',
                        type=str, nargs='?', default=default_config, help='config file')
    args, unknown = parser.parse_known_args()
    # Dependency injection: Arguments are injected into the function from the gin config file.
    gin.parse_config_file(args.config_file, skip_unknown=True)
    TR.env_create()
    preprocess_config = PreprocessConfig()
    # run_mesh_vis(preprocess_config, "./output/jupiter/ris_3/")
    run(preprocess_config)
    TR.env_release()
