# Copyright @yucwang 2022

import igl
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter

def compute_mesh_distance_from_files(source_file, target_file):
    Vi, _, _, Fi, _, _ = igl.read_obj(source_file)
    Vt, _, _, Ft, _, _ = igl.read_obj(target_file)
    return compute_mesh_distance(Vi, Fi, Vt, Ft)


def compute_mesh_distance(Vi, Fi, Vt, Ft):
    max_len = (Vt.max(axis=0) - Vt.min(axis=0)).max()
    _Vi = Vi / max_len
    _Vt = Vt / max_len

    dist1, _, _ = igl.point_mesh_squared_distance(_Vi, _Vt, Ft)
    dist2, _, _ = igl.point_mesh_squared_distance(_Vt, _Vi, Fi)

    return np.sqrt(dist1).mean() + np.sqrt(dist2).mean()

class TensorboardDashboard:
    def __init__(self, log_dir, func_names = None):
        self.log_dir = log_dir
        self.func_names = []
        if func_names != None:
            for func_name in func_names:
                self.func_names.append(func_name)

        self.cur_iteration = 0
        self.summary_writer = SummaryWriter(log_dir=self.log_dir)

    def reset(self):
        self.cur_iteration = 0

    def add_function(self, func_name):
        if func_name in self.func_names:
            print("Warning: function {} has already existed.")
        self.func_names.append(func_name)

    def write_to_summary(self, func_name, func_type, value):
        if func_name not in self.func_names:
            print("Error: function {} not registered.".format(func_name))
            return
        if func_type == 'scalar':
            self.summary_writer.add_scalar(func_name, value, self.cur_iteration)
        elif func_type == 'image':
            self.summary_writer.add_image(func_name, value, self.cur_iteration, dataformats='HWC')

    def step(self):
        self.cur_iteration = self.cur_iteration + 1