from copy import deepcopy
from typing import List
import gin
from pypsdr.plot import write_obj, get_mesh_error
import psdr_cpu
import dataclasses
from dataclasses import field
from pypsdr.validate import *
from pypsdr.utils.io import *
from pypsdr.loss import compute_render_loss, uniform_laplacian
from pypsdr.optimize import Camera
from pypsdr.optimizer import LGDescent,sparse_eye
from pypsdr.renderer import Render
from pypsdr.common import gen_cameras
import igl
os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1"

from random import randrange, seed

import scipy
import scipy.sparse.linalg
import datetime

render_target_image = False
render_iter_image = True


@gin.configurable
# generate the constructor
@dataclasses.dataclass
class TrainRunner:
    scene_init_file: str = "./kitty/ini.xml"
    scene_target_file: str = "./kitty/tar.xml"
    shape_id: int = 0  # shape to be optimized
    niter: int = 500
    camera_file: str = "./kitty/cam_pos.txt"
    options: psdr_cpu.RenderOptions = None
    guiding_options: GuidingOptions = None
    out_dir: str = "./output"          # output directory
    # directory that contains target images
    target_dir: str = "./kitty/output/kitty/target"
    test_indices: List = field(
        default_factory=lambda: [0, 14, 32, 57])
    mala_options: psdr_cpu.MALAOptions = psdr_cpu.MALAOptions()

    lr: float = 1.0
    lmbda: float = 500.0
    batch_size: int = 10
    print_size: int = 10



    integrator: psdr_cpu.Integrator = psdr_cpu.Path2()

    def __post_init__(self):
        return
        
    def run(self):
        self.out_dir = "./output/correlation0"
        mkdir(self.out_dir)
        mkdir(os.path.join(self.out_dir, "iter"))
        mkdir(os.path.join(self.out_dir, "results"))
        mkdir(os.path.join(self.out_dir, "results", "iter"))
        self.scene_init = Scene(self.scene_init_file)
        self.width = self.scene_init.camera.width
        self.height = self.scene_init.camera.height
        self.fov = self.scene_init.camera.fov
        camera_pos = np.loadtxt(self.camera_file)
        camera_info = Camera(target=gin.REQUIRED)
        self.cameras = gen_cameras(positions=camera_pos,
                                   target=camera_info.target,
                                   up = camera_info.up,
                                   fov=camera_info.fov,
                                   resolution=[self.width, self.height],
                                   type=1)
        self.sceneAD = psdr_cpu.SceneAD(self.scene_init)
        self.mesh_integrator = psdr_cpu.Path2()
        psdr_cpu.set_verbose(True)
        
        pos_id = 0
        
        tar_integrator = psdr_cpu.Path2()
        global_out_dir = self.out_dir
        show_mesh_options = RenderOptions(gin.REQUIRED)
        # self.options.max_bounces = 5
        show_mesh_options.max_bounces = 3
        show_mesh_options.spp = 1024
        with gin.config_scope('tar'):
            tar_options = RenderOptions(gin.REQUIRED)
        # tar_options.spp = 128
        
        render_camera = self.cameras[1]
        scene_tar = Scene('../data/scenes/correlation/tar.xml')
        scene_tar.camera = render_camera
        mesh_tar = scene_tar.shapes[0]
        tar_vertices = np.array(mesh_tar.vertices_world)
        tar_faces = np.array(mesh_tar.indices)
        
        mask_integrator = psdr_cpu.Mask()
        write_obj(mesh_tar.vertices_world, mesh_tar.indices, '../data/scenes/show_mesh/mesh.obj')
        scene_show_target_mesh = Scene('../data/scenes/show_mesh/scene.xml', {'mesh_file': '../data/scenes/show_mesh/mesh.obj'})
        show_vertices_tar = np.array(scene_show_target_mesh.shapes[0].vertices)
        mean = np.mean(show_vertices_tar, axis=0)
        scale = np.max(np.abs(show_vertices_tar - mean)) * 1.5
        print(mean, scale)
        show_vertices_tar = (show_vertices_tar - mean) / scale
        
        scene_init = Scene('../data/scenes/correlation/init.xml')
        scene_init.camera = render_camera
        mesh_init = scene_init.shapes[0]
        init_vertices = np.array(mesh_init.vertices)
        init_faces = np.array(mesh_init.indices)
        
        write_obj(mesh_init.vertices_world, mesh_init.indices, '../data/scenes/show_mesh/mesh.obj')
        scene_show_init_mesh = Scene('../data/scenes/show_mesh/scene.xml', {'mesh_file': '../data/scenes/show_mesh/mesh.obj'})
        show_vertices_init = np.array(scene_show_init_mesh.shapes[0].vertices)
        show_vertices_init = (show_vertices_init - mean) / scale
        # print(show_vertices.shape)
        # print(np.array(scene_show_target_mesh.shapes[0].vertices).shape)
        # exit()
        scene_show_target_mesh.shapes[0].setVertices(show_vertices_tar)
        scene_show_target_mesh.configure()
        scene_show_init_mesh.shapes[0].setVertices(show_vertices_init)
        scene_show_init_mesh.configure()
        
        if render_target_image:
            scene_tar.configure()
            image_tar = tar_integrator.renderC(scene_tar, tar_options)\
                .reshape(self.height, self.width, 3)
            mesh_image_tar = self.mesh_integrator.renderC(scene_show_target_mesh, show_mesh_options)\
                .reshape(self.height, self.width, 3)
            mesh_image_mask = mask_integrator.renderC(scene_show_target_mesh, show_mesh_options)\
                .reshape(self.height, self.width, 3)
            mesh_image_tar = mesh_image_tar * mesh_image_mask + np.ones_like(mesh_image_tar) * (1 - mesh_image_mask)
            imwrite(np.concatenate([image_tar, mesh_image_tar], axis=1),
                    os.path.join(global_out_dir, "results", "tar.exr"))
            
            scene_init.configure()
            image_init = tar_integrator.renderC(scene_init, tar_options)\
                .reshape(self.height, self.width, 3)
            mesh_image_init = self.mesh_integrator.renderC(scene_show_init_mesh, show_mesh_options)\
                .reshape(self.height, self.width, 3)
            mesh_image_mask = mask_integrator.renderC(scene_show_init_mesh, show_mesh_options)\
                .reshape(self.height, self.width, 3)
            mesh_image_init = mesh_image_init * mesh_image_mask + np.ones_like(mesh_image_init) * (1 - mesh_image_mask)
            imwrite(np.concatenate([image_init, mesh_image_init], axis=1),
                    os.path.join(global_out_dir, "results", "init.exr"))
        else:
            image_tar = imread(os.path.join(global_out_dir, "results", "tar.exr"))[:, :256, :]
        # exit()
        seed(1)
        error = []
        mesh_error = []
        
        def render_and_loss(obj_filename, render = True):
            iter_v, _, _, iter_f, _, _ = igl.read_obj(obj_filename)
            write_obj(iter_v, iter_f, "../data/scenes/correlation/mesh.obj")
            scene_iter = Scene('../data/scenes/correlation/init.xml')
            scene_iter.camera = render_camera
            mesh_iter = scene_iter.shapes[0]
            iter_v = np.array(mesh_iter.vertices_world)
            iter_f = np.array(mesh_iter.indices)
            write_obj(iter_v, iter_f, '../data/scenes/show_mesh/mesh.obj')
            scene_show_iter_mesh = Scene('../data/scenes/show_mesh/scene.xml', 
                                         {'mesh_file': '../data/scenes/show_mesh/mesh.obj'})
            show_vertices = np.array(scene_show_iter_mesh.shapes[0].vertices)
            show_vertices = (show_vertices - mean) / scale
            scene_show_iter_mesh.shapes[0].setVertices(show_vertices)
            scene_show_iter_mesh.configure()
            self.options.seed = iter
            
            scene_iter.configure()
            
            # render
            if render:
                image_iter = self.integrator.renderC(scene_iter, self.options)\
                    .reshape(self.height, self.width, 3)
                mesh_image_iter = self.integrator.renderC(scene_show_iter_mesh, show_mesh_options)\
                    .reshape(self.height, self.width, 3)
                mesh_image_mask = mask_integrator.renderC(scene_show_iter_mesh, show_mesh_options)\
                    .reshape(self.height, self.width, 3)
                mesh_image_iter = mesh_image_iter * mesh_image_mask + np.ones_like(mesh_image_iter) * (1 - mesh_image_mask)
            
                # rmse = np.sqrt(np.mean((image_iter - image_tar) ** 2))
                mesh_loss = get_mesh_error(iter_v, iter_f, tar_vertices, tar_faces)
                return image_iter, mesh_image_iter, None, mesh_loss
            else:
                mesh_loss = get_mesh_error(iter_v, iter_f, tar_vertices, tar_faces)
                return None, None, None, mesh_loss
                
        
        # self.out_dir = "./output/correlation0"
        global_iter = 0
        for iter in np.arange(0, 601, 10):
            print(iter, global_iter)
            savefig = iter in [200, 400, 600]
            image_iter, mesh_image_iter, loss, mesh_loss = render_and_loss(os.path.join(self.out_dir, "iter", "iter_%d.obj" % iter), savefig)
            # print stat
            if savefig:
                imwrite(np.concatenate([image_iter, mesh_image_iter], axis=1),
                        os.path.join(global_out_dir, "results", "iter", "iter_%d.exr" % (global_iter + iter)))
            # error.append(loss)
            mesh_error.append(mesh_loss)
            # np.savetxt(global_out_dir+"/results/img_loss.log", error)
            np.savetxt(global_out_dir+"/results/mesh_loss.log", mesh_error)
            
            


import argparse
if __name__ == "__main__":
    default_config = './correlation_configs/correlation0.conf'
    parser = argparse.ArgumentParser(
        description='Script for generating validation results')
    parser.add_argument('config_file', metavar='config_file',
                        type=str, nargs='?', default=default_config, help='config file')
    args, unknown = parser.parse_known_args()
    # Dependency injection: Arguments are injected into the function from the gin config file.
    gin.parse_config_file(args.config_file, skip_unknown=True)
    train_runner = TrainRunner()
    train_runner.run()