import numpy as np
import sys
import os

from scipy.sparse import csr_matrix, csc_matrix
from scipy.sparse import tril
from scipy.sparse.linalg import inv, spsolve

sys.path.append("utils/")
from exr_utils import *


def opt_heightmap(out_dir, downsample_scale):
    os.chdir(out_dir)
    slopes = load_exr("slopes/moments0_{:d}x.exr".format(downsample_scale))
    N = slopes.shape[0]
    slopes /= N * 0.5

    # build coefficient matrix
    def rc2index(r, c):
        return (r % N) * N + (c % N)

    non_zero_per_vertex = 4 * 2 + 5
    row = np.zeros((non_zero_per_vertex * N * N), dtype=np.float)
    col = np.zeros((non_zero_per_vertex * N * N), dtype=np.float)
    data = np.zeros((non_zero_per_vertex * N * N), dtype=np.float)

    cnt = 0
    for r in range(N):
        for c in range(N):
            idx = [[rc2index(r, c), rc2index(r + 1, c)],
                   [rc2index(r, c + 1), rc2index(r + 1, c + 1)]]

            # x_slope
            for i in range(2):
                for j in range(2):
                    row[cnt] = idx[0][0] * 3
                    col[cnt] = idx[i][j]
                    if i == 0:
                        data[cnt] = -0.5
                    else:
                        data[cnt] = 0.5
                    cnt += 1

            # y_slope
            for i in range(2):
                for j in range(2):
                    row[cnt] = idx[0][0] * 3 + 1
                    col[cnt] = idx[i][j]
                    if j == 0:
                        data[cnt] = -0.5
                    else:
                        data[cnt] = 0.5
                    cnt += 1

    # laplacian regularization
    w_reg = np.sqrt(0.01)
    for r in range(N):
        for c in range(N):
            idx = rc2index(r, c)
            for i in [-1, 1]:
                row[cnt] = idx * 3 + 2
                col[cnt] = rc2index(r + i, c)
                data[cnt] = w_reg
                cnt += 1

                row[cnt] = idx * 3 + 2
                col[cnt] = rc2index(r, c + i)
                data[cnt] = w_reg
                cnt += 1

            row[cnt] = idx * 3 + 2
            col[cnt] = idx
            data[cnt] = -4 * w_reg
            cnt += 1

    A = csr_matrix((data, (row, col)), shape=(3 * N * N, N * N))

    # build y
    y = np.zeros((3 * N * N), dtype=np.float)
    for r in range(N):
        for c in range(N):
            idx = rc2index(r, c)
            y[idx * 3] = slopes[r][c][0]
            y[idx * 3 + 1] = slopes[r][c][1]
            y[idx * 3 + 2] = 0

    At = A.transpose()
    C = csr_matrix.dot(At, A)
    b = At.dot(y)

    x = spsolve(C, b)
    h = np.reshape(x, (N, N))
    h_min = np.min(x)
    h -= h_min - 1e-2
    save_exr("opt_heightmap_{:d}.exr".format(downsample_scale), h)

    os.chdir("../../python/")

