import numpy as np
import sys
import os
import shutil

from scipy import interpolate
import scipy.ndimage.filters as sfilters

sys.path.append("utils/")
from exr_utils import *
from concentric import *

from opt_heightmap import opt_heightmap
from compute_ST import compute_ST

downsample_scale = 64


def resample_heightmap(img, reso, texel_size):
    ylim = img.shape[0] * texel_size * 0.5
    xlim = img.shape[1] * texel_size * 0.5
    y = np.linspace(-ylim, ylim, img.shape[0])
    x = np.linspace(-xlim, xlim, img.shape[1])

    f = interpolate.interp2d(x, y, img)
    y_new = np.linspace(-ylim, ylim, reso)
    x_new = np.linspace(-xlim, xlim, reso)
    img_new = f(x_new, y_new)

    min_height = np.min(img_new)
    img_new += -min_height + 0.01
    return img_new, xlim, ylim


def compute_slopes(out_dir):
    os.chdir(out_dir)

    # usage: mtsutil LEADR_plane [scene_filename] [resolution] [sqrtSpp] [downsample_scale]
    cmd = "\"..\..\mitsuba\dist\mtsutil\""
    cmd += " LEADR_plane"
    cmd += " gen_slopes.xml"
    cmd += " {:d} {:d} {:d}".format(1024, 4, downsample_scale)
    os.system(cmd)

    name = "slopes/"
    if not os.path.exists(name):
        os.makedirs(name)
    try:
        os.rename("uv.exr", "{:s}/uv.exr".format(name))
    except:
        pass
    ds = 1
    while ds <= downsample_scale:
        fp = "moments0_{:d}x.exr".format(ds)
        try:
            os.rename(fp, "{:s}/{:s}".format(name, fp))
        except:
            pass
        fp = "moments1_{:d}x.exr".format(ds)
        try:
            os.rename(fp, "{:s}/{:s}".format(name, fp))
        except:
            pass

        ds *= 2

    os.chdir("../../python/")


def compute_vmf_lobes(out_dir, xlim, ylim):
    os.chdir(out_dir)

    # these parameters are set in the scene file "gen_vmf_lobes.xml"
    scale = 40
    x_scale = xlim * 40
    y_scale = ylim * 40

    # usage: mtsutil normalMipmap [num_lobes] [resolution] [sqrtSpp] [downsample_scale] [scene_filename]
    cmd = "\"..\..\mitsuba\dist\mtsutil\""
    cmd += " normalMipmap"
    cmd += " {:d} {:d} {:d} {:d}".format(6, 1024, 8, downsample_scale)
    cmd += " gen_vmf_lobes.xml"
    os.system(cmd)

    name = "vmf_lobes/"
    if not os.path.exists(name):
        os.makedirs(name)
    for i in range(6):
        fp = "vmf_64x_lobe_{:d}.exr".format(i)
        os.rename(fp, "{:s}/{:s}".format(name, fp))

    os.chdir("../../python/")


if __name__ == "__main__":
    # load image
    img = load_exr("../data/disp_maps/twill.exr")

    out_dir = "../data/prefilter_twill/"
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    # resample image to (1k)^2 heightmap
    img, xlim, ylim = resample_heightmap(img, 1024, 0.001)
    save_exr(out_dir + "base_heightmap_1.exr", img)

    # compute slopes
    compute_slopes(out_dir)

    # optimize for lower-resolution heightmap
    opt_heightmap(out_dir, downsample_scale)

    # compute vmf lobes
    compute_vmf_lobes(out_dir, xlim, ylim)

    # compute S and T
    compute_ST(out_dir)

    # copy files
    scene_dir = "../data/render_twill/"
    shutil.copytree(out_dir + "vmf_lobes", scene_dir + "vmf_lobes")
    shutil.copyfile(out_dir + "base_heightmap_1.exr", scene_dir + "base_heightmap_1.exr")
    shutil.copyfile(out_dir + "opt_heightmap_64.exr", scene_dir + "opt_heightmap_64.exr")
    shutil.copyfile(out_dir + "S.exr", scene_dir + "S.exr")
    shutil.copyfile(out_dir + "T.exr", scene_dir + "T.exr")
