import torch 
import torch.nn as nn 
import numpy as np
import cv2
# import enoki as ek


def compute_render_loss(our_img, tar_img, weight):
    loss = 0
    I = our_img
    T = tar_img
    diff = (I - T)
    diff = diff.nan_to_num(nan=0)
    diff = diff.abs()
    loss += diff.sum()
    return loss * weight

def L2_loss(our_img, tar_img, weight):
    loss = 0
    I = our_img
    T = tar_img
    diff = (I - T)
    diff = diff.nan_to_num(nan=0)
    diff = diff.square()
    loss += diff.sum()
    return loss * weight

def uniform_laplacian(verts, edges):
    # compute L once per mesh subdiv. 
    with torch.no_grad():
        V = verts.shape[0]
        e0 = edges[:, 0]
        e1 = edges[:, 1]
        # e0, e1, _, _ = edges.unbind(1)
        idx01 = torch.stack([e0, e1], dim=1)  # (E, 2)
        idx10 = torch.stack([e1, e0], dim=1)  # (E, 2)
        idx = torch.cat([idx01, idx10], dim=0).t()  # (2, 2*E)
        # First, we construct the adjacency matrix,
        # i.e. A[i, j] = 1 if (i,j) is an edge, or
        # A[e0, e1] = 1 &  A[e1, e0] = 1
        ones = torch.ones(idx.shape[1], dtype=torch.float32, device=verts.device)

        # We construct the Laplacian matrix by adding the non diagonal values
        # i.e. L[i, j] = 1 if (i, j) is an edge
        L = torch.sparse.FloatTensor(idx, ones, (V, V))

        # Then we add the diagonal values L[i, i] = -1.
        idx = torch.arange(V, device=verts.device)
        idx = torch.stack([idx, idx], dim=0)
        ones = torch.ones(idx.shape[1], dtype=torch.float32, device=verts.device)
        L -= torch.sparse.FloatTensor(idx, -ones, (V, V))

        vals = torch.sparse.sum(L, dim=0).to_dense()
        indices = torch.arange(V)
        idx = torch.stack([indices, indices], dim=0)
        L = torch.sparse.FloatTensor(idx, vals, (V, V)) - L
    return L

def segment_sum(data: torch.Tensor, segment_ids: torch.Tensor) -> torch.Tensor:
    num_segments = torch.unique_consecutive(segment_ids).shape[0]

    # Repeats ids until same dimension as data
    if len(segment_ids.shape) == 1:
        s = torch.prod(torch.tensor(data.shape[1:], dtype=torch.int64)).long()
        segment_ids = segment_ids.repeat_interleave(s).view(segment_ids.shape[0], *data.shape[1:])

    assert data.shape == segment_ids.shape, "data.shape and segment_ids.shape should be equal"

    shape = [num_segments] + list(data.shape[1:])
    result = torch.zeros(*shape, dtype=torch.float32)
    result = result.scatter_add(0, segment_ids, data)
    return result

def laplace_regularizer_const(V, E, F):
    class mesh_op_laplace_regularizer_const:
        def __init__(self, V, E, F):
            self.inputs = [V, E, F]

            self.nVerts = V.shape[0]
            t_pos_idx = F 

            # Build vertex neighbor rings
            vtx_n = [[] for _ in range(self.nVerts)]
            for tri in t_pos_idx:
                for (i0, i1) in [(tri[0], tri[1]), (tri[1], tri[2]), (tri[2], tri[0])]:
                    vtx_n[i0].append(i1)

            # Collect index/weight pairs to compute each Laplacian vector for each vertex.
            # Similar notation to https://mgarland.org/class/geom04/material/smoothing.pdf
            ix_j, ix_i, w_ij = [], [], []
            for i in range(self.nVerts):
                m = len(vtx_n[i])
                ix_i += [i] * m
                ix_j += vtx_n[i]
                w_ij += [1.0 / m] * m

            # Setup torch tensors
            self.ix_i = torch.tensor(ix_i, dtype=torch.int64)
            self.ix_j = torch.tensor(ix_j, dtype=torch.int64)
            self.w_ij = torch.tensor(w_ij, dtype=torch.float32)[:, None]

        def eval(self, V):

            # differences or absolute version (see paper)
            v_pos = V

            # Gather edge vertex pairs
            x_i = v_pos[self.ix_i, :]
            x_j = v_pos[self.ix_j, :]

            # Compute Laplacian differences: (x_j - x_i) * w_ij
            term = (x_j - x_i) * self.w_ij

            # Sum everyhing
            term = segment_sum(term, self.ix_i)
            
            return torch.sum(term**2)
    
    return mesh_op_laplace_regularizer_const(V, E, F)
