import argparse
from typing import Dict, List
import gin
from pypsdr.plot import get_mesh_error, write_obj
from pypsdr.renderer import Model
import psdr_cpu
import dataclasses
from dataclasses import field
from pypsdr.validate import *
from pypsdr.utils.io import *
from pypsdr.loss import compute_render_loss, uniform_laplacian
from pypsdr.optimize import Camera
from pypsdr.optimizer import LGDescent, sparse_eye
from pypsdr.renderer import Model, Render
from pypsdr.common import gen_cameras
from pypsdr.largesteps.parameterize import from_differential, to_differential
from pypsdr.largesteps.geometry import compute_matrix
from pypsdr.largesteps.optimize import AdamUniform
from random import randrange, seed
from torch.utils.tensorboard import SummaryWriter
import scipy
import scipy.sparse.linalg
import datetime
import logging

logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.INFO)


@gin.configurable
# generate the constructor
@dataclasses.dataclass
class TrainRunner:
    scene_init_file: str = "./kitty/ini.xml"
    scene_target_file: str = "./kitty/tar.xml"
    niter: int = 500
    options: psdr_cpu.RenderOptions = None
    spp_dr: int = None
    out_dir: str = "./output"          # output directory
    # directory that contains target images
    target_dir: str = "./kitty/output/kitty/target"
    test_indices: List = field(
        default_factory=lambda: [0, 14, 32, 57])
    img_threshold: float = 2.
    batch_size: int = 10
    print_size: int = 10
    train_indices: np.array = None
    paramters: List = field(default_factory=lambda: [])
    outputs: List = field(default_factory=lambda: [])
    camera_file: str = ""
    integrator: psdr_cpu.Integrator = psdr_cpu.Direct()
    parameters: Dict = field(default_factory=lambda: {})
    pose_setting: Dict = field(default_factory=lambda: {})
    scene_config : Dict = None

    def __post_init__(self):
        mkdir(self.out_dir)
        mkdir(os.path.join(self.out_dir, "iter"))
        self.scene_init = Scene(self.scene_init_file, self.scene_config)
        self.scene_tar = Scene(self.scene_target_file)
        self.width = self.scene_init.camera.width
        self.height = self.scene_init.camera.height
        self.fov = self.scene_init.camera.fov
        self.train_images = np.array(read_images(self.target_dir))
        if self.pose_setting:
            test_indices = self.pose_setting["test_indices"]
            self.test_images = self.train_images[test_indices]
            self.test_image = np.concatenate(self.test_images, axis=1)
        else:
            self.test_images = self.train_images[self.test_indices]
            self.test_image = np.concatenate(self.test_images, axis=1)
        imwrite(self.test_image,
                os.path.join(self.out_dir, "test.png"))

        # bitmap = self.scene_tar.bsdfs[2].reflectance
        # bitmap = np.array(bitmap.m_data).reshape(
        #     bitmap.m_res[1], bitmap.m_res[0], 3)
        # imwrite(bitmap, os.path.join(
        #         self.out_dir, "text.exr"))

        self.train_images = torch.from_numpy(self.train_images)
        self.test_images = torch.from_numpy(self.test_images)

        if self.train_indices:
            self.train_indices = np.array(self.train_indices)
            assert(self.train_indices[self.train_indices >= len(
                self.train_images)].size == 0)
        else:
            self.train_indices = np.arange(len(self.train_images))

        self.cameras = [self.scene_init.camera]

        if self.camera_file:
            camera_pos = np.loadtxt(self.camera_file)
            camera_info = Camera(target=gin.REQUIRED)
            self.cameras = gen_cameras(positions=camera_pos,
                                    target=camera_info.target,
                                    up=camera_info.up,
                                    fov=camera_info.fov,
                                    resolution=[self.width, self.height],
                                    type = camera_info.type)

        if self.pose_setting:
            self.pose_setting['poses'] = np.load(self.pose_setting['filename'])
            if 'train_indices' not in self.pose_setting:
                self.pose_setting['train_indices'] = np.arange(
                    len(self.pose_setting['poses']))
            assert('test_indices' in self.pose_setting)

        self.sceneAD = psdr_cpu.SceneAD(self.scene_init)
        context = {
            "integrator": self.integrator,
            "options": self.options,
            "sceneAD": self.sceneAD,
            "cameras": self.cameras
        }
        self.model = Model(self.sceneAD, self.parameters, context=context)
        # psdr_cpu.set_verbose(True)
        self.writer = SummaryWriter(log_dir=self.out_dir)


    def run(self):
        # setup optimizers
        print('Starting optimization')
        seed(1)
        np.random.seed(0)
        # generate mini-batches
        idx_shuffled = np.array([], dtype=np.int32)
        for i in range(self.batch_size * self.niter // len(self.train_indices) + 1):
            shuffled = self.train_indices.copy()
            np.random.shuffle(shuffled)
            idx_shuffled = np.concatenate((idx_shuffled, shuffled))

        # update the train index tape
        if self.pose_setting:
            train_indices = self.pose_setting['train_indices']
            nepoch = self.batch_size * \
                self.niter // len(train_indices) + 1
            idx_shuffled = np.array([], dtype=np.int32)
            for i in range(nepoch):
                shuffled = self.train_indices.copy()
                np.random.shuffle(shuffled)
                idx_shuffled = np.concatenate((idx_shuffled, shuffled))
            self.pose_setting['idx_shuffled'] = idx_shuffled

        error = []
        mesh_errors = []
        niter = 0
        for iter in range(self.niter):
            now = datetime.datetime.now()
            self.model.zero_grad()
            loss = 0.0
            indices = idx_shuffled[iter *
                                   self.batch_size: (iter + 1) * self.batch_size]
            for j in range(self.batch_size):
                # set sensor id
                sensor_id = indices[j]
                # set pose id
                pose_setting = {k: v for k, v in self.pose_setting.items()}
                if pose_setting:
                    sensor_id = 0
                    pose_setting['pose_id'] = pose_setting['idx_shuffled'][niter]
                    print("rendering pose: ", pose_setting['pose_id'], "\n")
                # render
                img = self.model.render(context={"sensor_id": sensor_id,
                                                 "pose_setting": pose_setting,
                                                 "options": RenderOptions(iter,
                                                                          gin.REQUIRED,
                                                                          gin.REQUIRED,
                                                                          gin.REQUIRED),
                                                 "bwd_options": RenderOptions(iter,
                                                                              gin.REQUIRED,
                                                                              self.spp_dr if self.spp_dr else gin.REQUIRED,
                                                                              gin.REQUIRED)})
                if pose_setting:
                    tar = self.train_images[pose_setting['pose_id']]
                else:
                    tar = self.train_images[sensor_id]
                imwrite(img.detach().numpy(), os.path.join(
                    self.out_dir, "init.png"))
                imwrite(tar.detach().numpy(), os.path.join(
                    self.out_dir, "tar.png"))
                # clamping
                img = img.clamp(0., self.img_threshold)
                tar = tar.clamp(0., self.img_threshold)
                img_loss = compute_render_loss(
                    img, tar, 1.0)
                print("image_loss: ", img_loss.item())
                loss = img_loss.item()
                img_loss.backward()
                print("Rendering camera:", sensor_id,
                      "loss:", img_loss.item(), end='\r')
                niter += 1


            # step
            self.model.step()

            # TODO update scene with new parameters

            # print stat
            # print("grad: ", np.abs(u.grad.cpu()).sum().item())
            end = datetime.datetime.now() - now
            print("rendered camera:", indices)
            print('Iteration:', iter,
                  'Loss:', loss,
                  'Total time:', end.seconds + end.microseconds / 1000000.0)
            error.append(loss)

            V_t = np.array(self.scene_tar.shapes[0].vertices)
            F_t = np.array(self.scene_tar.shapes[0].indices)
            V_i = np.array(self.sceneAD.val.shapes[0].vertices)
            F_i = np.array(self.sceneAD.val.shapes[0].indices)
            mesh_error = get_mesh_error(V_i, F_i, V_t, F_t)
            mesh_errors.append(mesh_error)
            self.writer.add_scalar("image loss", loss, iter)
            self.writer.add_scalar("mesh error", mesh_error, iter)
            np.savetxt(os.path.join(self.out_dir, "loss.log"), error)
            np.savetxt(os.path.join(self.out_dir, "error.log"), mesh_errors)
            # TODO save parameters

            # TODO write files
            # self.model.save()
            # bitmap = self.sceneAD.val.bsdfs[2].reflectance
            # bitmap = np.array(bitmap.m_data).reshape(
            #     bitmap.m_res[1], bitmap.m_res[0], 3)
            # imwrite(bitmap, os.path.join(
            #     self.out_dir, "iter", "text_%04d.exr" % iter))

            for o in self.outputs:
                if o['type'] == "obj":
                    V = eval("self.sceneAD.val.%s" % o['V'])
                    F = eval("self.sceneAD.val.%s" % o['F'])
                    write_obj(V, F, os.path.join(self.out_dir,
                              "iter", o['prefix'] + "_%04d.obj" % iter))
                elif o['type'] == "shape":
                    exec("self.sceneAD.val.%s.save(%s)" % o['shape'],
                         os.path.join(self.out_dir, "iter", o['prefix'] + "_%04d.obj" % iter))
                elif o['type'] == "texture":
                    bitmap = eval("self.sceneAD.val.%s" % o['bitmap'])
                    bitmap = np.array(bitmap.m_data).reshape(
                        bitmap.m_res[1], bitmap.m_res[0], 3)
                    imwrite(bitmap, os.path.join(
                        self.out_dir, "iter", "iter_%04d.exr" % iter))

            if iter % self.print_size == 0:
                test_images = []
                if self.pose_setting:
                    for test_id in self.pose_setting['test_indices']:
                        print("Testing pose:", test_id, end='\r')
                        self.pose_setting.update({'pose_id': test_id})
                        test_img = self.model.render(
                            context={"sensor_id": 0,
                                     "pose_setting": self.pose_setting,
                                     "options": RenderOptions(iter, gin.REQUIRED, 128, 0, 0)}).detach()
                        test_images.append(test_img)
                else:
                    for test_id in self.test_indices:
                        print("Testing camera:", test_id, end='\r')
                        test_img = self.model.render(
                            context={"sensor_id": test_id,
                                     "options": RenderOptions(iter, gin.REQUIRED, 128, 0, 0)}).detach()
                        test_images.append(test_img)
                test_image = np.concatenate(test_images, axis=1)
                imwrite(np.concatenate([self.test_image, test_image], axis=0),
                        os.path.join(self.out_dir, "iter", "iter_%04d.png" % iter))


if __name__ == "__main__":
    default_config = './terrain/with_anti.conf'
    parser = argparse.ArgumentParser(
        description='Script for generating validation results')
    parser.add_argument('config_file', metavar='config_file',
                        type=str, nargs='?', default=default_config, help='config file')
    args, unknown = parser.parse_known_args()
    # Dependency injection: Arguments are injected into the function from the gin config file.
    if os.path.dirname(args.config_file) != '':
        os.chdir(os.path.dirname(args.config_file))
    gin.add_config_file_search_path(os.getcwd())
    gin.parse_config_file(os.path.basename(
        args.config_file), skip_unknown=True)
    train_runner = TrainRunner()
    train_runner.run()
