import scipy.ndimage.filters
import pyredner
import numpy as np
import torch
import matplotlib.pyplot as plt
import os
from typing import List
import gin
from pypsdr.plot import write_obj
import psdr_cpu
import dataclasses
from dataclasses import field
from pypsdr.validate import *
from pypsdr.utils.io import *
from pypsdr.loss import compute_render_loss, uniform_laplacian
from pypsdr.optimize import Camera
from pypsdr.optimizer import LGDescent,sparse_eye
from scipy.spatial.transform import Rotation as R
from pypsdr.renderer import Render2

from random import randrange, seed

import scipy
import scipy.sparse.linalg
import datetime

pyredner.set_use_gpu(False)

@gin.configurable
# generate the constructor
@dataclasses.dataclass
class TrainRunner:
    scene_init_file: str = "./kitty/ini.xml"
    scene_target_file: str = "./kitty/tar.xml"
    shape_id: int = 0  # shape to be optimized
    niter: int = 500
    options: psdr_cpu.RenderOptions = None
    guiding_options: GuidingOptions = None
    out_dir: str = "./output"          # output directory
    # directory that contains target images
    target_dir: str = "./kitty/output/kitty/target"
    test_indices: List = field(
        default_factory=lambda: [0, 14, 32, 57])

    lr: float = 1.0
    lmbda: float = 500.0
    batch_size: int = 10
    print_size: int = 10

    integrator: psdr_cpu.Integrator = psdr_cpu.Direct()

    def __post_init__(self):
        mkdir(self.out_dir)
        mkdir(os.path.join(self.out_dir, "iter"))
        self.scene_redner = pyredner.load_mitsuba(self.scene_init_file)
        self.scene_init_psdr = Scene(self.scene_init_file)
        self.scene_init_psdr.shapes[0].indices = self.scene_redner.shapes[0].indices.detach().numpy()
        self.scene_init_psdr.shapes[0].setVertices(self.scene_redner.shapes[0].vertices.detach().numpy())
        self.scene_init_psdr.configure()
        self.scene_init_psdr.shapes[0].configure()
        self.width = self.scene_init_psdr.camera.width
        self.height = self.scene_init_psdr.camera.height
        self.fov = self.scene_init_psdr.camera.fov
        self.train_images = np.array(read_images(self.target_dir))
        self.test_images = self.train_images[self.test_indices]
        self.test_image = np.concatenate(self.test_images, axis=1)
        imwrite(self.test_image,
                os.path.join(self.out_dir, "test.png"))

        self.train_images = torch.from_numpy(self.train_images)
        self.test_images = torch.from_numpy(self.test_images)
        self.sceneAD = psdr_cpu.SceneAD(self.scene_init_psdr)

        self.obj_pos = np.loadtxt('obj_pos.txt')
        self.cameras = [self.scene_init_psdr.camera]
        self.render_psdr = Render2(self.sceneAD, self.cameras,
                             psdr_cpu.Path2(), self.options, self.guiding_options, 0, self.obj_pos)
        
        psdr_cpu.set_verbose(True)


    def run(self):
        render = pyredner.RenderFunction.apply
        options = RenderOptions(
            0, gin.REQUIRED, gin.REQUIRED, gin.REQUIRED)
        integrator = pyredner.integrators.WarpFieldIntegrator(
                num_samples = options.spp,
                max_bounces = options.max_bounces,
                kernel_parameters = pyredner.integrators.KernelParameters(
                                    vMFConcentration=1e4,
                                    auxPrimaryGaussianStddev=0.01,
                                    numAuxiliaryRays=8
                                )
            )
        F = torch.tensor(self.scene_init_psdr.shapes[0].indices).long()
        E = torch.tensor(
            self.scene_init_psdr.shapes[0].getEdges()).long()
        # V = torch.tensor(
        #     self.scene_redner.shapes[0].vertices,
        #     dtype=torch.float32,
        #     requires_grad=True)
        V = self.scene_redner.shapes[0].vertices.clone().detach().requires_grad_(True)
        print('Starting optimization')
        L = uniform_laplacian(V, E).detach() * \
            self.lmbda
        I = sparse_eye(L.shape[0])
        IL_term = I + L
        Lv = np.asarray(IL_term.coalesce().values())
        Li = np.asarray(IL_term.coalesce().indices())
        IL_term_s = scipy.sparse.coo_matrix((Lv, Li), shape=L.shape)
        IL_term_s_solver = scipy.sparse.linalg.factorized(IL_term_s)
        
        
        optimizer = LGDescent(
            params=[
                {'params': V,
                 'lr':  self.lr},
            ],
            IL_term=IL_term,
            IL_solver=IL_term_s_solver)

        seed(1)
        error = []

        for iter in range(self.niter):
            now = datetime.datetime.now()
            
            options = RenderOptions(
                iter, gin.REQUIRED, gin.REQUIRED, gin.REQUIRED)
            optimizer.zero_grad()
            print(self.batch_size)
            assert(self.batch_size == 1)

            for j in range(self.batch_size):
                pos_id = (iter*self.batch_size+j) % len(self.train_images)
                
                raw_position = V
                pos = self.obj_pos[pos_id]
                r = R.from_rotvec(pos[3] * pos[0:3])
                rot_position = raw_position @ torch.tensor(r.as_matrix().transpose(), dtype=torch.float32)
                self.scene_redner.shapes[0].vertices = rot_position
                args=pyredner.RenderFunction.serialize_scene_class(
                    scene = self.scene_redner,
                    integrator = integrator)
                img = render(iter*iter+1, *args)
                imwrite(img.detach().numpy(), os.path.join(self.out_dir, "iter", "test_redner.png"))
                # img[img.isnan()] = 0.0
                img_loss = compute_render_loss(
                    img, self.train_images[pos_id], 1.0)
                # print(img_loss)
                loss = img_loss.item()
                img_loss.backward()
                print("Rendering camera:", pos_id,
                        "loss:", img_loss.item())
                # if j == 0:
                #     Vg = V.grad.clone()
                # else:
                #     Vg += V.grad.clone()
                del img_loss
            # V.grad = Vg
            # print(V.grad)
            # print(V.grad.sum())
            V.grad[0] += 0.0000000001
            
            optimizer.step()
            # update scene
            # exit()
            # raw_position = V.detach().numpy()
            # pos = self.obj_pos[pos_id]
            # r = R.from_rotvec(pos[3] * pos[0:3])
            # rot_position = np.dot(raw_position, r.as_matrix().transpose())
            # self.sceneAD.val.shapes[self.shape_id].setVertices(rot_position)

            self.sceneAD.val.shapes[self.shape_id].setVertices(V.detach().numpy())
            self.scene_init_psdr.shapes[self.shape_id].setVertices(V.detach().numpy())
            # print stat
            print("grad: ", np.abs(V.grad).sum().item())
            end = datetime.datetime.now() - now
            print('Iteration:', iter,
                  'Loss:', loss,
                  'Total time:', end.seconds + end.microseconds / 1000000.0)
            error.append(loss)
            np.savetxt(self.out_dir+"/loss.log", error)

            # write files
            if iter % self.print_size == 0:
                write_obj(V.detach(), F.detach(),
                          os.path.join(self.out_dir, "iter", "iter_%d.obj" % iter))
                test_images = []
                for test_id in self.test_indices:
                    print("Testing camera:", test_id, end='\r')
                    test_img = self.render_psdr(V.detach(), test_id,
                                           {"options": RenderOptions(iter, gin.REQUIRED, 64, 0, 0)})
                    test_images.append(test_img)
                test_image = np.concatenate(test_images, axis=1)
                imwrite(np.concatenate([self.test_image, test_image], axis=0),
                        os.path.join(self.out_dir, "iter", "iter_%d.png" % iter))




import argparse
if __name__ == "__main__":
    default_config = './gen1.conf'
    parser = argparse.ArgumentParser(
        description='Script for generating validation results')
    parser.add_argument('config_file', metavar='config_file',
                        type=str, nargs='?', default=default_config, help='config file')
    args, unknown = parser.parse_known_args()
    # Dependency injection: Arguments are injected into the function from the gin config file.
    gin.parse_config_file(args.config_file, skip_unknown=True)

    train_runner = TrainRunner()
    train_runner.run()
    