import argparse
from email.policy import default
import time
from typing import List
from cv2 import transform
import gin
import psdr_cpu
from pypsdr.validate import *
from pypsdr.utils.io import *
from pypsdr.utils.exr2png import convertEXR2ColorMap
import dataclasses
import os
import numpy as np
import torch
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
torch.random.manual_seed(10)
# import yep

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1"

# inject dependencies into the constructor


@gin.configurable
# generate the constructor
@dataclasses.dataclass
class TestRunner:
    scene: psdr_cpu.Scene
    options: psdr_cpu.RenderOptions
    mala_options: psdr_cpu.MALAOptions
    integrator: psdr_cpu.Integrator
    aug_integrators: List = dataclasses.field(default_factory=lambda: [])
    guiding_options: GuidingOptions = GuidingOptions()
    out_dir: str = "./"
    render_type: str = "forward"
    suffix: str = ""
    delta: float = 0.001

    def __post_init__(self):
        os.makedirs(self.out_dir, exist_ok=True)
        os.makedirs(os.path.join(self.out_dir, "val"), exist_ok=True)
        self.width = self.scene.camera.width
        self.height = self.scene.camera.height
        self.prefix = type(self.integrator).__name__ + \
            "_" + str(self.options.spp) + "_"
        if self.suffix != "":
            self.prefix += self.suffix + "_"
        psdr_cpu.set_verbose(True)
        self.scene.load_envmap('envmap2.exr')

    def renderC(self):
        image = self.integrator.renderC(
            self.scene, self.options).reshape(self.height, self.width, 3)
        imwrite(image, os.path.join(self.out_dir,  self.prefix + "forward.exr"))

    def renderD(self):
        xform = Transform(gin.REQUIRED)
        if self.guiding_options.guide_type != "":
            self.scene.shapes[xform.shape_id].sort_config = self.guiding_options.sort_config
            self.scene.shapes[xform.shape_id].enable_draw = True
            self.scene.shapes[xform.shape_id].configure()

    
        for i in range(len(self.scene.shapes)):
            self.scene.shapes[i].enable_edge = False
        self.scene.shapes[xform.shape_id].enable_edge = True
        self.scene.configure()
        
        self.scene.shapes[xform.shape_id].sort_config = self.guiding_options.sort_config
        self.scene.shapes[xform.shape_id].enable_draw = True
        self.scene.shapes[xform.shape_id].configure()
        d_image = np.ones((self.height, self.width, 3))
        d_image[:, :, 1:3] = 0

        # dependency injection
        xform = Transform(gin.REQUIRED)
        assert(xform.shape_id >= 0)
        shape = self.scene.shapes[xform.shape_id]
        shape.requires_grad = True
        if xform.vertex_id >= 0:
            shape.vertex_idx = xform.vertex_id
        if type(xform.transformation) is Translation:
            # velocity = np.zeros_like(shape.vertices)
            # velocity[:, :] = xform.transformation.translation
            # shape.setVelocities(velocity)
            shape.setTranslation(xform.transformation.translation)
        elif type(xform.transformation) is Rotation:
            shape.setRotation(xform.transformation.rotation)
        else:
            assert(False)
        sceneAD = psdr_cpu.SceneAD(self.scene)
        img = self.integrator.renderD(
            sceneAD, self.options, d_image.reshape(-1))
        # self.integrator.renderD(
        #     sceneAD, self.options, d_image.reshape(-1))
        boundary_integrator = psdr_cpu.BoundaryIntegrator(sceneAD.val)
        ind_integrator = psdr_cpu.IndirectEdgeIntegrator(sceneAD.val)
        d_integrator = psdr_cpu.DirectEdgeIntegrator(sceneAD.val)
        # MLT_integrator.load_MALA_config(self.mala_options)
        
        # mala_options = MALAOptions(gin.REQUIRED)
        # if (use_mala):
        #     boundary_integrator = psdr_cpu.DirectEdgeMLT(sceneAD.val)
        #     boundary_integrator.load_MALA_config(self.mala_options)
        #     boundary_integrator.preprocess_grid(
        #         sceneAD.val, self.guiding_options.grid_config_direct, self.options.max_bounces)
        # else:
        boundary_integrator.configure_mala(self.mala_options)
        if self.guiding_options.guide_type != "":
            if self.guiding_options.guide_option == "direct" or self.guiding_options.guide_option == "both":
                boundary_integrator.recompute_direct_edge(sceneAD.val)
                if self.guiding_options.guide_type == "grid":
                    d_integrator.preprocess_grid(sceneAD.val, self.guiding_options.grid_config_direct, self.options.max_bounces)
                    boundary_integrator.preprocess_grid_direct(
                        sceneAD.val, self.guiding_options.grid_config_direct, self.options.max_bounces)
                else:
                    d_integrator.preprocess_aq(sceneAD.val, self.guiding_options.aq_config_direct, self.options.max_bounces)
                    boundary_integrator.preprocess_aq_direct(
                        sceneAD.val, self.guiding_options.aq_config_direct, self.options.max_bounces)
            if self.guiding_options.guide_option == "indirect" or self.guiding_options.guide_option == "both":
                boundary_integrator.recompute_indirect_edge(sceneAD.val)
                if self.guiding_options.guide_type == "grid":
                    ind_integrator.preprocess_grid(sceneAD.val, self.guiding_options.grid_config_indirect, self.options.max_bounces)
                    boundary_integrator.preprocess_grid_indirect(
                        sceneAD.val, self.guiding_options.grid_config_indirect, self.options.max_bounces)
                else:
                    ind_integrator.preprocess_aq(sceneAD.val, self.guiding_options.aq_config_indirect, self.options.max_bounces)
                    boundary_integrator.preprocess_aq_indirect(
                        sceneAD.val, self.guiding_options.aq_config_indirect, self.options.max_bounces)
                    
        use_mala = True
        if use_mala:
            img += boundary_integrator.renderD(
                sceneAD, self.options, d_image.reshape(-1))
        else:
            img += ind_integrator.renderD(
                sceneAD, self.options, d_image.reshape(-1))
            img += d_integrator.renderD(
                sceneAD, self.options, d_image.reshape(-1))

        for integrator in self.aug_integrators:
            integrator.enable_antithetic = self.enable_antithetic
            img += integrator.renderD(
                sceneAD, self.options, d_image.reshape(-1))

        scale = 0.2
        
        img = torch.tensor(img)
        img[img.isnan()] = 0
        img = img.reshape((self.height, self.width, 3))
        print("img grad:", img.clone().detach().sum())
        print("grad: ", torch.tensor(
            np.array(sceneAD.der.shapes[xform.shape_id].vertices))[0][2])
        if (use_mala):
            result_path = os.path.join(
                self.out_dir,  self.prefix + "backward")
        else:
            result_path = os.path.join(
                self.out_dir,  self.prefix + "backward_ref")
        imwrite(img.numpy(), result_path + ".exr")
        convertEXR2ColorMap(result_path + ".exr", result_path + ".png", -scale, scale, 1.0, False)
        
    def render_converge(self):
        xform = Transform(gin.REQUIRED)
        if self.guiding_options.guide_type != "":
            self.scene.shapes[xform.shape_id].sort_config = self.guiding_options.sort_config
            self.scene.shapes[xform.shape_id].enable_draw = True
            self.scene.shapes[xform.shape_id].configure()

    
        for i in range(len(self.scene.shapes)):
            self.scene.shapes[i].enable_edge = False
        self.scene.shapes[xform.shape_id].enable_edge = True
        self.scene.configure()
        
        self.scene.shapes[xform.shape_id].sort_config = self.guiding_options.sort_config
        self.scene.shapes[xform.shape_id].enable_draw = True
        self.scene.shapes[xform.shape_id].configure()
        d_image = np.ones((self.height, self.width, 3))
        d_image[:, :, 1:3] = 0

        # dependency injection
        xform = Transform(gin.REQUIRED)
        assert(xform.shape_id >= 0)
        shape = self.scene.shapes[xform.shape_id]
        shape.requires_grad = True
        if xform.vertex_id >= 0:
            shape.vertex_idx = xform.vertex_id
        if type(xform.transformation) is Translation:
            # velocity = np.zeros_like(shape.vertices)
            # velocity[:, :] = xform.transformation.translation
            # shape.setVelocities(velocity)
            shape.setTranslation(xform.transformation.translation)
        elif type(xform.transformation) is Rotation:
            shape.setRotation(xform.transformation.rotation)
        else:
            assert(False)
        sceneAD = psdr_cpu.SceneAD(self.scene)
        img = self.integrator.renderD(
            sceneAD, self.options, d_image.reshape(-1))

        boundary_integrator = psdr_cpu.BoundaryIntegrator(sceneAD.val)
        ind_integrator = psdr_cpu.IndirectEdgeIntegrator(sceneAD.val)
        # MLT_integrator.load_MALA_config(self.mala_options)
        
        boundary_integrator.configure_mala(self.mala_options)
        if self.guiding_options.guide_type != "":
            if self.guiding_options.guide_option == "direct" or self.guiding_options.guide_option == "both":
                boundary_integrator.recompute_direct_edge(sceneAD.val)
                if self.guiding_options.guide_type == "grid":
                    boundary_integrator.preprocess_grid_direct(
                        sceneAD.val, self.guiding_options.grid_config_direct, self.options.max_bounces)
                else:
                    boundary_integrator.preprocess_aq_direct(
                        sceneAD.val, self.guiding_options.aq_config_direct, self.options.max_bounces)
            if self.guiding_options.guide_option == "indirect" or self.guiding_options.guide_option == "both":
                boundary_integrator.recompute_indirect_edge(sceneAD.val)
                if self.guiding_options.guide_type == "grid":
                    ind_integrator.preprocess_grid(sceneAD.val, self.guiding_options.grid_config_indirect, self.options.max_bounces)
                    boundary_integrator.preprocess_grid_indirect(
                        sceneAD.val, self.guiding_options.grid_config_indirect, self.options.max_bounces)
                else:
                    ind_integrator.preprocess_aq(sceneAD.val, self.guiding_options.aq_config_indirect, self.options.max_bounces)
                    boundary_integrator.preprocess_aq_indirect(
                        sceneAD.val, self.guiding_options.aq_config_indirect, self.options.max_bounces)
        
        scale = 0.4
        self.options.seed = 2
        img_full = boundary_integrator.renderD(
                sceneAD, self.options, d_image.reshape(-1))
        num_iter = 16
        for i in range(num_iter - 1):
            self.options.seed = i + 5
            img = boundary_integrator.renderD(
                sceneAD, self.options, d_image.reshape(-1))
            img_full += img
            img = torch.tensor(img)
            img[img.isnan()] = 0
            img = img.reshape((self.height, self.width, 3))
            print("img grad:", img.clone().detach().sum())
            print("grad: ", torch.tensor(
                np.array(sceneAD.der.shapes[xform.shape_id].vertices))[0][2])
            result_path = os.path.join(
                self.out_dir, "val",  self.prefix + "backward_" + str(i))
            imwrite(img.numpy(), result_path + ".exr")
            convertEXR2ColorMap(result_path + ".exr", result_path + ".png", -scale, scale, 1.0, False)

        img_full = torch.tensor(img_full / num_iter)
        img_full[img_full.isnan()] = 0
        img_full = img_full.reshape((self.height, self.width, 3))
        print("img grad:", img_full.clone().detach().sum())
        print("grad: ", torch.tensor(
            np.array(sceneAD.der.shapes[xform.shape_id].vertices))[0][2])
        result_path = os.path.join(
            self.out_dir,  self.prefix + "backward_full")
        imwrite(img_full.numpy(), result_path + ".exr")
        convertEXR2ColorMap(result_path + ".exr", result_path + ".png", -scale, scale, 1.0, False)
        # for integrator in self.aug_integrators:
        #     integrator.enable_antithetic = self.enable_antithetic
        #     img += integrator.renderD(
        #         sceneAD, self.options, d_image.reshape(-1))

        

    def render_fd(self):
        delta = self.delta
        image1 = self.integrator.renderC(
            self.scene, self.options).reshape(self.height, self.width, 3)

        xform = Transform(gin.REQUIRED)
        shape = self.scene.shapes[xform.shape_id]
        vertices = np.array(shape.vertices)
        if xform.vertex_id >= 0:
            vertices[xform.vertex_id] = xform.transformation.transform(
                vertices[xform.vertex_id], delta)
        else:
            for i in range(len(vertices)):
                vertices[i] = xform.transformation.transform(
                    vertices[i], delta)
        self.scene.shapes[xform.shape_id].vertices = vertices
        self.scene.configure()
        image2 = self.integrator.renderC(
            self.scene, self.options).reshape(self.height, self.width, 3)
        fd = (image2 - image1) / delta
        fd[:, :, 1:3] = 0
        scale = 2
        fname = os.path.join(self.out_dir, self.prefix + "fd")
        imwrite(fd, fname + ".exr")
        convertEXR2ColorMap(fname + ".exr", fname + ".png", -scale, scale, 1.0, False)

    def d_render(self):
        sceneAD = psdr_cpu.SceneAD(self.scene)
        d_image = np.ones((self.height, self.width, 3))
        d_image[:, :, 1:3] = 0
        img = self.integrator.renderD(
            sceneAD, self.options, d_image.reshape(-1).astype(np.float32))
        print(torch.tensor(sceneAD.der.shapes[0].vertices).abs().sum())

    def run(self):
        if self.render_type == "forward":
            self.renderC()
        elif self.render_type == "backward":
            self.renderD()
        elif self.render_type == "fd":
            self.render_fd()
        else:
            assert(False)

def write_vol(fname, data, size):
    import struct
    with open(fname, "wb") as fout:
        fout.write("VOL".encode("ascii"))
        fout.write(int.to_bytes(3, 1, 'little'))
        fout.write(struct.pack('I', 1))
        fout.write(struct.pack('3I', *size))
        fout.write(struct.pack('I', 1))
        fout.write(struct.pack('6f', 0, 0, 0, 1, 1, 1))
        # data_reordered = data.transpose(2, 1, 0)
        # for i in range(size[2]):
        #     n = size[0] * size[1]
        #     data_str = struct.pack('%df' % n, *data_reordered[i].reshape(-1).tolist())
        #     fout.write(data_str)
        n = size[0] * size[1] * size[2]
        data_str = struct.pack('%df' % n, *data.tolist())
        fout.write(data_str)

if __name__ == "__main__":
    os.chdir(os.path.dirname(os.path.realpath(__file__)))
    default_config = './king_rotate.conf'
    parser = argparse.ArgumentParser(
        description='Script for generating validation results')
    parser.add_argument('config_file', metavar='config_file',
                        type=str, nargs='?', default=default_config, help='config file')
    args, unknown = parser.parse_known_args()
    # Dependency injection: Arguments are injected into the function from the gin config file.
    gin.add_config_file_search_path(os.getcwd())
    gin.parse_config_file(args.config_file, skip_unknown=True)
    test_runner = TestRunner()
    # yep.start("val.prof")
    
    # direct_integrator = psdr_cpu.DirectEdgeMLT(test_runner.scene)
    # out_arr = direct_integrator.diff_bsdf_test(0.05, -0.1, 0.1)
    # bsdf = out_arr[6:]
    # print(out_arr)
    # print(bsdf, np.linalg.norm(bsdf))
    
    # test_runner.run()
    test_runner.render_converge()
    test_runner.options.spp = 128
    test_runner.options.sppse0 = 0
    # test_runner.renderC()
    print("done")
    exit()