import argparse
import dataclasses
import logging
from typing import Dict, List

import numpy as np
from pypsdr.common import gen_cameras, gen_random_rotations
from pypsdr.optimize import Camera, gen_camera_positions, gen_camera_positions1
import gin
from pypsdr.renderer import Model
from pypsdr.utils.io import imwrite, mkdir
from pypsdr.utils.timer import Timer
from pypsdr.validate import RenderOptions, Scene
import psdr_cpu
import os
from pypsdr.render import render_save_multi_views
from random import seed
logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.INFO)


@gin.configurable
# generate the constructor
@dataclasses.dataclass
class PreprocessRunner:
    scene_file: str
    out_dir: str = "./"
    integrator: psdr_cpu.Integrator = psdr_cpu.Direct()
    pose_setting:  Dict = dataclasses.field(default_factory=lambda: {})

    def update_scene(self, param_map):
        for key in param_map:
            exec("self.scene." + key +
                 " = param_map[key]")
        # update relavent scene data structure
        self.scene.configure()

    def render(self, context):
        pose_setting = context["pose_setting"]
        render_param_map = {k: v for k, v in context["param_map"].items()}
        if pose_setting:
            key = pose_setting["id"]
            # validate the pose_setting
            poses = pose_setting["poses"]
            assert(key in context["param_map"])
            assert(key.endswith("vertices"))  # make sure it's a shape
            # get the rotation
            pose_id = pose_setting["pose_id"]
            rotation = poses[pose_id]
            assert(rotation.shape == (3, 3))  # make sure it's a rotation
            # transform the vertices
            V = context["param_map"][key]  # get the vertices
            # NOTE numerical precision issue
            with Timer("transform vertices") as timer:
                V = (V.astype(np.float64) @ rotation.T.astype(np.float64)).astype(np.float32)  # in the computation graph
            render_param_map[key] = V  # update the param_map
        self.update_scene(render_param_map)
        image = context['integrator'].renderC(self.scene, context["options"])
        return image.reshape((self.scene.camera.height, self.scene.camera.width, 3))

    def run(self):
        psdr_cpu.set_verbose(True)
        self.scene = Scene(self.scene_file)
        self.sceneAD = psdr_cpu.SceneAD(self.scene)
        self.width = self.scene.camera.width
        self.height = self.scene.camera.height
        self.fov = self.scene.camera.fov
        with gin.config_scope('tar'):
            self.options = RenderOptions(gin.REQUIRED)
        # camera_positions = gen_camera_positions(
        #     batch=self.batch_size, radius=self.radius, hemisphere=self.hemisphere)
        seed(1)
        np.random.seed(0)
        assert(self.pose_setting)
        # generate random pose
        pose_setting = self.pose_setting
        if type(pose_setting['poses']) is str:
            pose_setting['poses'] = np.load(pose_setting['poses'])
        elif type(pose_setting['poses']) is int:
            pose_setting['poses'] = gen_random_rotations(pose_setting['poses'])
        mkdir(os.path.dirname(self.pose_setting['filename']))
        np.save(self.pose_setting['filename'], pose_setting['poses'])
        param_map = {pose_setting['id']: np.array(
            eval("self.scene." + pose_setting['id']))}

        context = {
            "integrator": self.integrator,
            "options": self.options,
            "pose_setting": self.pose_setting,
            "param_map": param_map
        }

        mkdir(self.out_dir)
        mkdir(os.path.join(self.out_dir, "target"))
        for i in range(len(pose_setting['poses'])):
            pose_setting.update({'pose_id': i})
            image = self.render(context)
            imwrite(image, os.path.join(os.path.join(
                self.out_dir, "target"), "pose_{}.exr".format(i)))


if __name__ == "__main__":
    seed(1)
    default_config = './pig_pixel/base.conf'
    parser = argparse.ArgumentParser(
        description='Script for generating validation results')
    parser.add_argument('config_file', metavar='config_file',
                        type=str, nargs='?', default=default_config, help='config file')
    args, unknown = parser.parse_known_args()
    if os.path.dirname(args.config_file) != '':
        os.chdir(os.path.dirname(args.config_file))
    gin.add_config_file_search_path(os.getcwd())
    # Dependency injection: Arguments are injected into the function from the gin config file.
    gin.parse_config_file(os.path.basename(
        args.config_file), skip_unknown=True)
    runner = PreprocessRunner()
    runner.run()
