import numpy as np
import sys
import os
import time
import psutil
import shutil

from scipy import interpolate
import scipy.ndimage.filters as sfilters

sys.path.append("utils/")
from exr_utils import *
from concentric import *

from twill_config import *

run_gi = True

# change mitsuba path
os.environ["PATH"] = "D:\Lifan\OneDrive - UC San Diego\s2019_code_data\mitsuba\dist\;" + os.environ["PATH"]
mts = "mtsutil"


class Solve(object):
    def __init__(self, config):
        self.config = config
        self.configure()

    def configure(self):
        self.config["height_filename"] = self.config["data_dir"] + "base_heightmap_1.exr"
        self.config["downsampled_height_filename"] = self.config["data_dir"] + \
                                                     "opt_heightmap_%d.exr" % (downsample_scale)
        self.hmap = load_exr(self.config["height_filename"])
        self.hmap_low = load_exr(self.config["downsampled_height_filename"])
        self.config["x"] = self.hmap.shape[1]
        self.config["y"] = self.hmap.shape[0]
        print("x = {:d}, y = {:d}".format(self.config["x"], self.config["y"]))

        self.s_params = {}
        self.t_params = {}
        u_texel_size = 1.0 / (self.config["x"] * self.config["tile_x"])
        v_texel_size = 1.0 / (self.config["y"] * self.config["tile_y"])

        num_texels_per_block = float(self.config["s_block_size"])
        self.s_params["u_block_size"] = u_texel_size * num_texels_per_block
        self.s_params["v_block_size"] = v_texel_size * num_texels_per_block
        attenuation_map_scale = self.config["x"] // self.config["s_xy_reso"]
        self.s_params["u_min"] = 0.5 - u_texel_size * (num_texels_per_block - attenuation_map_scale) * 0.5
        self.s_params["v_min"] = 0.5 - v_texel_size * (num_texels_per_block - attenuation_map_scale) * 0.5
        self.s_params["u_step"] = u_texel_size * attenuation_map_scale
        self.s_params["v_step"] = v_texel_size * attenuation_map_scale

        num_texels_per_block = float(self.config["t_block_size"])
        self.t_params["u_block_size"] = u_texel_size * num_texels_per_block
        self.t_params["v_block_size"] = v_texel_size * num_texels_per_block
        attenuation_map_scale = self.config["x"] // self.config["t_xy_reso"]
        self.t_params["u_min"] = 0.5 - u_texel_size * (num_texels_per_block - attenuation_map_scale) * 0.5
        self.t_params["v_min"] = 0.5 - v_texel_size * (num_texels_per_block - attenuation_map_scale) * 0.5
        self.t_params["u_step"] = u_texel_size * attenuation_map_scale
        self.t_params["v_step"] = v_texel_size * attenuation_map_scale

    def sample_directions(self, N):
        wis = np.zeros((N, 3), dtype=np.float)
        wos = np.zeros((N, 3), dtype=np.float)
        for i in range(N):
            wis[i][0], wis[i][1], wis[i][2] = squareToCosineHemisphere(np.random.uniform(0.1, 0.9), np.random.uniform(0.1, 0.9))
            wos[i][0], wos[i][1], wos[i][2] = squareToCosineHemisphere(np.random.uniform(0.1, 0.9), np.random.uniform(0.1, 0.9))
        return wis, wos

    def prepare_params_input(self, u_st, v_st, wis, wos,
                             res_dir, res_filename):
        N = wis.shape[0]
        params = np.zeros((N, 10), dtype=np.float)

        x_scale = self.config["x_scale"]
        y_scale = self.config["y_scale"]
        x_st = -x_scale + u_st * 2.0 * x_scale
        x_ed = -x_scale + (u_st + self.t_params["u_block_size"]) * 2.0 * x_scale
        y_st = -y_scale + v_st * 2.0 * y_scale
        y_ed = -y_scale + (v_st + self.t_params["v_block_size"]) * 2.0 * y_scale
        params[:, 0] = x_st
        params[:, 1] = x_ed
        params[:, 2] = y_st
        params[:, 3] = y_ed

        params[:, 4:7] = wis
        params[:, 7:10] = wos
        np.savetxt(res_dir + res_filename, params, fmt="%.8f")
        return res_dir + res_filename

    def run_bsdf_simulator(self, scene_filename, wi, sqrt_spp,
                           u_st, v_st, use_gi, res_filename,
                           s_filename, t_filename, to_stdout=False):
        n_cores = psutil.cpu_count()
        if not to_stdout:
            cmd = mts + " -p {:d} -q bsdfSimulator".format(n_cores - 2)
        else:
            cmd = mts + " -p {:d} bsdfSimulator".format(n_cores - 2)

        args = " " + self.path_wrapper(scene_filename)
        args += " " + str(wi[0]) + " " + str(wi[1]) + " " + str(wi[2])
        args += " " + str(sqrt_spp)
        args += " " + str(self.config["wo_reso"])

        x_scale = self.config["x_scale"]
        y_scale = self.config["y_scale"]
        x_st = -x_scale + u_st * 2.0 * x_scale
        x_ed = -x_scale + (u_st + self.s_params["u_block_size"]) * 2.0 * x_scale
        y_st = -y_scale + v_st * 2.0 * y_scale
        y_ed = -y_scale + (v_st + self.s_params["v_block_size"]) * 2.0 * y_scale
        args += " " + str(x_st) + " " + str(x_ed)
        args += " " + str(y_st) + " " + str(y_ed)

        if use_gi:
            args += " " + str(self.config["min_depth"]) + " " + str(self.config["max_depth"])
        else:
            args += " 1 1"
        args += " " + str(self.config["shadow_option"])
        args += " " + self.path_wrapper(res_filename)

        use_full_sphere = 0
        dist_gi_texel_scale = self.config["x"] / self.config["s_block_size"] * 1.0
        args += " " + str(use_full_sphere)
        args += " " + str(dist_gi_texel_scale)

        if s_filename != "":
            args += " " + self.path_wrapper(s_filename)
            if t_filename == "":
                print("Error: should give both S and T...")
            args += " " + self.path_wrapper(t_filename)

        if to_stdout:
            print(cmd + args)
        os.system(cmd + args)

    def run_eff_brdf_estimator(self, scene_filename, params_filename, sqrt_spp, use_gi,
                               s_filename, t_filename, res_filename, to_stdout=False):
        n_cores = psutil.cpu_count()
        if not to_stdout:
            cmd = mts + " -p {:d} -q effBrdfEstimator".format(n_cores - 2)
        else:
            cmd = mts + " -p {:d} effBrdfEstimator".format(n_cores - 2)

        args = " " + self.path_wrapper(scene_filename)
        args += " " + self.path_wrapper(params_filename)
        args += " " + str(sqrt_spp)

        if use_gi:
            args += " " + str(self.config["min_depth"]) + " " + str(self.config["max_depth"])
        else:
            args += " 1 1"
        args += " " + str(self.config["shadow_option"]) + " " + str(self.config["shadow_option"])

        dist_gi_texel_scale = self.config["x"] / self.config["t_block_size"] * 0.5
        args += " " + str(dist_gi_texel_scale)
        args += " " + self.path_wrapper(res_filename)

        if s_filename != "":
            args += " " + self.path_wrapper(s_filename)
            if t_filename == "":
                print("Error: should give both S and T...")
            args += " " + self.path_wrapper(t_filename)

        if to_stdout:
            print(cmd + args)
        os.system(cmd + args)

        values = np.loadtxt(res_filename + "_values.txt")
        return values

    def gen_ref_lobes(self):
        res_base_dir = self.config["data_dir"] + self.config["s_high_res_dir"] + "/"
        if not os.path.exists(res_base_dir):
            os.makedirs(res_base_dir)

        x_size = self.config["s_xy_reso"]
        y_size = self.config["s_xy_reso"]

        wi_reso = self.config["wi_reso"]
        wi = np.zeros((3), dtype=np.float)

        smallest = 0.02
        t_wi = np.linspace(smallest, 1.0 - smallest * 0.9, wi_reso) # asymmetry to avoid numerical issues

        for y in range(y_size):
            for x in range(x_size):
                u_st = self.s_params["u_min"] + self.s_params["u_step"] * x
                v_st = self.s_params["v_min"] + self.s_params["v_step"] * y
                print("====== (%.6f, %.6f) ======" % (u_st, v_st))

                res_dir = res_base_dir + ("p_%d_%d" % (x, y)) + "/"
                if not os.path.exists(res_dir):
                    os.makedirs(res_dir)

                for r in range(wi_reso):
                    for c in range(wi_reso):
                        wi[0], wi[1], wi[2] = squareToHemisphere(t_wi[c], t_wi[r])
                        print(r, c, wi)

                        res_filename = self.config["s_res_filename_prefix"] + "wi_" + str(r) + "_" + str(c)
                        self.run_bsdf_simulator(self.config["data_dir"] + self.config["scene_filename"],
                                                wi,
                                                1000,
                                                u_st, v_st, True,
                                                res_dir + res_filename,
                                                "", "", False)

    def solve_for_T(self):
        res_base_dir = self.config["tmp_dir"] + self.config["t_low_res_dir"] + "/"
        if not os.path.exists(res_base_dir):
            os.makedirs(res_base_dir)

        x_size = self.config["t_xy_reso"]
        y_size = self.config["t_xy_reso"]

        epsilon = 1e-6
        T = np.zeros((y_size, x_size, 3), dtype=np.float)

        S_input = self.config["data_dir"] + "S_init.exr"
        T_input = self.config["data_dir"] + "T_init.exr"

        for y in range(y_size):
            for x in range(x_size):
                print("running (%d, %d)" % (x, y))

                u_st = self.t_params["u_min"] + self.t_params["u_step"] * x
                v_st = self.t_params["v_min"] + self.t_params["v_step"] * y
                res_dir = res_base_dir + ("p_%d_%d" % (x, y)) + "/"
                if not os.path.exists(res_dir):
                    os.makedirs(res_dir)

                N = self.config["t_4D_samples"]
                wis, wos = self.sample_directions(N)
                params_filename = self.prepare_params_input(u_st, v_st, wis, wos,
                                                            res_dir, "input_params.txt")

                if True:
                    values_high = self.run_eff_brdf_estimator(
                        self.config["data_dir"] + self.config["scene_filename"],
                        params_filename,
                        self.config["t_sqrt_samples_per_block"],
                        run_gi,
                        S_input, T_input, res_dir + "eff_brdf_high"
                    )
                else:
                    values_high = np.loadtxt(res_dir + "eff_brdf_high_values.txt")

                if True:
                    values_low = self.run_eff_brdf_estimator(
                        self.config["data_dir"] + self.config["scene_filename_low"],
                        params_filename,
                        self.config["t_sqrt_samples_per_block"],
                        False,
                        S_input, T_input, res_dir + "eff_brdf_low"
                    )
                else:
                    values_low = np.loadtxt(res_dir + "eff_brdf_low_values.txt")

                int_value = np.zeros((3), dtype=np.float)
                for k in range(N):
                    tmp_value = np.clip(np.divide(values_high[k] + epsilon, values_low[k] + epsilon), 0, 1e3)
                    pdf = wis[k][2] / np.pi * wos[k][2] / np.pi
                    int_value += tmp_value / pdf
                int_value /= N
                T[y, x, :] = int_value

                print(int_value)

        T_output = res_base_dir + self.config["t_scale_filename"] + ".exr"
        save_exr(T_output, T)
        return T_output

    def process_T(self, T_input, smooth):
        T = load_exr(T_input)
        T_blur = np.copy(T)

        if smooth:
            sigma = 0.5
            for k in range(3):
                T_blur[:, :, k] = sfilters.gaussian_filter(T[:, :, k], sigma, mode='wrap')

        T_avg = np.mean(np.mean(T_blur, axis=0), axis=0)
        for k in range(3):
            T_blur[:, :, k] /= T_avg[k]

        T_output = self.config["tmp_dir"] + self.config["t_low_res_dir"] + "/" + \
            self.config["t_scale_filename"] + "_blur.exr"
        save_exr(T_output, T_blur)
        return T_output

    def solve_for_S(self, T_input):
        res_base_dir = self.config["tmp_dir"] + self.config["s_low_res_dir"] + "/"
        if not os.path.exists(res_base_dir):
            os.makedirs(res_base_dir)

        x_size = self.config["s_xy_reso"]
        y_size = self.config["s_xy_reso"]

        wi_reso = self.config["wi_reso"]
        wi = np.zeros((3), dtype=np.float)

        smallest = 0.02
        t_wi = np.linspace(smallest, 1.0 - smallest * 0.9, wi_reso)  # asymmetry to avoid numerical issues

        wo_reso = self.config["wo_reso"]

        epsilon = 1e-6
        S = np.zeros((wi_reso * wi_reso, wo_reso * wo_reso, 3), dtype=np.float)

        S_input = self.config["data_dir"] + "S_init.exr"
        T_input = T_input
        S_output = res_base_dir + self.config["s_scale_filename"] + ".exr"

        for y in range(y_size):
            for x in range(x_size):
                u_st = self.s_params["u_min"] + self.s_params["u_step"] * x
                v_st = self.s_params["v_min"] + self.s_params["v_step"] * y

                res_dir = res_base_dir + ("p_%d_%d" % (x, y)) + "/"
                if not os.path.exists(res_dir):
                    os.makedirs(res_dir)

                if True:
                    for r in range(wi_reso):
                        for c in range(wi_reso):
                            wi[0], wi[1], wi[2] = squareToHemisphere(t_wi[c], t_wi[r])
                            print(r, c, wi)

                            res_filename = self.config["s_res_filename_prefix"] + "wi_" + str(r) + "_" + str(c)
                            self.run_bsdf_simulator(
                                self.config["data_dir"] + self.config["scene_filename_low"],
                                wi, self.config["s_sqrt_samples_per_block"],
                                u_st, v_st, False,
                                res_dir + res_filename,
                                S_input, T_input, False)

                for r in range(wi_reso):
                    for c in range(wi_reso):
                        low_res_filename = res_dir + self.config["s_res_filename_prefix"] + "wi_" + str(r) + "_" + str(c)
                        low_res_filename += "_order_1.exr"
                        lobe_low = load_exr(low_res_filename)

                        high_res_filename = self.config["data_dir"] + self.config["s_high_res_dir"] + "/" + \
                                            ("p_%d_%d" % (x, y)) + "/" + \
                                            self.config["s_res_filename_prefix"] + "wi_" + str(r) + "_" + str(c)

                        if run_gi:
                            high_res_filename += "_order_all.exr"
                        else:
                            high_res_filename += "_order_1.exr"
                        lobe_high = load_exr(high_res_filename)

                        wi_idx = r * wi_reso + c

                        for r1 in range(wo_reso):
                            for c1 in range(wo_reso):
                                wo_idx = r1 * wo_reso + c1
                                for k in range(3):
                                    h1 = lobe_low[r1][c1][k]
                                    h_ref = lobe_high[r1][c1][k]
                                    S[wi_idx][wo_idx][k] = min((h_ref + epsilon) / (h1 + epsilon), 1e3)

                res = S
                save_exr(S_output, res)

        return S_output

    def path_wrapper(self, name):
        return "\"" + name + "\""


def compute_ST(out_dir):
    os.chdir(out_dir)

    np.random.seed(23333)
    sol = Solve(twill_config)

    T_filename = sol.solve_for_T()
    T_filename = sol.process_T(T_filename, False)

    # ref lobes for S
    sol.gen_ref_lobes()
    S_filename = sol.solve_for_S(T_filename)

    shutil.copyfile(S_filename, "S.exr")
    shutil.copyfile(T_filename, "T.exr")

    os.chdir("../../python/")
