import numpy as np
import torch,os,sys
import torch.nn as nn
import skimage.measure
from utils import print_gpu_usage

class BoundingBox:
    def __init__(self, pMin, pMax, res) -> None:
        assert(len(pMin) == 3 and len(pMax) == 3 and len(res) == 3)
        self.pMin = pMin
        self.pMax = pMax
        self.res = res
        self.voxel_size = (pMax - pMin) / res

def meshgrid(pMin: torch.Tensor,
             pMax: torch.Tensor,
             res: torch.Tensor):
    x, y, z = torch.meshgrid([torch.linspace(pMin[i], pMax[i], res[i], device=pMin.device) for i in range(3)])
    return torch.stack([x.flatten(), y.flatten(), z.flatten()]).t()


def convert_sdf_samples_to_mesh(sdf_tensor,
                                voxel_grid_origin,
                                voxel_size,
                                offset = None,
                                scale = None):
    sdf_tensor_numpy = sdf_tensor.cpu().numpy()
    verts, faces, normals, values = skimage.measure.marching_cubes(sdf_tensor_numpy, level=0.0, spacing=voxel_size.cpu().numpy())
    verts = torch.tensor(verts.astype(float), requires_grad=False, dtype=torch.float, device=sdf_tensor.device)
    faces = torch.tensor(faces.astype(int), requires_grad=False, dtype=torch.int, device=sdf_tensor.device)
    mesh_points = voxel_grid_origin + verts
    return mesh_points, faces


class IsoSurfaceFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, sdf_value, verts, decoder):
        verts = verts.clone().detach().requires_grad_(True)
        ctx.save_for_backward(verts)
        ctx.decoder = decoder
        return verts

    @staticmethod
    def backward(ctx, grad_output):
        verts, = ctx.saved_tensors
        with torch.enable_grad():
            pred_sdf = ctx.decoder(verts)
            loss_normals = torch.sum(pred_sdf)
            normals, = torch.autograd.grad(loss_normals, verts)
            normals = normals / torch.norm(normals, 2, 1, keepdim=True)
            return -torch.matmul(grad_output.unsqueeze(1), normals.unsqueeze(-1)).squeeze(), None, None


class IsoSurface(nn.Module):
    def __init__(self, sdf, bbox: BoundingBox, scaling=torch.ones(3), save_memory = True):
        super().__init__()
        self.sdf = sdf
        self.bbox = bbox
        self.samples = meshgrid(bbox.pMin, bbox.pMax, bbox.res).cuda() / scaling.cuda()
        self.num_samples = self.samples.shape[0] 
        if save_memory:
            self.samples = torch.split(self.samples, 8192)

    def forward(self):
        if type(self.samples) != tuple:
            sdf_values = self.sdf(self.samples)
        else:
            sdf_values = torch.empty(self.num_samples, device='cuda')
            cbegin = 0
            with torch.no_grad():
                for chunk in self.samples:
                    # print_gpu_usage("{0} samples loaded".format(cbegin))
                    cend = cbegin + chunk.shape[0]
                    sdf_values[cbegin : cend] = self.sdf(chunk.cuda())
                    cbegin = cend
        sdf_values = sdf_values.reshape(tuple(self.bbox.res))
        verts, faces = convert_sdf_samples_to_mesh(sdf_values.detach(), self.bbox.pMin, self.bbox.voxel_size)
        pred_sdf = self.sdf(verts)
        verts = IsoSurfaceFunction.apply(pred_sdf, verts, self.sdf)
        return verts, faces