# import psdr_cuda
import numpy as np
# import enoki as ek
# from enoki.cuda_autodiff import Float32 as FloatD, Vector3f as Vector3fD, Matrix4f as Matrix4fD
# from enoki.cuda import Float32 as FloatC
import torch
from torch.autograd import Variable
from torch import tensor as Tensor
import argparse
import cv2

from utils.file import create_output_dir
from utils.plot import image_loss, para_loss, write_obj
from utils.loss import compute_envmap_loss, compute_silhouette_loss, uniform_laplacian, texture_range_loss, texture_correlation_loss, total_variation_loss, compute_render_loss, compute_render_loss_mask

from random import randrange
from typing import List, Optional
import datetime

import scipy
import scipy.sparse.linalg
# from cupyx.scipy.sparse import linalg
# import cupyx
# import cupy as cp
import vredner
import pyvredner
import json
import copy
from scene_gen import draw_circle
from utils.utils import gen_camera

# import pyeltopo as etp
# from pyeltopo import intPtr, doublePtr

def sparse_eye(size):
    indices = torch.arange(
        0, size).long().unsqueeze(0).expand(2, size)
    values = torch.tensor(1.0).expand(size)
    cls = getattr(torch.sparse, values.type().split(".")[-1])
    return cls(indices, values, torch.Size([size, size]))


def adamax(params: List[Tensor],
           grads: List[Tensor],
           m1_tp: List[Tensor],
           m2_tp: List[Tensor],
           state_steps: List[int],
           *,
           beta1: float,
           beta2: float,
           lr: float,
           IL_term, IL_solver):
    r"""Functional API that performs adamax algorithm computation.
    See :class:`~torch.optim.Adamax` for details.
    """
    for i, param in enumerate(params):
        grad = grads[i]
        m1_tp = m1_tp[i]
        m2_tp = m2_tp[i]
        step = state_steps[i]

        grad = torch.as_tensor(IL_solver(np.asarray(grad)))
        m1_tp.mul_(beta1).add_(grad, alpha=1 - beta1)
        m2_tp.mul_(beta2).add_(grad.square(), alpha=1 - beta2)
        u = torch.matmul(IL_term, param.detach())
        clr = lr / ((1-beta1 ** step) * (m2_tp.amax() /
                    (1-beta2 ** step)).sqrt()) * m1_tp
        u = u - clr
        param.copy_(torch.as_tensor(IL_solver(np.asarray(u))))


class LGDescent(torch.optim.Optimizer):
    """Take a coordinate descent step for a random parameter.
    And also, make every 100th step way bigger.
    """

    def __init__(self, params, IL_term, IL_solver, lr=2e-3, betas=(0.9, 0.999)):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError(
                "Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError(
                "Invalid beta parameter at index 1: {}".format(betas[1]))
        defaults = dict(lr=lr, betas=betas)
        self.IL_term = IL_term
        self.IL_solver = IL_solver
        super(LGDescent, self).__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:

            params_with_grad = []
            grads = []
            m1_tp = []
            m2_tp = []
            state_steps = []

            beta1, beta2 = group['betas']
            lr = group['lr']

            for p in group['params']:
                if p.grad is None:
                    continue
                params_with_grad.append(p)
                if p.grad.is_sparse:
                    raise RuntimeError(
                        'Adamax does not support sparse gradients')
                grads.append(p.grad)

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    state['m1_tp'] = torch.zeros_like(
                        p, memory_format=torch.preserve_format)
                    state['m2_tp'] = torch.zeros_like(
                        p, memory_format=torch.preserve_format)

                m1_tp.append(state['m1_tp'])
                m2_tp.append(state['m2_tp'])

                state['step'] += 1
                state_steps.append(state['step'])

            adamax(params_with_grad,
                   grads,
                   m1_tp,
                   m2_tp,
                   state_steps,
                   beta1=beta1,
                   beta2=beta2,
                   lr=lr,
                   IL_term=self.IL_term, IL_solver=self.IL_solver)

        return loss


config = {
    "batch": 100,
    "radius": 100,
}

c_scene = {}
c_d_scene = {}

def psdr_optimize(OPT_Para):
    global c_scene, c_d_scene
    """ generate camera """
    batch = config["batch"]
    radius = config["radius"]
    camera_pos = np.loadtxt("cam_pos.txt")
    resolution = [256, 256]
    cameras = gen_camera(positions=camera_pos,
                         target=[0., 0., 0.],
                         fov=10.,
                         resolution=resolution)
    c_cameras = [c.c_obj() for c in cameras]
    sensor_count = OPT_Para['sensor_count']
    debug_npass = OPT_Para['debug_npass']
    iter_path = output_dir + "iter/"
    create_output_dir(iter_path)
    print("psdr_optimize!")
    direct_integrator = vredner.Direct()

    options = vredner.RenderOptions(13,                 # random seed
                                    OPT_Para['spp'],    # spp
                                    OPT_Para['max_bounces'],  # max bounces
                                    OPT_Para['sppe'],   # sppe
                                    OPT_Para['sppse'],  # sppse0
                                    False)
    scene, integrator = pyvredner.load_mitsuba(scene_init)
    c_scene = scene.c_obj()
    c_d_scene = c_scene.clone()
    c_d_scene.setZero()

    current_stage = 0
    mesh_key = 'Mesh[0]'
    num_sensors = batch

    """ load target images """
    tar_img_store = []
    for curr_sensor_id in range(0, num_sensors):
        curr_tar_img = cv2.imread(
            target_path+"sensor_%d.exr" % curr_sensor_id, cv2.IMREAD_UNCHANGED)
        curr_tar_img = cv2.resize(curr_tar_img, dsize=(
            resolution[1], resolution[0]))
        curr_tar_img = torch.from_numpy(cv2.cvtColor(
            curr_tar_img, cv2.COLOR_RGB2BGR)).float()
        curr_tar_img = curr_tar_img.reshape((-1, 3))
        tar_img_store.append(curr_tar_img)
        del curr_tar_img

    """ test images """
    test_array = [0,14,32,57]
    test_img_store = []
    for test_sensor_id in test_array:
        curr_tar_img = cv2.imread(
            target_path+"sensor_%d.exr" % test_sensor_id, cv2.IMREAD_UNCHANGED)
        curr_tar_img = cv2.resize(curr_tar_img, dsize=(
            resolution[1], resolution[0]))
        curr_tar_img = torch.from_numpy(cv2.cvtColor(
            curr_tar_img, cv2.COLOR_RGB2BGR)).float()
        curr_tar_img = curr_tar_img.reshape((-1, 3))
        test_img_store.append(curr_tar_img)
        del curr_tar_img

    """ store the target test images """
    print("target camera:", test_array)
    t_res = (resolution[1], resolution[0], 3)
    tar_img_array = [test_img_store[0].reshape(t_res),
                     test_img_store[1].reshape(t_res),
                     test_img_store[2].reshape(t_res),
                     test_img_store[3].reshape(t_res)]
    target_result_img = np.concatenate(tar_img_array, axis=1)
    target_result_img = cv2.cvtColor(target_result_img, cv2.COLOR_RGB2BGR)
    target_result_img = np.power(target_result_img, 1/2.2)
    target_result_img = np.uint8(np.clip(target_result_img * 255., 0., 255.))
    cv2.imwrite(iter_path+"/target.png", target_result_img)
    np.random.seed(len(tar_img_store))
    shffle_array = np.arange(len(tar_img_store))
    np.random.shuffle(shffle_array)
    iter_num = 0

    class PSDRRender(torch.autograd.Function):
        @staticmethod
        def forward(ctx, V: torch.Tensor, sensor_id):
            global c_scene
            scene.shapes[0].vertices = V.detach()
            c_scene = scene.c_obj()
            ctx.sensor_id = sensor_id
            """ render multiple views """
            c_camera = c_cameras[sensor_id]
            c_scene.camera = c_camera
            image = torch.zeros(resolution[1],
                                resolution[0],
                                3)
            direct_integrator.render(
                c_scene, options, vredner.float_ptr(image.data_ptr()))


            return image

        @staticmethod
        def backward(ctx, grad_out: torch.Tensor):
            global c_scene, c_d_scene

            c_scene.shape_list[1].enable_edge = False
            c_scene.shape_list[2].enable_edge = False
            c_scene.shape_list[3].enable_edge = False

            # PrimaryEdgeIntegrator = vredner.PrimaryEdgeIntegrator(c_scene)
            # DirectEdgeIntegrator = vredner.DirectEdgeIntegrator(c_scene)
            # IndirectEdgeIntegrator = vredner.IndirectEdgeIntegrator(c_scene)

            """ set up the d_scene to store the gradients """
            c_d_scene.setZero()
            #! might be problematic
            c_camera = c_cameras[ctx.sensor_id]
            c_scene.camera = c_camera
            dummy_image = torch.zeros(resolution[1],resolution[0],3)

            direct_integrator.d_render(
                c_scene, c_d_scene, options,
                vredner.float_ptr(grad_out.data_ptr()),
                vredner.float_ptr(dummy_image.data_ptr()))

            # PrimaryEdgeIntegrator.d_render(c_scene, c_d_scene, options,
            #                     vredner.float_ptr(grad_out.data_ptr()),
            #                     vredner.float_ptr(dummy_image.data_ptr()))

            # aq_config = vredner.aq_config(0.000001, 16, 0.5, 0, 5)

            # aq_config.use_heap = True
            # aq_config.edge_draw = False
            # aq_config.sample_decay = 0.5
            # aq_config.min_spg = 8
            # aq_config.max_depth_x = 16
            # aq_config.max_depth_y = 24
            # aq_config.max_depth_z = 24
            # DirectEdgeIntegrator.preprocess_aq(c_scene, aq_config, OPT_Para['max_bounces'])


            # DirectEdgeIntegrator.d_render(c_scene, c_d_scene, options,
            #                     vredner.float_ptr(grad_out.data_ptr()),
            #                     vredner.float_ptr(dummy_image.data_ptr()))

            # IndirectEdgeIntegrator.d_render(c_scene, c_d_scene, options,
            #                     vredner.float_ptr(grad_out.data_ptr()),
            #                     vredner.float_ptr(dummy_image.data_ptr()))


            grad_vertx = torch.tensor(c_d_scene.shape_list[0].vertices)
            grad_vertx[grad_vertx.isnan()] = 0.
            return (grad_vertx,
                    None)

    F = torch.tensor(c_scene.shapes[0].indices).long()
    E = torch.tensor(c_scene.shape_list[0].getEdges()).long()
    V = scene.shapes[0].vertices
    V.requires_grad = True

    n_verts = V.shape[0]
    n_faces = F.shape[0]

    new_verts  = torch.zeros_like(V).cpu().double().contiguous()
    out_verts  = torch.zeros_like(V).cpu().double().contiguous()
    masses     = torch.ones((n_verts, 1)).cpu().double().contiguous()


    print("n_verts:", n_verts, " n_faces:", n_faces)

    print('Starting optimization')

    for stage_id in range(0, len(OPT_Para['Stages'])):
        if OPT_Para['Stages'][current_stage]['npass'] <= 0:
            break
        current_stage = stage_id
        L = uniform_laplacian(V, E).detach() * \
            OPT_Para['Stages'][current_stage]['lmbda']
        I = sparse_eye(L.shape[0])
        IL_term = I + L
        Lv = np.asarray(IL_term.coalesce().values())
        Li = np.asarray(IL_term.coalesce().indices())

        # IL_term_s = scipy.sparse.csc_matrix((Lv, Li), shape=L.shape)
        IL_term_s = scipy.sparse.coo_matrix((Lv, Li), shape=L.shape)
        IL_term_s_solver = scipy.sparse.linalg.factorized(IL_term_s)

        params_mesh = [
            {'params': V, 'lr':  OPT_Para['Stages'][current_stage]['lr']},
        ]

        optimizer_mesh = LGDescent(params_mesh, IL_term, IL_term_s_solver)

        render = PSDRRender.apply
        npass = OPT_Para['Stages'][current_stage]['npass']
        

        for nn in range(npass+1):
            options = vredner.RenderOptions(nn,                 # random seed
                                            OPT_Para['spp'],    # spp
                                            OPT_Para['max_bounces'],                  # max bounces
                                            OPT_Para['sppe'],   # sppe
                                            OPT_Para['sppse'],  # sppse0
                                            False)
            now = datetime.datetime.now()
            optimizer_mesh.zero_grad()
            total_img_loss = 0.0
            active_sensor = []
            for i in range(OPT_Para['sensor_count']):
                active_sensor.append(shffle_array[randrange(num_sensors)])

            loss = 0.
            sensor_count = len(active_sensor)
            for loop in range(0, len(active_sensor)):
                sensor_id = active_sensor[loop]
                img = render(V, sensor_id)
                c_scene = scene.c_obj()
                tar_img = tar_img_store[sensor_id]
                npixels = resolution[0] * resolution[1]
                img_loss = compute_render_loss(img.reshape((-1, 3)), tar_img, 1.0, npixels)
                print("Rendering camera:", sensor_id, img_loss, end="\n")
                loss += img_loss
            loss.backward()


            print("grad: ", np.abs(V.grad).sum())
            a = V.detach().numpy().copy()

            # Eltopo step
            # _curr_verts = V.detach().cpu().double().contiguous()
            optimizer_mesh.step()

            # _new_verts = V.detach().cpu().double().contiguous()
            # _curr_faces = F.detach().cpu().int().contiguous()
            # _masses = masses.detach().cpu().double().contiguous()
            # _out_verts = out_verts.detach().cpu().double().contiguous()
            # trimesh = etp.TriMesh(
            #     doublePtr(_curr_verts.data_ptr()),
            #     intPtr(_curr_faces.data_ptr()),
            #     doublePtr(_masses.data_ptr()), 
            #     int(n_verts), int(n_faces), 
            #     doublePtr(_new_verts.data_ptr()),
            #     doublePtr(_out_verts.data_ptr()),
            #     1.0, 1.0
            # )
            # etp.ElTopoStep(trimesh)
            # print()
            # temp = Variable(_out_verts.clone().float(), requires_grad=True)
            # V.data = temp.data
            # print("V2: ",V.mean())



            b = V.detach().numpy().copy()
            d = np.linalg.norm(a - b)
            print("step : ", d)

            if iter_num % OPT_Para['print_size'] == 0:
                write_obj(V.detach(),
                          F.detach(),
                          iter_path+"iter_" + str(iter_num) + ".obj")

                temp_img_array = []
                for test_id in test_array:
                    print("Rendering debug camera:", test_id, end="\r")
                    c_scene.camera = c_cameras[test_id]
                    debug_img = torch.zeros(resolution[1],
                                            resolution[0],
                                            3)
                    direct_integrator.render(
                        c_scene, options, vredner.float_ptr(debug_img.data_ptr()))
                    for ii in range(0, debug_npass):
                        options = vredner.RenderOptions(ii,                 # random seed
                                                        # spp
                                                        OPT_Para['spp'],
                                                        OPT_Para['max_bounces'],                  # max bounces
                                                        # sppe
                                                        0,
                                                        # sppse0
                                                        0,
                                                        False)
                        if ii == 0:
                            debug_img = torch.zeros(resolution[1],
                                                    resolution[0],
                                                    3)
                            direct_integrator.render(
                                c_scene, options, vredner.float_ptr(debug_img.data_ptr()))
                            debug_img /= float(debug_npass)
                        else:
                            debug_img_ = torch.zeros(resolution[1],
                                                     resolution[0],
                                                     3)
                            direct_integrator.render(
                                c_scene, options, vredner.float_ptr(debug_img_.data_ptr()))
                            debug_img_ /= float(debug_npass)
                            debug_img += debug_img_
                    temp_img = debug_img.numpy().reshape(
                        resolution[1], resolution[0], 3)
                    temp_img_array.append(temp_img)

                debug_result_img = np.concatenate(temp_img_array, axis=1)
                debug_result_img = cv2.cvtColor(
                    debug_result_img, cv2.COLOR_RGB2BGR)
                debug_result_img = np.power(debug_result_img, 1/2.2)
                debug_result_img = np.uint8(
                    np.clip(debug_result_img * 255., 0., 255.))
                iter_result_img = np.concatenate(
                    [target_result_img, debug_result_img], axis=0)
                cv2.imwrite(iter_path+"img_out_"+str(iter_num) +
                            ".png", iter_result_img)
            print()
            end = datetime.datetime.now() - now
            print('Stage:', current_stage,
                  'Iteration:', iter_num,
                  'Loss:', loss.item(),
                  'Total time:', end.seconds + end.microseconds / 1000000.0)
            iter_num += 1


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog='ESSVBRDF',
        description='Envmap shape SVBRDF optimization',
        epilog='Kai Yan (kyan8@uci.edu)'
    )
    parser.add_argument('--scene', required=True, type=str)
    parser.add_argument('--json', required=True, type=str)
    args = parser.parse_args()

    scene_name = args.scene
    output_dir = "./output/" + scene_name + "_" + args.json + "/"
    target_path = "./output/" + scene_name + "/target/"
    scene_init = "./ini.xml"

    create_output_dir("./output/")
    create_output_dir(output_dir)

    config_file = open(args.json+".json").read()
    OPT_config = json.loads(config_file)
    psdr_optimize(OPT_config)
