from pyvredner import transform
from torch.optim import optimizer
import vredner
import pyvredner
import torch
from renderer import RenderFunction
import os
import copy
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'


class Trainer:
    def __init__(self) -> None:
        self.py_scene, integrator = pyvredner.load_mitsuba("./scene_side.xml")
        self.d_py_scene = copy.deepcopy(self.py_scene)
        self.vertices = self.py_scene.shapes[1].vertices
        self.options = vredner.RenderOptions(13,     # random seed
                                             1024,    # spp
                                             6,      # max bounces
                                             100,      # sppse0
                                             100,      # sppe
                                             False)
        self.param = torch.tensor([0.], requires_grad=True)
        self.target_param = torch.tensor([0.6])
        self.optimizer = torch.optim.Adam([self.param], lr=0.1)
        self.num_iter = 100
        self.render = RenderFunction.apply

    """ 
    rotate around the z axis 
    """

    def transform_vertices(self, vertices, param: torch.Tensor):
        return vertices @ transform.gen_rotation_matrix3x3([0., 0., 1.], param).transpose(0, 1)

    def run(self):
        Path("./results/iterations").mkdir(parents=True, exist_ok=True)
        vertices = self.transform_vertices(self.vertices, self.target_param)
        image_target = self.render(
            vertices, self.py_scene, self.d_py_scene, self.options)
        pyvredner.imwrite(image_target, "./results/image_target.exr")
        vertices = self.transform_vertices(self.vertices, self.param)
        image_init = self.render(
            vertices, self.py_scene, self.d_py_scene, self.options)
        pyvredner.imwrite(image_init, "./results/image_init.exr")

        error = []
        param = []
        for t in range(self.num_iter):
            self.optimizer.zero_grad()
            self.options.seed = t + 1
            vertices = self.transform_vertices(self.vertices, self.param)
            image = self.render(vertices, self.py_scene,
                                self.d_py_scene, self.options)
            image_loss = (image - image_target).pow(2).mean()
            image_loss.backward()
            param_loss = (self.param - self.target_param).pow(2).mean().sqrt()
            self.optimizer.step()
            pyvredner.imwrite(image, "./results/iterations/iter_%d.exr" % (t))
            error.append([t,
                          param_loss.detach().numpy(),
                          image_loss.detach().numpy()])
            param.append(np.array(self.param.detach().numpy()))
            print(self.param)
            np.savetxt("./results/error.log", error)
            np.savetxt("./results/param.log", param)


if __name__ == "__main__":
    trainer = Trainer()
    trainer.run()
