import nvidia_smi
import numpy as np
import torch, math
import cv2

def _bytes_to_megabytes(bytes):
    return round((bytes/1024)/1024,2)

def print_gpu_usage(message=""):
    nvidia_smi.nvmlInit()
    handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
    info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
    print("[INFO] MEMORY USAGE:".format(message), _bytes_to_megabytes(info.used), "/", _bytes_to_megabytes(info.total), 'MB - ', message)

def write_obj(verts, faces, fname='test.obj'):
    verts = np.copy(verts)
    faces = np.copy(faces)
    faces += 1
    thefile = open(fname, 'w')
    for item in verts:
        thefile.write("v {0} {1} {2}\n".format(item[0], item[1], item[2]))
    for item in faces:
        thefile.write(
            # "f {0}//{0} {1}//{1} {2}//{2}\n".format(item[0], item[1], item[2]))
            "f {0} {1} {2}\n".format(item[0], item[1], item[2]))
    thefile.close()

def write_obj_uv(verts, faces, uv, faces_uv, fname='test.obj'):
    verts = np.copy(verts)
    faces = np.copy(faces)
    uv = np.copy(uv)
    faces_uv = np.copy(faces_uv)

    faces += 1
    faces_uv += 1

    thefile = open(fname, 'w')
    for item in verts:
        thefile.write("v {0} {1} {2}\n".format(item[0], item[1], item[2]))
    for item in uv:
        thefile.write("vt {0} {1}\n".format(item[0], item[1]))
    for face_id, uv_id in zip(faces, faces_uv):
        thefile.write(
            "f {0}//{3} {1}//{4} {2}//{5}\n".format(face_id[0], face_id[1], face_id[2], uv_id[0], uv_id[1], uv_id[2]))
    thefile.close()

def make_grid(arr, ncols=2):
    n, height, width, nc = arr.shape
    nrows = n//ncols
    assert n == nrows*ncols
    return arr.reshape(nrows, ncols, height, width, nc).swapaxes(1,2).reshape(height*nrows, width*ncols, nc)


def fibonacci_sphere(samples, dist):
    points = []
    phi = math.pi * (3. - math.sqrt(5.))  # golden angle in radians
    for i in range(samples):
        y = 1 - (i / float(samples - 1)) * 2  # y goes from 1 to -1
        radius = math.sqrt(1 - y * y)  # radius at y
        theta = phi * i  # golden angle increment
        x = math.cos(theta) * radius
        z = math.sin(theta) * radius
        points.append([dist*x, dist*y, dist*z])
    return points


def rgb2bgr(image):
    return cv2.cvtColor(image, cv2.COLOR_RGB2BGR)

def bgr2rgb(image):
    return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

def encode2SRGB(v):
	if (v <= 0.0031308):
		return (v * 12.92) * 255.0
	else:
		return (1.055*(v**(1.0/2.4))-0.055) * 255.0

def write_exr(image_name, image):
    if '.exr' not in image_name: 
        image_name += '.exr'
    cv2.imwrite(image_name, rgb2bgr(image))

def read_exr(image_name):
    if '.exr' not in image_name: 
        image_name += '.exr'
    return bgr2rgb(cv2.imread(image_name, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH))

def write_png(image_name, image, mapping=True):
    if mapping:
        image = np.vectorize(encode2SRGB)(image)
    else:
        image *= 255
    image = np.clip(np.rint(image), 0, 255).astype(np.uint8)
    if '.png' not in image_name: 
        image_name += '.png'
    cv2.imwrite(image_name, rgb2bgr(image))
