import torch 
import torch.nn as nn 
import numpy as np
import cv2
# import enoki as ek


def compute_render_loss(our_img, tar_img, weight, npixels):
    loss = 0
    I = our_img
    T = tar_img
    I[I > 1] = 1.
    T[T > 1] = 1.
    diff = (I - T).abs()
    loss += diff.mean()

    # for i in range(3):  
    #     I_ = our_img[i]
    #     T_ = tar_img[i]
    #     I = ek.select(I_ > 1, 1.0, I_)
    #     T = ek.select(T_ > 1, 1.0, T_)
    #     diff = ek.abs(I - T)
    #     loss += ek.hsum(ek.hsum(diff)) / (npixels *3)
    return loss * weight

def compute_envmap_loss(our_env, tar_env, weight, npixels):
    loss = 0
    I = our_env
    T = tar_env
    diff = torch.pow(I - T, 2)
    loss += diff.mean()

    # for i in range(3):  
    #     I = our_env[i]
    #     T = tar_env[i]
    #     diff = ek.pow(I - T, 2)
    #     loss += ek.hsum(ek.hsum(diff)) / (npixels * 3)
    return loss * weight

def compute_render_loss_mask(our_img, tar_img, mask_img, weight, npixels):
    loss = 0

    I = our_img
    T = tar_img
    I[I > 1] = 1.
    T[T > 1] = 1.
    diff = (I - T).abs()
    diff[mask_img < 254.] = 0.
    loss += diff.mean()

    # for i in range(3):  
    #     I_ = our_img[i]
    #     T_ = tar_img[i]
    #     I = ek.select(I_ > 1, 1.0, I_)
    #     T = ek.select(T_ > 1, 1.0, T_)
    #     diff = ek.abs(I - T)
    #     diff = ek.select(mask_img > 254.0, diff, 0.0)
        
    #     loss += ek.hsum(ek.hsum(diff)) / (npixels *3)
    return loss * weight

def compute_silhouette_loss(our_sil, tar_sil, weight, npixels):
    diff = ek.abs(our_sil[0] - tar_sil) + ek.abs(our_sil[1] - tar_sil) + ek.abs(our_sil[2] - tar_sil)
    loss = ek.hsum(ek.hsum(diff)) / (npixels*3.0)
    return loss * weight


def uniform_laplacian(verts, edges):
    # compute L once per mesh subdiv. 
    with torch.no_grad():
        V = verts.shape[0]
        e0 = edges[:, 0]
        e1 = edges[:, 1]
        # e0, e1, _, _ = edges.unbind(1)
        idx01 = torch.stack([e0, e1], dim=1)  # (E, 2)
        idx10 = torch.stack([e1, e0], dim=1)  # (E, 2)
        idx = torch.cat([idx01, idx10], dim=0).t()  # (2, 2*E)
        # First, we construct the adjacency matrix,
        # i.e. A[i, j] = 1 if (i,j) is an edge, or
        # A[e0, e1] = 1 &  A[e1, e0] = 1
        ones = torch.ones(idx.shape[1], dtype=torch.float32, device=verts.device)

        # We construct the Laplacian matrix by adding the non diagonal values
        # i.e. L[i, j] = 1 if (i, j) is an edge
        L = torch.sparse.FloatTensor(idx, ones, (V, V))

        # Then we add the diagonal values L[i, i] = -1.
        idx = torch.arange(V, device=verts.device)
        idx = torch.stack([idx, idx], dim=0)
        ones = torch.ones(idx.shape[1], dtype=torch.float32, device=verts.device)
        L -= torch.sparse.FloatTensor(idx, -ones, (V, V))

        vals = torch.sparse.sum(L, dim=0).to_dense()
        indices = torch.arange(V)
        idx = torch.stack([indices, indices], dim=0)
        L = torch.sparse.FloatTensor(idx, vals, (V, V)) - L
    return L
    
def texture_range_loss(D, S, R, weight, left, right):
    lossS = torch.pow(torch.where(S < 0, -S, torch.where(S > 1, S - 1, torch.zeros_like(S))), 2)
    lossD = torch.pow(torch.where(D < 0, -D, torch.where(D > 1, D - 1, torch.zeros_like(D))), 2)
    lossR = torch.pow(torch.where(R < left, left - R, torch.where(R > right, R - right, torch.zeros_like(R))), 2)
    loss = (lossS.mean() + lossD.mean() + lossR.mean()) * weight
    return loss

def texture_correlation_loss(D, S, R, res, weightS, weightR):
    # cosine similiarty 
    cos3 = torch.nn.CosineSimilarity(dim=0, eps=1e-6)
    cos1 = torch.nn.CosineSimilarity(dim=0, eps=1e-6)
    D_np_3 = D.cpu().detach().numpy()
    D_np_1 = np.mean(D_np_3, axis=1, keepdims=1)    
    lossS = 1 - cos3(S, torch.from_numpy(D_np_3).float().cuda())
    lossR = 1 - cos1(S, torch.from_numpy(D_np_1).float().cuda())
    loss = lossS.mean() * weightS + lossR.mean() * weightR 
    return loss 
    

def total_variation_loss(_R, res, weightR):
    R = _R.reshape((res, res))
    varR = torch.mean(torch.pow(R[:-1,...] - R[1:,...], 2))\
         + torch.mean(torch.pow(R[:,:-1,...] - R[:,1:,...], 2))
    loss = varR * weightR  
    return loss
