import torch 
import torch.nn as nn 
import numpy as np
import cv2
# import enoki as ek


def compute_render_loss(our_img, tar_img, weight):
    loss = 0
    I = our_img
    T = tar_img
    diff = (I - T)
    diff = diff.nan_to_num(nan=0)
    diff = diff.abs()
    loss += diff.sum()
    return loss * weight

def L2_loss(our_img, tar_img, weight):
    loss = 0
    I = our_img
    T = tar_img
    diff = (I - T)
    diff = diff.nan_to_num(nan=0)
    diff = diff.square()
    loss += diff.sum()
    return loss * weight

def uniform_laplacian(verts, edges):
    # compute L once per mesh subdiv. 
    with torch.no_grad():
        V = verts.shape[0]
        e0 = edges[:, 0]
        e1 = edges[:, 1]
        # e0, e1, _, _ = edges.unbind(1)
        idx01 = torch.stack([e0, e1], dim=1)  # (E, 2)
        idx10 = torch.stack([e1, e0], dim=1)  # (E, 2)
        idx = torch.cat([idx01, idx10], dim=0).t()  # (2, 2*E)
        # First, we construct the adjacency matrix,
        # i.e. A[i, j] = 1 if (i,j) is an edge, or
        # A[e0, e1] = 1 &  A[e1, e0] = 1
        ones = torch.ones(idx.shape[1], dtype=torch.float32, device=verts.device)

        # We construct the Laplacian matrix by adding the non diagonal values
        # i.e. L[i, j] = 1 if (i, j) is an edge
        L = torch.sparse.FloatTensor(idx, ones, (V, V))

        # Then we add the diagonal values L[i, i] = -1.
        idx = torch.arange(V, device=verts.device)
        idx = torch.stack([idx, idx], dim=0)
        ones = torch.ones(idx.shape[1], dtype=torch.float32, device=verts.device)
        L -= torch.sparse.FloatTensor(idx, -ones, (V, V))

        vals = torch.sparse.sum(L, dim=0).to_dense()
        indices = torch.arange(V)
        idx = torch.stack([indices, indices], dim=0)
        L = torch.sparse.FloatTensor(idx, vals, (V, V)) - L
    return L
