# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# 
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import sys
import os
sys.path.insert(1, "../../")

import gin
import torch
import TensorRay as TR
import numpy as np
import scipy
import scipy.sparse.linalg
from dataclasses import dataclass, field
import argparse
import datetime
from typing import List
from random import seed, randrange

from pyTensorRay.fwd import *
from pyTensorRay.common import get_int_ptr, mkdir
from pyTensorRay.utils import image_tensor_to_torch, save_torch_image, read_images, update_render_batch_options
from pyTensorRay.common import get_int_ptr, get_ptr, mkdir
from pyTensorRay.metrics import TensorboardDashboard, compute_mesh_distance
from pyTensorRay.multi_view import gen_camera
from pyTensorRay.parameter_manager import SceneParameterManager
from pyTensorRay.renderer_multi_view import Render
from pyTensorRay.loss import *
from pyTensorRay.optimizer import LargeSteps, LargeSteps2, LargeSteps3

def gen_neighbour_for_texture(width, height):
    num_of_faces = (width - 1) * (height - 1) * 2
    F = torch.zeros((num_of_faces, 3), dtype=torch.int32)
    for i in range(height - 1):
        for j in range(width - 1):
            index_of_texel = i * (width - 1) + j
            index_of_f = index_of_texel * 2
            F[index_of_f] = torch.tensor([index_of_texel, index_of_texel + 1, index_of_texel + width])
            # F[index_of_f + 1] = torch.tensor([index_of_texel, index_of_texel + 1, index_of_texel + width + 1])
            F[index_of_f + 1] = torch.tensor([index_of_texel, index_of_texel + width, index_of_texel + width + 1])
            # F[index_of_f + 3] = torch.tensor([index_of_texel + 1, index_of_texel + width, index_of_texel + width + 1])

    return F

@gin.configurable
@dataclasses.dataclass
class TrainRunner:
    scene_init_file: str = "./scenes/sphere_to_cube/init.xml"
    scene_target_file: str = "./scenes/sphere_to_cube/tar.xml"
    out_dir: str = "./output/sphere_to_cube/"
    target_dir: str = "./output/sphere_to_cube/target/"
    test_indices: List = field(default_factory=lambda: [0, 14, 29, 40])
    training_index: List[List[int]] = field(default_factory=lambda: [])
    is_camera_stratified: bool = False
    atten_iter: List[List[int]] = field(default_factory=lambda: [])
    lr_schedule: List[float] = field(default_factory=lambda: [])

    shapes_id: List[int] = field(default_factory=lambda: [])
    bsdfs_id: List[int] = field(default_factory=lambda: [])
    niter: int = 500
    lr: float = 0.1
    lambda_value: float = 10.0
    burn_in_iter: int = 32
    batch_size: int = 10
    print_size: int = 10

    camera_file: str = "./output/sphere_to_cube/cam_pos.txt"

    options: TR.RenderOptions = None
    integrator: TR.Integrator = TR.PathTracer()
    ref_integrator: TR.Integrator = TR.PathTracer()
    edge_integrators: List = dataclasses.field(
        default_factory=lambda: [TR.PixelBoundaryIntegrator(), TR.PrimaryEdgeIntegrator(), TR.DirectEdgeIntegrator(), TR.IndirectEdgeIntegrator()])

    def __post_init__(self):
        mkdir(self.out_dir)
        mkdir(os.path.join(self.out_dir, "iter"))
        mkdir(os.path.join(self.out_dir, "iter", "train"))
        self.scene_init = Scene(self.scene_init_file)
        self.width = self.scene_init.get_width(0)
        self.height = self.scene_init.get_height(0)
        # TODO: passing a dict to set up requires_grad

        for shape_id in self.shapes_id:
            self.scene_init.shapes[shape_id].diff_all_vertex_pos()

        for bsdf_id in self.bsdfs_id:
            self.scene_init.bsdfs[bsdf_id].diff_texture()

        cameras_info = np.loadtxt(self.camera_file)
        self.cameras = []
        for i in range(len(cameras_info)):
            self.cameras.append(gen_camera(cameras_info[i,0:3], cameras_info[i,3:6], cameras_info[i,6:9],
                                            cameras_info[i,9], [self.width, self.height]))
        # for c in cameras_info:
        #     self.scene_init.update_env_map_bbox(TR.Vector3f(c[0], c[1], c[2]))

        self.scene_init.configure()
        self.scene_param_manager = SceneParameterManager([0])
        self.scene_param_manager.get_values()
        print(self.scene_param_manager)
        self.scene_init.configure()

        self.scene_target = Scene(self.scene_target_file)
        self.scene_target.configure()
        self.V_targets = []
        self.F_targets = []
        for shape_id in self.shapes_id:
            shape = self.scene_target.shapes[shape_id]
            face_count = shape.get_face_count()
            vertex_count = shape.get_vertex_count()
            V = torch.zeros([vertex_count, 3], dtype=torch.float32)
            F = torch.zeros([face_count, 3], dtype=torch.int32)
            shape.get_vertex_pos(get_ptr(V))
            shape.get_face_indices(get_int_ptr(F))
            self.V_targets.append(V.numpy())
            self.F_targets.append(F.numpy())

        #for shape_id in self.shapes_id:
        #    self.scene_target.shapes[shape_id].export_mesh(os.path.join(self.out_dir, "target_{}.obj".format(shape_id)))

        self.train_images = torch.stack(read_images(self.target_dir))
        self.train_image_pyramids = [build_pyramid(self.train_images[i]) for i in range(self.train_images.shape[0])]

        self.test_images = self.train_images[self.test_indices]
        self.test_image = torch.cat([test_img for test_img in self.test_images], axis=1)
        save_torch_image(os.path.join(self.out_dir, "test.png"), self.test_image)        

        # TODO: check if RenderOptions.seed is useful
        # self.options = RenderOptions(
        #     1234, gin.REQUIRED, gin.REQUIRED, gin.REQUIRED, gin.REQUIRED, gin.REQUIRED, gin.REQUIRED
        # )
        # batch_options = RenderBatchOptions(gin.REQUIRED, gin.REQUIRED, gin.REQUIRED, gin.REQUIRED, gin.REQUIRED)
        # self.options = update_render_batch_options(self.options, batch_options)

        # set up the dashboard
        dashboard_path = os.path.join(self.out_dir, 'logdir')
        self.dashboard = TensorboardDashboard(log_dir=dashboard_path)

        self.render = Render(self.scene_init, self.cameras, self.scene_param_manager,
                             self.integrator, self.ref_integrator, self.edge_integrators, self.options, None)

    def run(self):
        # Compute Laplacian and set up solver
        shape = self.scene_init.shapes[self.shapes_id[0]]
        face_count = shape.get_face_count()
        vertex_count = shape.get_vertex_count()
        F = torch.zeros([face_count, 3], dtype=torch.int32)
        F_cur = []
        F_cur.append(F.numpy())
        shape.get_face_indices(get_int_ptr(F))
        # optimizer1 = LargeSteps(self.scene_param_manager.params[0], F,
        #                     self.scene_param_manager, self.lr, (0.9, 0.999), self.lambda_value)
        # optimizer1 = LargeSteps2(self.scene_param_manager.params[0], F, self.lr, (0.9, 0.999), self.lambda_value)
        optimizer1 = LargeSteps2(self.scene_param_manager.params[0], F, self.lr, (0.9, 0.999), self.lambda_value)
        # optimizer1 = torch.optim.Adagrad(self.scene_param_manager.params[0], lr=self.lr, eps=1e-4)

        seed(self.options.seed)
        error = []

        self.dashboard.add_function('Loss/train')
        # self.dashboard.add_function('Loss/Albedo Loss')
        self.dashboard.add_function('Loss/Mesh Distance')
        self.dashboard.add_function('Test Image')

        TR.set_rnd_seed(self.options.seed)
        atten_level = 0
        bn_T = 0.0
        now = datetime.datetime.now()
        self.scene_init.configure()
        for i_iter in range(self.burn_in_iter):
            for camera_id in range(len(self.cameras)):
                self.render(self.scene_param_manager.params, camera_id,
                                {"seed": int(self.options.seed + i_iter * 1e5), "phase": "training" })
            self.render.step()
        end = datetime.datetime.now() - now
        bn_T = bn_T + end.seconds + end.microseconds / 1000000.0

        T = 0.0
        for i_iter in range(self.niter):
            now = datetime.datetime.now()
            # self.options.seed = np.random.randint(65536)
            optimizer1.zero_grad()
            # optimizer2.zero_grad()
            loss = 0.0
        
            if len(self.training_index) > 0:
                batch_index = randrange(0, len(self.training_index))
                np.random.shuffle(self.training_index[batch_index])
            else:
                batch_index = -1

            # if len(self.atten_iter) > atten_level and self.atten_iter[atten_level] == i_iter:
            #     self.options.spp = max(int(self.options.spp / 2), 1)
            #     self.options.spp_batch = self.options.spp
            #     atten_level = atten_level + 1

            # render
            for j in range(self.batch_size):
                if batch_index >= 0:
                    camera_id = self.training_index[batch_index][j]
                    print(camera_id)
                else:
                    camera_id = randrange(0, len(self.cameras))

                img = self.render(self.scene_param_manager.params, camera_id,
                                  {"seed": int(i_iter * 1e5), "phase": "training" })
                if j == self.batch_size - 1:
                    save_torch_image(os.path.join(self.out_dir, "iter", "iter_{}.exr".format(i_iter, camera_id)), img.clone().detach())
                img[img.isnan()] = 0.0
                img = torch.clamp(img, 0.0, 4.0)
                if j == self.batch_size - 1:
                    save_torch_image(os.path.join(self.out_dir, "iter", "iter_{}.exr".format(i_iter, camera_id)), img.clone().detach())
                # compute loss and backward
                img_loss = compute_tanh_loss_L2(img, self.train_images[camera_id])
                # img_loss = compute_render_loss_L2(img, self.train_images[camera_id], 1.0)
                # img_loss = compute_render_loss_pyramid_L1(img, self.train_image_pyramids[camera_id], 1.0)
                loss += img_loss.item()
                img_loss.backward()
                print("Rendering camera: {:d}, loss: {:.3f}".format(camera_id, img_loss.item()), end='\r')

            #end = datetime.datetime.now() - now
            #print("\n[INFO] render time = {:.3f}".format(end.seconds + end.microseconds / 1000000.0))

            optimizer1.step()
            # optimizer2.step()
            self.scene_param_manager.set_values()
            self.scene_init.configure()
            self.dashboard.write_to_summary('Loss/train', 'scalar', loss)

            # print stats
            end = datetime.datetime.now() - now
            T = T + end.seconds + end.microseconds / 1000000.0
            print("[INFO] iter = {:d}, loss = {:.3f}, time = {:.3f}".format(
                i_iter, loss, end.seconds + end.microseconds / 1000000.0))
            error.append(loss)
            np.savetxt(os.path.join(self.out_dir, "loss.log"), error)

            # self.render.step()

            # export files

            # It's quite slow when the mesh is large.
            if i_iter > 0 and (i_iter + 1) % 16 == 0:
                for shape_id in self.shapes_id:
                    cur_mesh_path = os.path.join(self.out_dir, "iter/iter_{:d}_{:04d}.obj".format(shape_id, i_iter))
                    self.scene_init.shapes[shape_id].export_mesh(cur_mesh_path)

            if i_iter % self.print_size == 0:
                for i_shape in range(len(self.shapes_id)):
                    shape_id = self.shapes_id[i_shape]
                    cur_mesh_path = os.path.join(self.out_dir, "iter/iter_{:d}_{:04d}.npy".format(shape_id, i_iter))
                    V_cur = self.scene_param_manager.params[0][0].detach().clone()
                    # F_cur = torch.zeros([face_count, 3], dtype=torch.int32)
                    np.save(cur_mesh_path, V_cur.numpy())
                    V_cur = V_cur.view(3, -1).transpose(0, 1).numpy()
                    mesh_dist = compute_mesh_distance(V_cur, F_cur[i_shape], self.V_targets[i_shape], self.F_targets[i_shape])
                    self.dashboard.write_to_summary('Loss/Mesh Distance', 'scalar', mesh_dist)
                    print("[INFO] mesh loss =", mesh_dist)

                # test_images = []
                # for test_id in self.test_indices:
                #     print("Testing camera: {:d}".format(test_id), end='\r')
                #     test_img = self.render(self.scene_param_manager.params, test_id,
                #                            {"seed": int(i_iter * 1e5), "phase": "testing" })
                #     test_images.append(test_img.detach())
                # test_image = torch.cat(test_images, axis=1)
                # test_image = torch.cat([self.test_image, test_image], axis=0)
                # # save_torch_image(os.path.join(self.out_dir, "iter/iter_{:d}.exr".format(i_iter)), test_image)
                # self.dashboard.write_to_summary('Test Image', 'image', test_image)

                # current_albedo_img = torch.zeros(515, 360, 3)
                # pixel_per_channel = 515 * 360
                # current_albedo_img[:,:,0] = current_albedo[0:pixel_per_channel].reshape((515, 360))
                # current_albedo_img[:,:,1] = current_albedo[pixel_per_channel:pixel_per_channel*2].reshape((515, 360))
                # current_albedo_img[:,:,2] = current_albedo[pixel_per_channel*2:pixel_per_channel*3].reshape((515, 360))
                # save_torch_image(os.path.join(self.out_dir, "iter", "albedo_iter_{}.exr".format(i_iter + 1)), current_albedo_img)

            self.render.step()
            self.dashboard.step()
        print("burn-in time: {}".format(bn_T))
        print("final time: {}".format(T))


if __name__ == "__main__":
    default_config = "sphere_to_cube.conf"
    parser = argparse.ArgumentParser()
    parser.add_argument("config_file", metavar="config_file", type=str, nargs="?",
                        default=default_config, help="config file")
    args, unknown = parser.parse_known_args()
    gin.parse_config_file(args.config_file, skip_unknown=True)

    TR.env_create()
    train_runner = TrainRunner()
    train_runner.run()
    TR.env_release()
