from typing import List
import gin
from pypsdr.plot import write_obj
import psdr_cpu
import dataclasses
from dataclasses import field
from pypsdr.validate import *
from pypsdr.utils.io import *
from pypsdr.loss import compute_render_loss, uniform_laplacian, L2_loss
from pypsdr.optimize import Camera
from pypsdr.optimizer import LGDescent,sparse_eye, LargeStepsOptimizer
from pypsdr.renderer import Render, RenderMi2
from pypsdr.common import gen_cameras, gen_cameras_mi
os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1"

from random import randrange, seed

import scipy
import scipy.sparse.linalg
import datetime

import drjit as dr
import mitsuba as mi
import matplotlib.pyplot as plt # We'll also want to plot some outputs
import os

import time


mi.set_variant('llvm_ad_rgb')

@gin.configurable
# generate the constructor
@dataclasses.dataclass
class TrainRunner:
    scene_init_file: str = "./kitty/ini.xml"
    scene_init_file_mi: str = "./kitty/ini.xml"
    scene_target_file: str = "./kitty/tar.xml"
    shape_id: int = 0  # shape to be optimized
    niter: int = 500
    options: psdr_cpu.RenderOptions = None
    guiding_options: GuidingOptions = None
    out_dir: str = "./output"          # output directory
    # directory that contains target images
    target_dir: str = "./kitty/output/kitty/target"
    test_indices: List = field(
        default_factory=lambda: [0, 14, 32, 57])

    lr: float = 1.0
    lmbda: float = 500.0
    batch_size: int = 10
    print_size: int = 10



    camera_file: str = "./kitty/cam_pos.txt"
    integrator: psdr_cpu.Integrator = psdr_cpu.Path2()
    mala_options: psdr_cpu.MALAOptions = psdr_cpu.MALAOptions()

    def __post_init__(self):
        mkdir(self.out_dir)
        mkdir(os.path.join(self.out_dir, "iter"))
        self.scene_init = Scene(self.scene_init_file)
        self.width = self.scene_init.camera.width
        self.height = self.scene_init.camera.height
        self.fov = self.scene_init.camera.fov
        self.train_images = np.array(read_images(self.target_dir))
        self.test_images = self.train_images[self.test_indices]
        self.test_image = np.concatenate(self.test_images, axis=1)
        imwrite(self.test_image,
                os.path.join(self.out_dir, "test.png"))
        
        self.key = 'mesh.vertex_positions'
        self.key_f = 'mesh.faces'

        self.train_images = torch.from_numpy(self.train_images)
        self.test_images = torch.from_numpy(self.test_images)

        camera_pos = np.loadtxt(self.camera_file)
        camera_info = Camera(target=gin.REQUIRED)
        
        self.scene_mi = mi.load_file(self.scene_init_file_mi)

        
        cam_list = gen_cameras_mi(positions=camera_pos,
                                   target=camera_info.target,
                                   up = camera_info.up,
                                   fov=camera_info.fov,
                                   resolution=[self.width, self.height], type=1)
        # self.cameras = [self.scene_init.camera]
        self.cameras_psdr = [l[0] for l in cam_list]
        self.cam_info = [l[1] for l in cam_list]
        
        self.mi_sensors = []
        
        from mitsuba import ScalarTransform4f as T
        
        for i in range(len(self.cameras_psdr)):
            mi_sensor = mi.load_dict({
                'type': 'perspective',
                'fov': self.cameras_psdr[i].fov,
                'to_world': T.look_at(self.cam_info[i][0], self.cam_info[i][1], self.cam_info[i][2]),
                'film': {
                    'type': 'hdrfilm',
                    'width': 256, 'height': 256,
                    'filter': {'type': 'box'},
                    'sample_border': True,
                },     
                'sampler': {
                    'type': 'independent',
                    'sample_count': self.options.spp
                },
            })
            self.mi_sensors.append(mi_sensor)
            
            
        self.integrator_mi = mi.load_dict({
            'type': 'prb_projective',
            'sppc': self.options.spp,
            'sppp': self.options.sppe,
            'sppi': self.options.sppse1,
            'max_depth': self.options.max_bounces + 1,
        })
        self.sceneAD = psdr_cpu.SceneAD(self.scene_init)
        
        self.render = RenderMi2(self.integrator_mi, self.integrator, self.mi_sensors, self.cameras_psdr, self.options.max_bounces, self.shape_id)

    def run(self):
        guiding_options = GuidingOptions()

        F = torch.tensor(self.scene_init.shapes[self.shape_id].indices, dtype=torch.int32)
        E = torch.tensor(
            self.scene_init.shapes[self.shape_id].getEdges()).long()
        V = torch.tensor(
            self.scene_init.shapes[self.shape_id].vertices_world,
            dtype=torch.float32,
            requires_grad=True)
        
        params = mi.traverse(self.scene_mi)
        F_mi = dr.llvm.ad.Int(F.flatten())
        params[self.key_f] = F_mi
        
        params.update()

        L = uniform_laplacian(V, E).detach() * \
            self.lmbda
        I = sparse_eye(L.shape[0])
        IL_term = I + L
        Lv = np.asarray(IL_term.coalesce().values())
        Li = np.asarray(IL_term.coalesce().indices())
        IL_term_s = scipy.sparse.coo_matrix((Lv, Li), shape=L.shape)
        IL_term_s_solver = scipy.sparse.linalg.factorized(IL_term_s)
        optimizer = LGDescent(
            params=[
                {'params': V,
                 'lr':  self.lr},
            ],
            IL_term=IL_term,
            IL_solver=IL_term_s_solver)
        # optimizer = LargeStepsOptimizer(V, F, self.lr, lmbda=self.lmbda)
        ### WARN: add edge sort here ###

        seed(1)
        error = []

        for iter in range(self.niter):
            
            now = datetime.datetime.now()
            optimizer.zero_grad()

            loss = 0.0


            for j in range(self.batch_size):
                sensor_id = randrange(len(self.train_images))
                
                time0 = time.time()
                print(" ")
                img = self.render(self.scene_mi, self.sceneAD, params, self.key, iter + 1400, self.options.spp, V, sensor_id)
                time1 = time.time()
                print("forward time: ", time1 - time0)
                # L1 Loss
                img_loss = torch.sum(torch.abs(img - self.train_images[sensor_id]))
                loss += img_loss.item()
                
                time0 = time.time()
                print(" ")
                img_loss.backward()
                time1 = time.time()
                print("backward time: ", time1 - time0)
                print("Rendering camera:", sensor_id,
                      "loss:", img_loss.item(), end='\r')
                
            # print stat
            print("V grad: ", V.grad)
            optimizer.step()
            self.sceneAD.val.shapes[0].setVertices(V.detach().numpy())

            self.sceneAD.val.shapes[0].configure()  # !
            self.sceneAD.val.configure()
            end = datetime.datetime.now() - now
            print('Iteration:', iter,
                  'Loss:', loss,
                  'Total time:', end.seconds + end.microseconds / 1000000.0)
            error.append(loss)
            np.savetxt(self.out_dir+"/loss.log", error)

            # write files
            if iter % self.print_size == 0:
                write_obj(V.detach(), F.detach(),
                          os.path.join(self.out_dir, "iter", "iter_%d.obj" % iter))
                test_images = []
                print("write finished")
                for test_id in self.test_indices:
                    print("Testing camera:", test_id)
                    
                    # self.sceneAD.val.camera = self.cameras_psdr[test_id]
                    # self.sceneAD.val.configure()
                    # test_img = self.integrator.renderC(self.sceneAD.val, self.options)\
                    #     .reshape(self.height, self.width, 3)
                    
                    test_img = self.render(self.scene_mi, self.sceneAD, params, self.key, iter + 400, self.options.spp, V, test_id).detach().numpy()
                    test_images.append(test_img)
                test_image = np.concatenate(test_images, axis=1)
                imwrite(np.concatenate([self.test_image, test_image], axis=0),
                        os.path.join(self.out_dir, "iter", "iter_%d.exr" % iter))
            # if (iter == 0):
            #     exit()


import argparse
if __name__ == "__main__":
    # os.chdir('./inverse_rendering')
    default_config = './aq_bunny_in_glass_configs_mi/aq_bunny_in_glass.conf'
    parser = argparse.ArgumentParser(
        description='Script for generating validation results')
    parser.add_argument('config_file', metavar='config_file',
                        type=str, nargs='?', default=default_config, help='config file')
    args, unknown = parser.parse_known_args()
    # Dependency injection: Arguments are injected into the function from the gin config file.
    gin.parse_config_file(args.config_file, skip_unknown=True)
    train_runner = TrainRunner()
    train_runner.run()