from pyvredner import transform
from torch.optim import optimizer
import vredner
import pyvredner
import torch
from pypsdr.validate import *
from pypsdr.utils.io import *
from pypsdr.loss import compute_render_loss, uniform_laplacian
from pypsdr.optimizer import LGDescent, sparse_eye
from pypsdr.renderer import Render
from pypsdr.common import gen_cameras
import os
import copy
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'


class Trainer:
    def __init__(self) -> None:
        self.options = vredner.RenderOptions(13,     # random seed
                                             2048,    # spp
                                             6,      # max bounces
                                             0,      # sppse0
                                             0,      # sppe
                                             0,
                                             False)

        self.guiding_options = GuidingOptions()
        self.param = torch.tensor([0.], requires_grad=True)
        self.target_param = torch.tensor([0.6])
        self.optimizer = torch.optim.Adam([self.param], lr=0.02)
        self.num_iter = 100
        self.scene = vredner.Scene("./init.xml")
        self.sceneAD = vredner.SceneAD(self.scene)
        self.integrator = vredner.Path2()
        self.render = Render(
            self.sceneAD, [self.scene.camera], self.integrator, self.options, self.guiding_options, 0)
        self.vertices = torch.tensor(np.array(self.scene.shapes[0].vertices))
        vredner.set_verbose(True)

    """
    translation
    """

    def transform_vertices(self, vertices, param: torch.Tensor):
        return vertices + torch.tensor([1., 0., 0.]) * param

    def run(self):
        Path("./results/iterations").mkdir(parents=True, exist_ok=True)
        vertices = self.transform_vertices(self.vertices, self.target_param)
        image_target = self.render(vertices, sensor_id=0)
        imwrite(image_target.detach().numpy(), "./results/image_target.exr")

        vertices = self.transform_vertices(self.vertices, self.param)
        image_init = self.render(vertices, sensor_id=0)
        imwrite(image_init.detach().numpy(), "./results/image_init.exr")

        error = []
        param = []
        for t in range(self.num_iter):
            self.optimizer.zero_grad()
            self.options.seed = t + 1
            self.options.spp = 128
            self.options.sppse1 = 128
            vertices = self.transform_vertices(self.vertices, self.param)
            image = self.render(vertices, 0, {"options": self.options})
            image_loss = (image - image_target).pow(2).mean()
            image_loss.backward()
            param_loss = (self.param - self.target_param).pow(2).mean().sqrt()
            self.optimizer.step()
            pyvredner.imwrite(image.detach().numpy(), "./results/iterations/iter_%d.exr" % (t))
            error.append([t,
                          param_loss.detach().numpy(),
                          image_loss.detach().numpy()])
            param.append(np.array(self.param.detach().numpy()))
            print(self.param)
            np.savetxt("./results/error.log", error)
            np.savetxt("./results/param.log", param)


if __name__ == "__main__":
    trainer = Trainer()
    trainer.run()
