import argparse
from email.policy import default
import time
from typing import List
from cv2 import transform
import gin
import psdr_cpu
from pypsdr.validate import *
from pypsdr.utils.io import *
from pypsdr.utils.exr2png import convertEXR2ColorMap
import dataclasses
import os
import numpy as np
import torch
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
torch.random.manual_seed(10)
# import yep

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1"

# inject dependencies into the constructor

scale = 1.2 # lucy mirror: 50, nefertiti: 5 lamp bulb: 15 two triangles: 1.2 bunny in glass: 1.0 napoleon 5

@gin.configurable
# generate the constructor
@dataclasses.dataclass
class TestRunner:
    scene: psdr_cpu.Scene
    options: psdr_cpu.RenderOptions
    mala_options: psdr_cpu.MALAOptions
    integrator: psdr_cpu.Integrator
    aug_integrators: List = dataclasses.field(default_factory=lambda: [])
    guiding_options: GuidingOptions = GuidingOptions()
    out_dir: str = "./"
    render_type: str = "forward"
    suffix: str = ""
    delta: float = 0.01

    def __post_init__(self):
        os.makedirs(self.out_dir, exist_ok=True)
        os.makedirs(os.path.join(self.out_dir, "val"), exist_ok=True)
        self.width = self.scene.camera.width
        self.height = self.scene.camera.height
        self.prefix = type(self.integrator).__name__ + \
            "_" + str(self.options.spp) + "_"
        if self.suffix != "":
            self.prefix += self.suffix + "_"
        psdr_cpu.set_verbose(True)

    def renderC(self):
        image = self.integrator.renderC(
            self.scene, self.options).reshape(self.height, self.width, 3)
        imwrite(image, os.path.join(self.out_dir,  self.prefix + "forward.exr"))
        imwrite(image, os.path.join(self.out_dir,  self.prefix + "forward.png"))

    def renderD(self):
        psdr_cpu.set_forward(True)
        xform = Transform(gin.REQUIRED)
        if self.guiding_options.guide_type != "":
            self.scene.shapes[xform.shape_id].sort_config = self.guiding_options.sort_config
            self.scene.shapes[xform.shape_id].enable_draw = True
            self.scene.shapes[xform.shape_id].configure()

    
        for i in range(len(self.scene.shapes)):
            self.scene.shapes[i].enable_edge = False
        self.scene.shapes[xform.shape_id].enable_edge = True
        self.scene.configure()
        
        self.scene.shapes[xform.shape_id].sort_config = self.guiding_options.sort_config
        self.scene.shapes[xform.shape_id].enable_draw = True
        self.scene.shapes[xform.shape_id].configure()
        d_image = np.ones((self.height, self.width, 3))
        d_image[:, :, 1:3] = 0

        # dependency injection
        xform = Transform(gin.REQUIRED)
        assert(xform.shape_id >= 0)
        shape = self.scene.shapes[xform.shape_id]
        shape.requires_grad = True
        if xform.vertex_id >= 0:
            shape.vertex_idx = xform.vertex_id
        if type(xform.transformation) is Translation:
            # velocity = np.zeros_like(shape.vertices)
            # velocity[:, :] = xform.transformation.translation
            # shape.setVelocities(velocity)
            shape.setTranslation(xform.transformation.translation)
        elif type(xform.transformation) is Rotation:
            shape.setRotation(xform.transformation.rotation)
        else:
            assert(False)
        sceneAD = psdr_cpu.SceneAD(self.scene)
        img = self.integrator.renderD(
            sceneAD, self.options, d_image.reshape(-1))
        # self.integrator.renderD(
        #     sceneAD, self.options, d_image.reshape(-1))

        boundary_integrator = psdr_cpu.BoundaryIntegrator(sceneAD.val)
        ind_integrator = psdr_cpu.IndirectEdgeIntegrator(sceneAD.val)
        dir_integrator = psdr_cpu.DirectEdgeIntegrator(sceneAD.val)
        
        boundary_integrator.configure_mala(self.mala_options)
        if self.guiding_options.guide_type != "":
            if self.guiding_options.guide_option == "direct" or self.guiding_options.guide_option == "both":
                boundary_integrator.recompute_direct_edge(sceneAD.val)
                if self.guiding_options.guide_type == "grid":
                    dir_integrator.preprocess_grid(sceneAD.val, self.guiding_options.grid_config_direct, self.options.max_bounces)
                    boundary_integrator.preprocess_grid_direct(
                        sceneAD.val, self.guiding_options.grid_config_direct, self.options.max_bounces)
                else:
                    dir_integrator.preprocess_aq(sceneAD.val, self.guiding_options.aq_config_direct, self.options.max_bounces)
                    boundary_integrator.preprocess_aq_direct(
                        sceneAD.val, self.guiding_options.aq_config_direct, self.options.max_bounces)
            if self.guiding_options.guide_option == "indirect" or self.guiding_options.guide_option == "both":
                boundary_integrator.recompute_indirect_edge(sceneAD.val)
                if self.guiding_options.guide_type == "grid":
                    ind_integrator.preprocess_grid(sceneAD.val, self.guiding_options.grid_config_indirect, self.options.max_bounces)
                    boundary_integrator.preprocess_grid_indirect(
                        sceneAD.val, self.guiding_options.grid_config_indirect, self.options.max_bounces)
                else:
                    ind_integrator.preprocess_aq(sceneAD.val, self.guiding_options.aq_config_indirect, self.options.max_bounces)
                    boundary_integrator.preprocess_aq_indirect(
                        sceneAD.val, self.guiding_options.aq_config_indirect, self.options.max_bounces)
                    
        use_mala = True
        if use_mala:
            img += boundary_integrator.renderD(
                sceneAD, self.options, d_image.reshape(-1))
        else:
            img += ind_integrator.renderD(
                sceneAD, self.options, d_image.reshape(-1))
            img += dir_integrator.renderD(
                sceneAD, self.options, d_image.reshape(-1))

        for integrator in self.aug_integrators:
            integrator.enable_antithetic = self.enable_antithetic
            img += integrator.renderD(
                sceneAD, self.options, d_image.reshape(-1))

        # print(np.array(sceneAD.der.shapes[xform.shape_id].vertices))
        
        img = torch.tensor(img)
        img[img.isnan()] = 0
        img = img.reshape((self.height, self.width, 3))
        print("img grad:", img.clone().detach().sum())
        print("grad: ", torch.tensor(
            np.array(sceneAD.der.shapes[xform.shape_id].vertices))[0][2])
        if (use_mala):
            result_path = os.path.join(
                self.out_dir,  self.prefix + "backward")
            # result_path = os.path.join(
            #     self.out_dir,  "mala_rolling+scaling")
        else:
            result_path = os.path.join(
                self.out_dir,  self.prefix + "backward_ref")
        imwrite(img.numpy(), result_path + ".exr")
        convertEXR2ColorMap(result_path + ".exr", result_path + ".png", -scale, scale, 1.0, False)
        
    def render_converge(self):
        xform = Transform(gin.REQUIRED)
        if self.guiding_options.guide_type != "":
            self.scene.shapes[xform.shape_id].sort_config = self.guiding_options.sort_config
            self.scene.shapes[xform.shape_id].enable_draw = True
            self.scene.shapes[xform.shape_id].configure()

    
        for i in range(len(self.scene.shapes)):
            self.scene.shapes[i].enable_edge = False
        self.scene.shapes[xform.shape_id].enable_edge = True
        self.scene.configure()
        
        self.scene.shapes[xform.shape_id].sort_config = self.guiding_options.sort_config
        self.scene.shapes[xform.shape_id].enable_draw = True
        self.scene.shapes[xform.shape_id].configure()
        d_image = np.ones((self.height, self.width, 3))
        d_image[:, :, 1:3] = 0

        # dependency injection
        xform = Transform(gin.REQUIRED)
        assert(xform.shape_id >= 0)
        shape = self.scene.shapes[xform.shape_id]
        shape.requires_grad = True
        if xform.vertex_id >= 0:
            shape.vertex_idx = xform.vertex_id
        if type(xform.transformation) is Translation:
            # velocity = np.zeros_like(shape.vertices)
            # velocity[:, :] = xform.transformation.translation
            # shape.setVelocities(velocity)
            shape.setTranslation(xform.transformation.translation)
        elif type(xform.transformation) is Rotation:
            shape.setRotation(xform.transformation.rotation)
        else:
            assert(False)
        sceneAD = psdr_cpu.SceneAD(self.scene)
        img = self.integrator.renderD(
            sceneAD, self.options, d_image.reshape(-1))

        boundary_integrator = psdr_cpu.BoundaryIntegrator(sceneAD.val)
        ind_integrator = psdr_cpu.IndirectEdgeIntegrator(sceneAD.val)
        # MLT_integrator.load_MALA_config(self.mala_options)
        
        # mala_options = MALAOptions(gin.REQUIRED)
        # if (use_mala):
        #     boundary_integrator = psdr_cpu.DirectEdgeMLT(sceneAD.val)
        #     boundary_integrator.load_MALA_config(self.mala_options)
        #     boundary_integrator.preprocess_grid(
        #         sceneAD.val, self.guiding_options.grid_config_direct, self.options.max_bounces)
        # else:
        boundary_integrator.configure_mala(self.mala_options)
        if self.guiding_options.guide_type != "":
            if self.guiding_options.guide_option == "direct" or self.guiding_options.guide_option == "both":
                boundary_integrator.recompute_direct_edge(sceneAD.val)
                if self.guiding_options.guide_type == "grid":
                    boundary_integrator.preprocess_grid_direct(
                        sceneAD.val, self.guiding_options.grid_config_direct, self.options.max_bounces)
                else:
                    boundary_integrator.preprocess_aq_direct(
                        sceneAD.val, self.guiding_options.aq_config_direct, self.options.max_bounces)
            if self.guiding_options.guide_option == "indirect" or self.guiding_options.guide_option == "both":
                boundary_integrator.recompute_indirect_edge(sceneAD.val)
                if self.guiding_options.guide_type == "grid":
                    ind_integrator.preprocess_grid(sceneAD.val, self.guiding_options.grid_config_indirect, self.options.max_bounces)
                    boundary_integrator.preprocess_grid_indirect(
                        sceneAD.val, self.guiding_options.grid_config_indirect, self.options.max_bounces)
                else:
                    ind_integrator.preprocess_aq(sceneAD.val, self.guiding_options.aq_config_indirect, self.options.max_bounces)
                    boundary_integrator.preprocess_aq_indirect(
                        sceneAD.val, self.guiding_options.aq_config_indirect, self.options.max_bounces)
        
        self.options.seed = 2
        img_full = boundary_integrator.renderD(
                sceneAD, self.options, d_image.reshape(-1))
        num_iters = 10
        for i in range(num_iters - 1):
            self.options.seed = i + 5
            img = boundary_integrator.renderD(
                sceneAD, self.options, d_image.reshape(-1))
            img_full += img
            img = torch.tensor(img)
            img[img.isnan()] = 0
            img = img.reshape((self.height, self.width, 3))
            print("img grad:", img.clone().detach().sum())
            print("grad: ", torch.tensor(
                np.array(sceneAD.der.shapes[xform.shape_id].vertices))[0][2])
            result_path = os.path.join(
                self.out_dir, "val",  self.prefix + "backward_" + str(i))
            imwrite(img.numpy(), result_path + ".exr")
            convertEXR2ColorMap(result_path + ".exr", result_path + ".png", -scale, scale, 1.0, False)

        img_full = torch.tensor(img_full / num_iters)
        img_full[img_full.isnan()] = 0
        img_full = img_full.reshape((self.height, self.width, 3))
        print("img grad:", img_full.clone().detach().sum())
        print("grad: ", torch.tensor(
            np.array(sceneAD.der.shapes[xform.shape_id].vertices))[0][2])
        result_path = os.path.join(
            self.out_dir,  self.prefix + "backward_full")
        imwrite(img_full.numpy(), result_path + ".exr")
        convertEXR2ColorMap(result_path + ".exr", result_path + ".png", -scale, scale, 1.0, False)
        # for integrator in self.aug_integrators:
        #     integrator.enable_antithetic = self.enable_antithetic
        #     img += integrator.renderD(
        #         sceneAD, self.options, d_image.reshape(-1))

        

    def render_fd(self):
        delta = self.delta
        image1 = self.integrator.renderC(
            self.scene, self.options).reshape(self.height, self.width, 3)

        xform = Transform(gin.REQUIRED)
        shape = self.scene.shapes[xform.shape_id]
        vertices = np.array(shape.vertices_world)
        if xform.vertex_id >= 0:
            vertices[xform.vertex_id] = xform.transformation.transform(
                vertices[xform.vertex_id], delta)
        else:
            for i in range(len(vertices)):
                vertices[i] = xform.transformation.transform(
                    vertices[i], delta)
        self.scene.shapes[xform.shape_id].vertices_world = vertices
        self.scene.configure()
        image2 = self.integrator.renderC(
            self.scene, self.options).reshape(self.height, self.width, 3)
        fd = (image2 - image1) / delta
        fd[:, :, 1:3] = 0
        fname = os.path.join(self.out_dir, self.prefix + "fd")
        imwrite(fd, fname + ".exr")
        convertEXR2ColorMap(fname + ".exr", fname + ".png", -scale, scale, 1.0, False)

    def d_render(self):
        sceneAD = psdr_cpu.SceneAD(self.scene)
        d_image = np.ones((self.height, self.width, 3))
        d_image[:, :, 1:3] = 0
        img = self.integrator.renderD(
            sceneAD, self.options, d_image.reshape(-1).astype(np.float32))
        print(torch.tensor(sceneAD.der.shapes[0].vertices).abs().sum())

    def run(self):
        if self.render_type == "forward":
            self.renderC()
        elif self.render_type == "backward":
            self.renderD()
        elif self.render_type == "fd":
            self.render_fd()
        else:
            assert(False)

def write_vol(fname, data, size):
    import struct
    with open(fname, "wb") as fout:
        fout.write("VOL".encode("ascii"))
        fout.write(int.to_bytes(3, 1, 'little'))
        fout.write(struct.pack('I', 1))
        fout.write(struct.pack('3I', *size))
        fout.write(struct.pack('I', 1))
        fout.write(struct.pack('6f', 0, 0, 0, 1, 1, 1))
        # data_reordered = data.transpose(2, 1, 0)
        # for i in range(size[2]):
        #     n = size[0] * size[1]
        #     data_str = struct.pack('%df' % n, *data_reordered[i].reshape(-1).tolist())
        #     fout.write(data_str)
        n = size[0] * size[1] * size[2]
        data_str = struct.pack('%df' % n, *data.tolist())
        fout.write(data_str)
        
def read_vol(fname):
    import struct
    with open(fname, "rb") as fin:
        magic = fin.read(3)
        version = struct.unpack('B', fin.read(1))[0]
        if (magic != "VOL".encode("ascii")):
            print("invalid vol file")
            exit()
        if (version != 3):
            print("invalid vol version")
            exit()
        n_channels = struct.unpack('I', fin.read(4))[0]
        size = struct.unpack('3I', fin.read(12))
        n_channels = struct.unpack('I', fin.read(4))[0]
        bbox = struct.unpack('6f', fin.read(24))
        n = size[0] * size[1] * size[2]
        data = struct.unpack('%df' % n, fin.read(4 * n))
        data = np.array(data).reshape(size[2], size[1], size[0])
        data = data.transpose(2, 1, 0)
        return data

if __name__ == "__main__":
    os.chdir(os.path.dirname(os.path.realpath(__file__)))
    default_config = './napoleon.conf'
    parser = argparse.ArgumentParser(
        description='Script for generating validation results')
    parser.add_argument('config_file', metavar='config_file',
                        type=str, nargs='?', default=default_config, help='config file')
    args, unknown = parser.parse_known_args()
    # Dependency injection: Arguments are injected into the function from the gin config file.
    gin.add_config_file_search_path(os.getcwd())
    gin.parse_config_file(args.config_file, skip_unknown=True)
    test_runner = TestRunner()
    # yep.start("val.prof")
    
    # direct_integrator = psdr_cpu.DirectEdgeMLT(test_runner.scene)
    # out_arr = direct_integrator.diff_bsdf_test(0.05, -0.1, 0.1)
    # bsdf = out_arr[6:]
    # print(out_arr)
    # print(bsdf, np.linalg.norm(bsdf))
    
    test_runner.run()
    # test_runner.render_converge()
    test_runner.options.spp = 128
    test_runner.options.sppse0 = 0
    # test_runner.renderC()
    print("done")
    exit()
    
    xform = Transform(gin.REQUIRED)
    
    for i in range(len(test_runner.scene.shapes)):
        test_runner.scene.shapes[i].enable_edge = False
    test_runner.scene.shapes[xform.shape_id].enable_edge = True
    
    test_runner.scene.shapes[xform.shape_id].sort_config = test_runner.guiding_options.sort_config
    test_runner.scene.shapes[xform.shape_id].enable_draw = True
    test_runner.scene.shapes[xform.shape_id].configure()
    test_runner.scene.configure()
    
    indirect_integrator = psdr_cpu.IndirectEdgeMLT(test_runner.scene)
    # direct_integrator_ref = psdr_cpu.DirectEdgeIntegrator(test_runner.scene)
    
    mala_options = MALAOptions(gin.REQUIRED)
    indirect_integrator.load_MALA_config(mala_options)
    
    grid = np.array([200, 200, 200])
    # min_b = np.array([0.0, 0.67, 0.2])
    # max_b = np.array([1.0, 0.82, 0.7])
    min_b = np.array([0.0, 0.0, 0.0])
    max_b = np.array([1.0, 1.0, 1.0])
    
    data = indirect_integrator.solve_Grid(test_runner.scene, grid, min_b, max_b)
    np.save('grid_full_rough.npy', data)
    print(data.max())
    print(data.shape)
    print(np.max(data))
    write_vol("tri_full_rough.vol", data, grid)
    # pc = indirect_integrator.solve_MALA(test_runner.scene, grid, min_b, max_b)
    # print(pc.shape)
    # np.save('pointcloud.npy', pc)
    exit()
    
    fixed_axis = 0
    fixed_axis_value = 0.25
    
    shape = test_runner.scene.shapes[xform.shape_id]
    shape.requires_grad = True
    if xform.vertex_id >= 0:
        shape.vertex_idx = xform.vertex_id
    if type(xform.transformation) is Translation:
        # velocity = np.zeros_like(shape.vertices)
        # velocity[:, :] = xform.transformation.translation
        # shape.setVelocities(velocity)
        shape.setTranslation(xform.transformation.translation)
    elif type(xform.transformation) is Rotation:
        shape.setRotation(xform.transformation.rotation)
    else:
        assert(False)
    
    sceneAD = psdr_cpu.SceneAD(test_runner.scene)
    scene1 = sceneAD.val
    
    def alloc_grid(grid, sample):
        grid_x = int(sample[1] * grid.shape[0])
        grid_y = int(sample[0] * grid.shape[1])
        grid[grid_x, grid_y] += 1
    
    def f(input): # input: (2)
        vec = np.zeros(3)
        vec[fixed_axis] = fixed_axis_value
        vec[(fixed_axis + 1) % 3] = input[0]
        vec[(fixed_axis + 2) % 3] = input[1]
        direct_integrator = psdr_cpu.DirectEdgeMLT(test_runner.scene)
        ret = direct_integrator.func_1b(test_runner.scene, vec[0], vec[1], vec[2])
        return torch.tensor(ret, dtype=torch.float32)

    def d_f(input): # input_array: (N, 2), cov: (2, 2)
        vec = np.zeros(3)
        vec[fixed_axis] = fixed_axis_value
        vec[(fixed_axis + 1) % 3] = input[0]
        vec[(fixed_axis + 2) % 3] = input[1]
        # sceneAD = psdr_cpu.SceneAD(test_runner.scene)
        direct_integrator = psdr_cpu.DirectEdgeMLT(scene1)
        ret = direct_integrator.d_func_1b(scene1, vec[0], vec[1], vec[2])
        if fixed_axis == 0:
            return torch.tensor([ret[1], ret[2]], dtype=torch.float32)
        elif fixed_axis == 1:
            return torch.tensor([ret[0], ret[2]], dtype=torch.float32)
        else:
            return torch.tensor([ret[0], ret[1]], dtype=torch.float32)

    cov = torch.tensor([[0.1, 0.095], [0.095, 0.1]], dtype=torch.float32)
    resolution = [50, 50]
    x = np.arange(0.0, 1.0, 1.0 / resolution[0])
    y = np.arange(0.0, 1.0, 1.0 / resolution[1])
    
    T = 200000
    xform = Transform(gin.REQUIRED)
    assert(xform.shape_id >= 0)
    shape = test_runner.scene.shapes[xform.shape_id]
    shape.requires_grad = True
    sceneAD = psdr_cpu.SceneAD(test_runner.scene)
    
    # d_image = np.ones((test_runner.height, test_runner.width, 3))
    # d_image[:, :, 1:3] = 0
    # img = direct_integrator.renderD(
    #     sceneAD, test_runner.options, d_image.reshape(-1).astype(np.float32))
    # print(img.sum())

    # img = torch.tensor(img)
    # img[img.isnan()] = 0
    # img = img.reshape((test_runner.height, test_runner.width, 3))
    # print("img grad:", img.clone().detach().sum())
    # print("grad: ", torch.tensor(
    #     np.array(sceneAD.der.shapes[xform.shape_id].vertices))[0][2])
    # imwrite(img.numpy(), os.path.join(
    #     test_runner.out_dir,  test_runner.prefix + "backward.exr"))
    # exit()
    
    
    output_array = direct_integrator.get_sample_slice(test_runner.scene, fixed_axis, fixed_axis_value, resolution[0], resolution[1])
    grid = direct_integrator.solve_MALA(test_runner.scene, fixed_axis, fixed_axis_value, resolution[0], resolution[1], 1e-2, T)
    
    with open('grid.txt', 'r') as f:
        size = f.readline()
        size = size.strip()
        size = size.split(' ')
        size = [int(size[0]), int(size[1])]
        x, y = np.meshgrid(np.arange(0.0, 1.0, 1.0 / size[0]), np.arange(1.0, 0.0, -1.0 / size[1]))
        grid_g = np.zeros([size[0], size[1], 2], dtype=np.float32)
        grid_G = np.zeros([size[0], size[1], 2], dtype=np.float32)
        density = np.zeros([size[0], size[1]], dtype=np.float32)
        for i in range(size[0]):
            for j in range(size[1]):
                line = f.readline()
                line = line.strip()
                line = line.split(' ')
                density[i, j] = float(line[4])
                s = density[i, j]
                if (s < 0.5):
                    s = 1.0
                grid_g[i, j, 0] = float(line[0]) / s
                grid_g[i, j, 1] = float(line[1]) / s
                grid_G[i, j, 0] = float(line[2]) / s
                grid_G[i, j, 1] = float(line[3]) / s
    print(grid_g.shape, grid_G.shape, density.shape, x.shape, y.shape)
    
    print(grid)
    grid = grid.reshape(resolution[0], resolution[1]).transpose(1, 0)
    output_array = output_array.reshape(resolution[0], resolution[1])
    print(grid.shape, output_array.shape)
    density = density.transpose(1, 0)
    
    grid /= T
    output_array = output_array / output_array.sum()
    print(output_array.sum(), grid.sum())
    
    rmse = np.sqrt(np.sum((output_array - grid)**2))
    print('rmse: ', rmse)
    
    fig = plt.figure(figsize=(15, 5))
    ax = fig.add_subplot(131)
    scale = output_array.max() * 0.8
    ax.imshow(grid, cmap='viridis', vmin = 0.0, vmax = scale, interpolation='nearest')
        
    ax = fig.add_subplot(132)
    # contour for f
    # print(output_array.max(), output_array.min())
    ax.imshow(output_array, cmap='viridis', vmin = 0.0, vmax = scale, interpolation='nearest')
        
    ax = fig.add_subplot(133)
    # contour for f
    ax.imshow(output_array - grid, cmap='viridis', vmin = -scale, vmax = scale, interpolation='nearest')
    plt.savefig('./output/val_mala.png')
    
    fig = plt.figure(figsize=(15, 5))
    ax = fig.add_subplot(131)
    ax.quiver(x, y, grid_g[:, :, 0], grid_g[:, :, 1], scale = 20)
    
    ax = fig.add_subplot(132)
    ax.quiver(x, y, grid_G[:, :, 0], grid_G[:, :, 1], scale = 400)
    
    s = min(T, 20000)
    ax = fig.add_subplot(133)
    ax.imshow(density / 20000, cmap='viridis', vmin = 0.0, vmax = scale, interpolation='nearest')    
    plt.savefig('./output/val_mala2.png')
    
    exit()
    r,t = np.meshgrid(x, y)
    input_array = torch.tensor(np.c_[r.reshape(-1), t.reshape(-1)], dtype=torch.float32)
    # output_array = f(input_array, cov).reshape(resolution[0], resolution[1])
    output_array = output_array.reshape(resolution[0], resolution[1])
    #sample random input in [-1, 1] x [-1, 1]
    
    grid = np.zeros(resolution, dtype=np.float32)
    
    input_sample = torch.tensor([0.5, 0.5])
    print(input_sample)
    input_sample.requires_grad = True
    output_sample_print = f(input_sample)
    print(output_sample_print)
    input_sample_grad = d_f(input_sample)
    print(input_sample_grad)
    output_sample_print = f(input_sample)
    print(output_sample_print)
    input_sample_grad = d_f(input_sample)
    print(input_sample_grad)
    
    fig = plt.figure()
    ax = fig.add_subplot(111)
    # ax.plot_surface(r, t, output_array.detach(), cmap='viridis', edgecolor='none')
    ax.imshow(output_array, cmap='hot', interpolation='nearest')
    plt.savefig('./output/test.png')
    