import numpy as np
import matplotlib.pyplot as plt


def image_loss(npass, data, output_dir):
    x = np.arange(0, npass)
    fig = plt.figure()
    plt.title("Image loss")
    plt.xlabel("Iteration")
    plt.ylabel("Loss")
    plt.plot(x, data, color ="red")
    plt.savefig(output_dir+"loss.png", dpi=100, bbox_inches='tight')
    plt.close()

def para_loss(npass, data, output_dir):
    x = np.arange(0, npass)
    fig = plt.figure()
    plt.title("Parameter difference")
    plt.xlabel("Iteration")
    plt.ylabel("Para diff")
    plt.plot(x, data, color ="red")
    plt.savefig(output_dir+"para.png", dpi=100, bbox_inches='tight')
    plt.close()
