from torch.torch_version import Version
from pyvredner import transform
from torch.optim import optimizer
import vredner
import pyvredner
import torch
from pypsdr.renderer import Render
from pypsdr.utils.io import *
from pypsdr.validate import *
import copy
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'


class Trainer:
    def __init__(self) -> None:
        self.init_scene = Scene("./scene.xml")
        self.tar_scene = Scene("./tar.xml")
        self.params = torch.tensor(np.array(self.init_scene.shapes[2].vertices),
                                   dtype=torch.float, requires_grad=True)
        self.options = vredner.RenderOptions(13,     # random seed
                                             1024,    # spp
                                             2,      # max bounces
                                             100,      # sppse0
                                             100,      # sppe
                                             False)

        self.sceneAD = vredner.SceneAD(self.init_scene)
        self.optimizer = torch.optim.Adam([self.params], lr=0.05)
        self.num_iter = 100
        self.integrator = vredner.Direct()
        self.cameras = [self.init_scene.camera]
        self.render = Render(
            self.sceneAD, self.cameras, self.integrator, self.options, shape_id=2)
        vredner.set_verbose(True)

    def run(self):
        mkdir("./results/iterations")
        self.render.setState({'sceneAD': vredner.SceneAD(self.tar_scene)})
        print("target:", torch.tensor(
            np.array(self.tar_scene.shapes[2].vertices)))
        image_target = self.render(V=torch.tensor(
            np.array(self.tar_scene.shapes[2].vertices)), sensor_id=0)
        imwrite(image_target.numpy(), "./results/image_target.exr")

        self.render.setState({'sceneAD': vredner.SceneAD(self.init_scene)})
        image_init = self.render(V=self.params, sensor_id=0)
        imwrite(image_init.detach().numpy(), "./results/image_init.exr")

        error = []
        param = []
        for t in range(self.num_iter):
            self.optimizer.zero_grad()
            self.options.seed = t + 1
            self.options.spp = 128
            image = self.render(V=self.params, sensor_id=0,
                                params={'options': self.options})
            image_loss = (image - image_target).pow(2).mean().sqrt()
            image_loss.backward()
            # param_loss = (self.param - self.target_param).pow(2).mean().sqrt()
            self.optimizer.step()
            self.sceneAD.val.shapes[2].setVertices(
                self.params.detach().numpy())
            self.render.setState({'sceneAD': self.sceneAD})
            imwrite(image.detach().numpy(),
                    "./results/iterations/iter_%d.exr" % (t))
            error.append([t,
                          #   param_loss.detach().numpy(),
                          image_loss.detach().numpy()])
            # param.append(np.array(self.params.detach().numpy()))
            print(self.params)
            np.savetxt("./results/error.log", error)
            # np.savetxt("./results/param.log", param)


if __name__ == "__main__":
    trainer = Trainer()
    trainer.run()
