import psdr_cuda
import numpy as np
import enoki as ek
from enoki.cuda_autodiff import Float32 as FloatD, Vector3f as Vector3fD, Matrix4f as Matrix4fD
from enoki.cuda import Float32 as FloatC
import torch 
from torch.autograd import Variable
from torch import tensor as Tensor
import argparse
import cv2

from utils.file import create_output_dir
from utils.plot import image_loss, para_loss
from utils.loss import compute_silhouette_loss, uniform_laplacian, texture_range_loss, total_variation_lossD, texture_correlation_loss, total_variation_loss, compute_render_loss, compute_render_loss_mask

from random import randrange
from typing import List, Optional
import datetime

from cupyx.scipy.sparse import linalg
import cupyx
import cupy as cp

import json

def sparse_eye(size):
    indices = torch.arange(0, size, device="cuda").long().unsqueeze(0).expand(2, size)
    values = torch.tensor(1.0, device="cuda").expand(size)
    cls = getattr(torch.sparse, values.type().split(".")[-1])
    return cls(indices, values, torch.Size([size, size])) 


def adamax(params: List[Tensor],
           grads: List[Tensor],
           m1_tp: List[Tensor],
           m2_tp: List[Tensor],
           state_steps: List[int],
           *,
           beta1: float,
           beta2: float,
           lr: float,
           IL_term, IL_solver):
    r"""Functional API that performs adamax algorithm computation.
    See :class:`~torch.optim.Adamax` for details.
    """
    for i, param in enumerate(params):
        grad = grads[i]
        m1_tp = m1_tp[i]
        m2_tp = m2_tp[i]
        step = state_steps[i]

        grad = torch.as_tensor(IL_solver(cp.asarray(grad)), device='cuda')
        m1_tp.mul_(beta1).add_(grad, alpha=1 - beta1)
        m2_tp.mul_(beta2).add_(grad.square(), alpha=1 - beta2)
        u = torch.matmul(IL_term, param.detach())
        clr = lr / ((1-beta1 ** step) * (m2_tp.amax()/(1-beta2 ** step)).sqrt()) * m1_tp
        u = u - clr
        param.copy_(torch.as_tensor(IL_solver(cp.asarray(u)), device='cuda'))


class LGDescent(torch.optim.Optimizer):
    """Take a coordinate descent step for a random parameter.
    And also, make every 100th step way bigger.
    """
    def __init__(self, params, IL_term, IL_solver, lr=2e-3, betas=(0.9, 0.999)):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        defaults = dict(lr=lr, betas=betas)
        self.IL_term = IL_term
        self.IL_solver = IL_solver
        super(LGDescent, self).__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:

            params_with_grad = []
            grads = []
            m1_tp = []
            m2_tp = []
            state_steps = []

            beta1, beta2 = group['betas']
            lr = group['lr']

            for p in group['params']:
                if p.grad is None:
                    continue
                params_with_grad.append(p)
                if p.grad.is_sparse:
                    raise RuntimeError('Adamax does not support sparse gradients')
                grads.append(p.grad)

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    state['m1_tp'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state['m2_tp'] = torch.zeros_like(p, memory_format=torch.preserve_format)

                m1_tp.append(state['m1_tp'])
                m2_tp.append(state['m2_tp'])

                state['step'] += 1
                state_steps.append(state['step'])

            adamax(params_with_grad,
                     grads,
                     m1_tp,
                     m2_tp,
                     state_steps,
                     beta1=beta1,
                     beta2=beta2,
                     lr=lr, 
                     IL_term=self.IL_term, IL_solver=self.IL_solver)

        return loss

def psdr_optimize(OPT_Para):
    torch.cuda.empty_cache() 
    ek.cuda_malloc_trim()

    sensor_count = OPT_Para['sensor_count']

    R_res = OPT_Para['R_size']
    D_res = OPT_Para['D_size']
    S_res = OPT_Para['S_size']


    iter_path = output_dir + "iter/"
    create_output_dir(iter_path)
    print("psdr_optimize!")
    direct_integrator = psdr_cuda.DirectIntegrator(1, 1, 1)
    # direct_integrator.hide_emitters = True
    silhouette_integrator = psdr_cuda.FieldExtractionIntegrator("silhouette")

    current_stage=0
    mesh_key = 'Mesh[0]'

    sc = []
    num_sensors = []
    for sc_id in range(len(scene_init)):
        sc.append(psdr_cuda.Scene())
        sc[sc_id].load_file(scene_init[sc_id], False)
        sc[sc_id].opts.spp = OPT_Para['spp']
        sc[sc_id].opts.sppe = OPT_Para['sppe']
        sc[sc_id].opts.sppse = OPT_Para['sppse']
        sc[sc_id].opts.log_level = 0
        sc[sc_id].configure()
        num_sensors.append(sc[sc_id].num_sensors)
    ro1 = sc[0].opts

    debug_npass = OPT_Para['debug_npass']
    total_sensors = sum(num_sensors)

    tar_img_store = []
    for curr_sensor_id in range(0, total_sensors):
        curr_tar_img = cv2.imread(target_path+"exr/sensor_%d.exr" % curr_sensor_id, cv2.IMREAD_UNCHANGED)
        curr_tar_img = cv2.resize(curr_tar_img, dsize=(ro1.width, ro1.height))
        curr_tar_img = torch.from_numpy(cv2.cvtColor(curr_tar_img, cv2.COLOR_RGB2BGR)).float()
        curr_tar_img = curr_tar_img.reshape((-1, 3))
        tar_img_store.append(curr_tar_img)
        del curr_tar_img

    tar_sli_store = []
    for curr_sensor_id in range(0, total_sensors):
        curr_sli_img = cv2.imread(sli_path+"exr/sensor_%d.exr" % curr_sensor_id, cv2.IMREAD_UNCHANGED)
        curr_sli_img = cv2.resize(curr_sli_img, dsize=(ro1.width, ro1.height))
        curr_sli_img = torch.from_numpy(curr_sli_img).float()[:,:,0]
        curr_sli_img = curr_sli_img.reshape((-1, 1))[:, 0]
        tar_sli_store.append(curr_sli_img)
        del curr_sli_img
    iter_num = 0

    sensor_sum = []
    sum_sen = 0
    for sen in num_sensors:
        sensor_sum.append(sum_sen)
        sum_sen += sen

    np.random.seed(64) 
    tar_img_array = []
    test_array = [[0, 32], [1, 32], [0, 64], [1, 64]] #35, 15, 56, 85
    for test_sensor_id in test_array:
        conca = sensor_sum[test_sensor_id[0]]
        if not test_sensor_id[1] < num_sensors[test_sensor_id[0]]:
            print("error target image")
            exit()
        curr_tar_img = cv2.imread(target_path+"exr/sensor_%d.exr" % (test_sensor_id[1]+conca), cv2.IMREAD_UNCHANGED)
        curr_tar_img = cv2.resize(curr_tar_img, dsize=(ro1.width, ro1.height))
        curr_tar_img = torch.from_numpy(cv2.cvtColor(curr_tar_img, cv2.COLOR_RGB2BGR)).float()
        curr_tar_img = curr_tar_img.reshape((ro1.height, ro1.width, 3))
        tar_img_array.append(curr_tar_img)
        del curr_tar_img

    print("target camera:", test_array)

    result_img = np.concatenate(tar_img_array, axis=1)
    result_img = cv2.cvtColor(result_img, cv2.COLOR_RGB2BGR)
    result_img = np.power(result_img, 1/2.2)
    result_img = np.uint8(np.clip(result_img * 255., 0., 255.))
    cv2.imwrite(iter_path+"/target.png", result_img)
    print(num_sensors)
    print(sensor_sum)


    class PSDRRender(torch.autograd.Function):
        @staticmethod 
        def forward(ctx, V, R, D, S, idx):
            _vertex_pos = Vector3fD(V)
            _roughness  = FloatD(R)
            _diffuse    = Vector3fD(D)
            _specular   = Vector3fD(S)

            ek.set_requires_gradient(_vertex_pos, V.requires_grad)
            ek.set_requires_gradient(_roughness,  R.requires_grad)
            ek.set_requires_gradient(_diffuse,    D.requires_grad)
            ek.set_requires_gradient(_specular,   S.requires_grad)
            
            ctx.input1 = _vertex_pos
            ctx.input2 = _roughness
            ctx.input3 = _diffuse
            ctx.input4 = _specular

            diffuse   = ek.select(_diffuse   > 1,   1.0, ek.select(_diffuse   < 0,    0.0,  _diffuse  ))
            specular  = ek.select(_specular  > 1,   1.0, ek.select(_specular  < 0,    0.0,  _specular ))
            roughness = ek.select(_roughness > OPT_Para['Stages'][current_stage]['wt_rightR'], OPT_Para['Stages'][current_stage]['wt_rightR'], ek.select(_roughness < OPT_Para['Stages'][current_stage]['wt_leftR'], OPT_Para['Stages'][current_stage]['wt_leftR'], _roughness))
            
            for scc_id in range(0, len(num_sensors)):            
                sc[scc_id].param_map['Mesh[0]'].vertex_positions = _vertex_pos;
                sc[scc_id].param_map['BSDF[0]'].diffuseReflectance.data  = diffuse
                sc[scc_id].param_map['BSDF[0]'].specularReflectance.data = specular
                sc[scc_id].param_map['BSDF[0]'].roughness.data           = roughness 
            
            npixels = ro1.height * ro1.width

            if iter_num % OPT_Para['print_size'] == 0:
                roughness_image = roughness.numpy().reshape(R_res, R_res, 1)
                roughness_image = cv2.cvtColor(roughness_image, cv2.COLOR_RGB2BGR)
                cv2.imwrite(iter_path+"roughness_iter_"+str(iter_num)+".exr", roughness_image)

                diffuse_image = diffuse.numpy().reshape(D_res, D_res, 3)
                diffuse_image = cv2.cvtColor(diffuse_image, cv2.COLOR_RGB2BGR)
                cv2.imwrite(iter_path+"diffuse_iter_"+str(iter_num)+".exr", diffuse_image)

                specular_image = specular.numpy().reshape(S_res, S_res, 3)
                specular_image = cv2.cvtColor(specular_image, cv2.COLOR_RGB2BGR)
                cv2.imwrite(iter_path+"specular_iter_"+str(iter_num)+".exr", specular_image)
                del roughness_image, specular_image, diffuse_image

            sc[idx[0]].configure2([idx[1]])

            render_loss = 0
            sli_loss = 0

            our_img = direct_integrator.renderD(sc[idx[0]], idx[1])
            tar_img = Vector3fD(tar_img_store[sensor_sum[idx[0]]+idx[1]].cuda())

            oursli_img = silhouette_integrator.renderD(sc[idx[0]], idx[1])
            tarsli_img = FloatD(tar_sli_store[sensor_sum[idx[0]]+idx[1]].cuda())
            render_loss += compute_render_loss(our_img, tar_img, OPT_Para['Stages'][current_stage]['wt_image'], npixels) / OPT_Para['batch_size']
            sli_loss += compute_silhouette_loss(oursli_img, tarsli_img, OPT_Para['Stages'][current_stage]['wt_sli'], npixels) / OPT_Para['batch_size']

            loss = render_loss + sli_loss
            ctx.out = loss
            out_torch = ctx.out.torch()
            return out_torch

        @staticmethod
        def backward(ctx, grad_out):
            ek.set_gradient(ctx.out, FloatC(grad_out))
            FloatD.backward()

            gV = ek.gradient(ctx.input1)
            nan_mask = ek.isnan(gV)
            gV = ek.select(nan_mask, 0, gV)
            gradV = gV.torch()   

            gradR = ek.gradient(ctx.input2).torch()
            gradD = ek.gradient(ctx.input3).torch() 
            gradS = ek.gradient(ctx.input4).torch() 
 

            result = (gradV, gradR, gradD, gradS, None)
            del ctx.out, ctx.input1, ctx.input2, ctx.input3, ctx.input4
            return result

    F = sc[0].param_map[mesh_key].face_indices.torch().long()
    E = psdr_cuda.Mesh.edge_indices(sc[0].param_map[mesh_key]).torch().long()

    V = Variable(sc[0].param_map[mesh_key].vertex_positions.torch(),          requires_grad=True)
    R = Variable(sc[0].param_map['BSDF[0]'].roughness.data.torch(),           requires_grad=True)
    D = Variable(sc[0].param_map['BSDF[0]'].diffuseReflectance.data.torch(),  requires_grad=True)
    S = Variable(sc[0].param_map['BSDF[0]'].specularReflectance.data.torch(), requires_grad=True)

    print('Starting optimization')
    for stage_id in range(0, len(OPT_Para['Stages'])):
        if OPT_Para['Stages'][current_stage]['npass'] <= 0:
            break
        current_stage = stage_id
        L = uniform_laplacian(V, E).detach() * OPT_Para['Stages'][current_stage]['lmbda']
        I = sparse_eye(L.shape[0])
        IL_term = I + L
        Lv = cp.asarray(IL_term.coalesce().values())
        Li = cp.asarray(IL_term.coalesce().indices())

        IL_term_s = cupyx.scipy.sparse.coo_matrix((Lv, Li), shape=L.shape)
        IL_term_s_solver = linalg.factorized(IL_term_s)


        params_mesh = [
            {'params': V, 'lr':  OPT_Para['Stages'][current_stage]['lr']},
        ]

        params_bsdf = [
            {'params': R, 'lr':  OPT_Para['Stages'][current_stage]['lr_roug']},
            {'params': D, 'lr':  OPT_Para['Stages'][current_stage]['lr_diff']},
            {'params': S, 'lr':  OPT_Para['Stages'][current_stage]['lr_spec']},
        ]
        optimizer_mesh = LGDescent(params_mesh, IL_term, IL_term_s_solver)
        optimizer_bsdf = torch.optim.Adam(params_bsdf)

        render = PSDRRender.apply
        npass = OPT_Para['Stages'][current_stage]['npass']

        for _ in range(npass+1):
            now = datetime.datetime.now()
            optimizer_mesh.zero_grad()
            optimizer_bsdf.zero_grad()
            total_loss = 0.0
            total_img_loss = 0.0
            total_tex_loss = 0.0
            total_corr_loss = 0.0
            total_tv_loss = 0.0

            for loop in range(0, sensor_count):
                print("Rendering camera:", loop, end = "\r")
                active_scene = randrange(0, len(num_sensors))
                active_sensor = randrange(0, num_sensors[active_scene])

                img_loss = render(V, R, D, S, [active_scene, active_sensor])
                tex_loss = texture_range_loss(D, S, R, weight=OPT_Para['Stages'][current_stage]['wt_texture_rang'], left=OPT_Para['Stages'][current_stage]['wt_leftR'], right=OPT_Para['Stages'][current_stage]['wt_rightR'])
                corr_loss = texture_correlation_loss(D, S, R, res=D_res, weightS=OPT_Para['Stages'][current_stage]['wt_texture_corrS'], weightR=OPT_Para['Stages'][current_stage]['wt_texture_corrR'])
                tv_loss = total_variation_lossD(D, res=D_res, weightR=OPT_Para['Stages'][current_stage]['wt_total_varD']) + total_variation_loss(R, res=R_res, weightR=OPT_Para['Stages'][current_stage]['wt_total_varR']) 

                loss = img_loss + tex_loss + corr_loss + tv_loss
                loss.backward()

                total_loss += loss.item() / float(sensor_count)
                total_img_loss += img_loss.item() / float(sensor_count)
                total_tex_loss += tex_loss.item() / float(sensor_count)
                total_corr_loss += corr_loss.item() / float(sensor_count)
                total_tv_loss += tv_loss.item() / float(sensor_count)
                if loop == 0:
                    Vg = V.grad / float(sensor_count)
                    Rg = R.grad / float(sensor_count)
                    Dg = D.grad / float(sensor_count)
                    Sg = S.grad / float(sensor_count)
                else:
                    Vg += V.grad / float(sensor_count)
                    Rg += R.grad / float(sensor_count)
                    Dg += D.grad / float(sensor_count)
                    Sg += S.grad / float(sensor_count)
                del loss
            V.grad = Vg
            R.grad = Rg
            D.grad = Dg
            S.grad = Sg

            optimizer_mesh.step()
            optimizer_bsdf.step()

            if iter_num % OPT_Para['print_size'] == 0:
                sc[0].param_map[mesh_key].dump(iter_path+"iter_" + str(iter_num) +".obj")
                temp_img_array = []
                for test_scene_id, test_sensor_id in test_array:
                    print("Rendering debug camera:", test_scene_id, test_sensor_id, end = "\r")
                    for ii in range(0, debug_npass):
                        if ii == 0:
                            debug_img = direct_integrator.renderC(sc[test_scene_id], test_sensor_id) / float(debug_npass)
                        else:
                            debug_img += direct_integrator.renderC(sc[test_scene_id], test_sensor_id) / float(debug_npass)
                    temp_img = debug_img.numpy().reshape(ro1.height, ro1.width, 3)
                    temp_img_array.append(temp_img)

                debug_result_img = np.concatenate(temp_img_array, axis=1)
                debug_result_img = cv2.cvtColor(debug_result_img, cv2.COLOR_RGB2BGR)
                debug_result_img = np.power(debug_result_img, 1/2.2)
                debug_result_img = np.uint8(np.clip(debug_result_img * 255., 0., 255.))
                iter_result_img = np.concatenate([result_img, debug_result_img], axis=0)
                cv2.imwrite(iter_path+"img_out_"+str(iter_num)+".png", iter_result_img)
            print()
            
            end = datetime.datetime.now() - now
            print('Stage:', current_stage,
                  'Iteration:', iter_num,
                  'Loss:', total_loss,
                  'Img_loss:', total_img_loss,
                  'Tex_loss:', total_tex_loss,
                  'Corr_loss:', total_corr_loss,
                  'Tv_loss:', total_tv_loss,
                  'Total time:', end.seconds + end.microseconds / 1000000.0)
            iter_num += 1


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog='ESSVBRDF',
        description='Envmap shape SVBRDF optimization',
        epilog=''
    )
    parser.add_argument('--scene', required=True, type=str)
    parser.add_argument('--json', required=True, type=str)
    args = parser.parse_args()

    scene_name = args.scene
    output_dir = "./output/" + scene_name + "_" + args.json + "/"
    target_path = "./output/" + scene_name + "/target/"
    sli_path = "./output/" + scene_name + "/target_sli/"

    create_output_dir("./output/")
    create_output_dir(output_dir)
    source_path = "./data/mesh/"+args.scene+"/"

    config_file = open(args.json+".json").read()
    OPT_config = json.loads(config_file)
    scene_init = []
    for i in range(0, OPT_config["scene_num"]):
        scene_init.append(source_path+"opt_"+str(i)+".xml")
    psdr_optimize(OPT_config)
