import matplotlib
import numpy as np
import matplotlib.pyplot as plt
import igl

def image_loss(npass, data, output_dir):
    x = np.arange(0, npass)
    fig = plt.figure()
    plt.title("Image loss")
    plt.xlabel("Iteration")
    plt.ylabel("Loss")
    plt.plot(x, data, color ="red")
    plt.savefig(output_dir+"loss.png", dpi=100, bbox_inches='tight')
    plt.close()

def para_loss(npass, data, output_dir):
    x = np.arange(0, npass)
    fig = plt.figure()
    plt.title("Parameter difference")
    plt.xlabel("Iteration")
    plt.ylabel("Para diff")
    plt.plot(x, data, color ="red")
    plt.savefig(output_dir+"para.png", dpi=100, bbox_inches='tight')
    plt.close()

def write_obj(verts, faces, fname='test.obj'):
    verts = np.copy(verts)
    faces = np.copy(faces)
    faces += 1
    file = open(fname, 'w')
    for item in verts:
        file.write("v {0} {1} {2}\n".format(item[0], item[1], item[2]))
    for item in faces:
        if len(item) == 3:
            file.write(
                # "f {0}//{0} {1}//{1} {2}//{2}\n".format(item[0], item[1], item[2]))
                "f {0} {1} {2}\n".format(item[0], item[1], item[2]))
        elif len(item) == 4:
            file.write(
                # "f {0}//{0} {1}//{1} {2}//{2} {3}//{3}\n".format(item[0], item[1], item[2], item[3]))
                "f {0} {1} {2} {3}\n".format(item[0], item[1], item[2], item[3]))
        else:
            raise ValueError("Invalid face")
    file.close()


def write_obj_texture(V, VT, Vi, Vti, fname='test.obj'):
    V = np.copy(V)
    VT = np.copy(VT)
    Vi = np.copy(Vi)
    Vti = np.copy(Vti)
    Vi += 1
    Vti += 1
    assert(Vi.shape[0] == Vti.shape[0])
    file = open(fname, 'w')

    for item in V:
        file.write("v {0} {1} {2}\n".format(item[0], item[1], item[2]))
    
    for item in VT:
        file.write("vt {0} {1}\n".format(item[0], item[1]))

    for i in range(Vi.shape[0]):
        vi = Vi[i]
        vti = Vti[i]
        if len(vi) == 4:
            file.write("f {0}/{1} {2}/{3} {4}/{5} {6}/{7}\n".format(vi[0], vti[0], vi[1], vti[1], vi[2], vti[2], vi[3], vti[3]))
        elif len(vi) == 3:
            file.write("f {0}/{1} {2}/{3} {4}/{5}\n".format(vi[0], vti[0], vi[1], vti[1], vi[2], vti[2]))
    file.close()


def get_mesh_error(Vi, Fi, Vt, Ft):
    Vi = Vi.copy()
    Vt = Vt.copy()
    Fi = Fi.copy()
    Ft = Ft.copy()
    norm = matplotlib.colors.Normalize(vmin=0.0, vmax=0.03)
    max_len = (Vt.max(axis=0) - Vt.min(axis=0)).max()
    Vi /= max_len
    Vt /= max_len
    dist1, _, _ = igl.point_mesh_squared_distance(Vi, Vt, Ft)
    dist2, _, _ = igl.point_mesh_squared_distance(Vt, Vi, Fi)
    result_error = (norm(dist1).mean() + norm(dist2).mean()) / 2.0
    return result_error