import argparse
from email.policy import default
from typing import List
from cv2 import transform
import gin
import psdr_cpu
from pypsdr.validate import *
from pypsdr.utils.io import *
import dataclasses
import os
import numpy as np
import torch
# import yep

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1"

# inject dependencies into the constructor


@gin.configurable
# generate the constructor
@dataclasses.dataclass
class TestRunner:
    scene: psdr_cpu.Scene
    options: psdr_cpu.RenderOptions
    integrator: psdr_cpu.Integrator
    aug_integrators: List = dataclasses.field(default_factory=lambda: [])
    guiding_options: GuidingOptions = GuidingOptions()
    out_dir: str = "./"
    render_type: str = "forward"
    suffix: str = ""
    delta: float = 0.1

    def __post_init__(self):
        os.makedirs(self.out_dir, exist_ok=True)
        self.width = self.scene.camera.width
        self.height = self.scene.camera.height
        self.prefix = type(self.integrator).__name__ + \
            "_" + str(self.options.spp) + "_"
        if self.suffix != "":
            self.prefix += self.suffix + "_"
        psdr_cpu.set_verbose(True)

    def renderC(self):
        image = self.integrator.renderC(
            self.scene, self.options).reshape(self.height, self.width, 3)
        imwrite(image, os.path.join(self.out_dir,  self.prefix + "forward.exr"))

    def renderD(self):
        psdr_cpu.set_forward(True)
        xform = Transform(gin.REQUIRED)

        if self.guiding_options.guide_type != "":
            self.scene.shapes[xform.shape_id].sort_config = self.guiding_options.sort_config
            self.scene.shapes[xform.shape_id].enable_draw = True
            self.scene.shapes[xform.shape_id].configure()

        d_image = np.ones((self.height, self.width, 3))
        d_image[:, :, 1:3] = 0

        # dependency injection
        xform = Transform(gin.REQUIRED)
        assert(xform.shape_id >= 0)
        shape = self.scene.shapes[xform.shape_id]
        shape.requires_grad = True
        if xform.vertex_id >= 0:
            shape.vertex_idx = xform.vertex_id
        if type(xform.transformation) is Translation:
            # velocity = np.zeros_like(shape.vertices)
            # velocity[:, :] = xform.transformation.translation
            # shape.setVelocities(velocity)
            shape.setTranslation(xform.transformation.translation)
        elif type(xform.transformation) is Rotation:
            shape.setRotation(xform.transformation.rotation)
        else:
            assert(False)
        sceneAD = psdr_cpu.SceneAD(self.scene)
        img = self.integrator.forwardRenderD(sceneAD, self.options)
        # img = self.integrator.renderD(
        #     sceneAD, self.options, d_image.reshape(-1))

        boundary_integrator = psdr_cpu.BoundaryIntegrator(sceneAD.val)

        if self.guiding_options.guide_type != "":
            if self.guiding_options.guide_option == "direct" or self.guiding_options.guide_option == "both":
                boundary_integrator.recompute_direct_edge(sceneAD.val)
                if self.guiding_options.guide_type == "grid":
                    boundary_integrator.preprocess_grid_direct(
                        sceneAD.val, self.guiding_options.grid_config_direct, self.options.max_bounces)
                else:
                    boundary_integrator.preprocess_aq_direct(
                        sceneAD.val, self.guiding_options.aq_config_direct, self.options.max_bounces)
            if self.guiding_options.guide_option == "indirect" or self.guiding_options.guide_option == "both":
                boundary_integrator.recompute_indirect_edge(sceneAD.val)
                if self.guiding_options.guide_type == "grid":
                    boundary_integrator.preprocess_grid_indirect(
                        sceneAD.val, self.guiding_options.grid_config_indirect, self.options.max_bounces)
                else:
                    boundary_integrator.preprocess_aq_indirect(
                        sceneAD.val, self.guiding_options.aq_config_indirect, self.options.max_bounces)

        img += boundary_integrator.forwardRenderD(sceneAD, self.options)
        # img += boundary_integrator.renderD(
        #     sceneAD, self.options, d_image.reshape(-1))

        for integrator in self.aug_integrators:
            integrator.enable_antithetic = self.enable_antithetic
            img += integrator.renderD(
                sceneAD, self.options, d_image.reshape(-1))

        img = torch.tensor(img)
        img[img.isnan()] = 0
        img = img.reshape((self.height, self.width, 3))
        print("img grad:", img.clone().detach().sum())
        print("grad: ", torch.tensor(
            np.array(sceneAD.der.shapes[xform.shape_id].vertices))[0][2])
        imwrite(img.numpy(), os.path.join(
            self.out_dir,  self.prefix + "backward.exr"))

    def render_fd(self):
        delta = self.delta
        image1 = self.integrator.renderC(
            self.scene, self.options).reshape(self.height, self.width, 3)

        xform = Transform(gin.REQUIRED)
        shape = self.scene.shapes[xform.shape_id]
        vertices = np.array(shape.vertices)
        if xform.vertex_id >= 0:
            vertices[xform.vertex_id] = xform.transformation.transform(
                vertices[xform.vertex_id], delta)
        else:
            for i in range(len(vertices)):
                vertices[i] = xform.transformation.transform(
                    vertices[i], delta)
        self.scene.shapes[xform.shape_id].vertices = vertices
        self.scene.configure()
        image2 = self.integrator.renderC(
            self.scene, self.options).reshape(self.height, self.width, 3)
        fd = (image2 - image1) / delta
        fd[:, :, 1:] = 0
        imwrite(fd, os.path.join(self.out_dir, self.prefix + "fd.exr"))

    def d_render(self):
        sceneAD = psdr_cpu.SceneAD(self.scene)
        d_image = np.ones((self.height, self.width, 3))
        d_image[:, :, 1:3] = 0
        img = self.integrator.renderD(
            sceneAD, self.options, d_image.reshape(-1).astype(np.float32))
        print(torch.tensor(sceneAD.der.shapes[0].vertices).abs().sum())

    def run(self):
        if self.render_type == "forward":
            self.renderC()
        elif self.render_type == "backward":
            self.renderD()
        elif self.render_type == "fd":
            self.render_fd()
        else:
            assert(False)

def write_vol(fname, data, size):
    import struct
    with open(fname, "wb") as fout:
        fout.write("VOL".encode("ascii"))
        fout.write(int.to_bytes(3, 1, 'little'))
        fout.write(struct.pack('I', 1))
        fout.write(struct.pack('3I', *size))
        fout.write(struct.pack('I', 1))
        fout.write(struct.pack('6f', 0, 0, 0, 1, 1, 1))
        # data_reordered = data.transpose(2, 1, 0)
        # for i in range(size[2]):
        #     n = size[0] * size[1]
        #     data_str = struct.pack('%df' % n, *data_reordered[i].reshape(-1).tolist())
        #     fout.write(data_str)
        n = size[0] * size[1] * size[2]
        data_str = struct.pack('%df' % n, *data.tolist())
        fout.write(data_str)

if __name__ == "__main__":
    default_config = './two_triangles.conf'
    parser = argparse.ArgumentParser(
        description='Script for generating validation results')
    parser.add_argument('config_file', metavar='config_file',
                        type=str, nargs='?', default=default_config, help='config file')
    args, unknown = parser.parse_known_args()
    # Dependency injection: Arguments are injected into the function from the gin config file.
    gin.add_config_file_search_path(os.getcwd())
    gin.parse_config_file(args.config_file, skip_unknown=True)
    test_runner = TestRunner()
    # yep.start("val.prof")

    test_runner.run()
    
    # xform = Transform(gin.REQUIRED)
    
    # for i in range(len(test_runner.scene.shapes)):
    #     test_runner.scene.shapes[i].enable_edge = False
    # test_runner.scene.shapes[xform.shape_id].enable_edge = True
    # test_runner.scene.configure()
    
    # test_runner.scene.shapes[xform.shape_id].sort_config = test_runner.guiding_options.sort_config
    # test_runner.scene.shapes[xform.shape_id].enable_draw = True
    # test_runner.scene.shapes[xform.shape_id].configure()
    
    # direct_integrator = psdr_cpu.DirectEdgeIntegrator(test_runner.scene)
    
    # grid_x = 2000
    # grid_y = 25
    # grid_z = 25
    
    # # data = np.ones((grid_x, grid_y, grid_z))
    # # for i in range(grid_x):
    # #     for j in range(grid_y):
    # #         print('progress: ', i*grid_y*grid_z+j*grid_z, '/', grid_x*grid_y*grid_z)
    # #         for k in range(grid_z):
    # #             data[i, j, k] = i*j*k/grid_x/grid_y/grid_z
    # data = direct_integrator.get_sample_vol(test_runner.scene, grid_x, grid_y, grid_z)
    # print(data.shape)
    # print(np.max(data))
    # write_vol("tri_02.vol", data, (grid_x, grid_y, grid_z))
    # yep.stop()
