import argparse
from email.policy import default
import time
from typing import List
from cv2 import transform
import gin
import psdr_cpu
from pypsdr.validate import *
from pypsdr.utils.io import *
from pypsdr.utils.exr2png import convertEXR2ColorMap
import dataclasses
import os
import numpy as np
import torch
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from tqdm import tqdm
torch.random.manual_seed(10)
# import yep
import mitsuba as mi
mi.set_variant('scalar_rgb')

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1"

# inject dependencies into the constructor

scale = 15 # lucy mirror: 50, nefertiti: 5 lamp bulb: 15

@gin.configurable
# generate the constructor
@dataclasses.dataclass
class TestRunner:
    scene: psdr_cpu.Scene
    options: psdr_cpu.RenderOptions
    mala_options: psdr_cpu.MALAOptions
    integrator: psdr_cpu.Integrator
    aug_integrators: List = dataclasses.field(default_factory=lambda: [])
    guiding_options: GuidingOptions = GuidingOptions()
    out_dir: str = "./"
    render_type: str = "forward"
    suffix: str = ""
    delta: float = 0.005

    def __post_init__(self):
        os.makedirs(self.out_dir, exist_ok=True)
        os.makedirs(os.path.join(self.out_dir, "val"), exist_ok=True)
        self.width = self.scene.camera.width
        self.height = self.scene.camera.height
        self.prefix = type(self.integrator).__name__ + \
            "_" + str(self.options.spp) + "_"
        if self.suffix != "":
            self.prefix += self.suffix + "_"
        psdr_cpu.set_verbose(True)

    def render_ray(self, org, dir, fname):
        with open("../data/scenes/two_triangles/curve.txt", "w") as f:
            stt = np.array(org) - np.array(dir) * 10
            end = np.array(org) + np.array(dir) * 10
            f.write(str(stt[0]) + " " + str(stt[1]) + " " + str(stt[2]) + " " + "0.01" + "\n")
            f.write(str(end[0]) + " " + str(end[1]) + " " + str(end[2]) + " " + "0.05" + "\n")
        scene = mi.load_file("../data/scenes/two_triangles/mitsuba_vis.xml")
        original_image = mi.render(scene, spp=128)

        mi.util.write_bitmap(fname, original_image)
        
    def mutation_test(self):
        print(len(self.scene.shapes))
        for i in range(len(self.scene.shapes)):
            self.scene.shapes[i].enable_edge = False
        self.scene.shapes[0].enable_edge = True
        self.scene.configure()
        self.scene.shapes[0].enable_draw = True
        self.scene.shapes[0].configure()
        
        mlt_integrator = psdr_cpu.IndirectEdgeMLT(self.scene)
        rnd = np.array([0.3, 0.3, 0.3])
        mutation_offset = np.array([0.02, 0.0, 0.3])
        
        path = mlt_integrator.perturbe_sample(self.scene, rnd, mutation_offset)
        path = np.array(path)
        for i in tqdm(range(path.shape[0])):
            
            ray = mlt_integrator.get_edge_ray(self.scene, path[i])
            self.render_ray(ray[0], ray[1], "./output/mutation/ray_" + str(i) + ".png")


if __name__ == "__main__":
    os.chdir(os.path.dirname(os.path.realpath(__file__)))
    default_config = './two_triangles.conf'
    parser = argparse.ArgumentParser(
        description='Script for generating validation results')
    parser.add_argument('config_file', metavar='config_file',
                        type=str, nargs='?', default=default_config, help='config file')
    args, unknown = parser.parse_known_args()
    # Dependency injection: Arguments are injected into the function from the gin config file.
    gin.add_config_file_search_path(os.getcwd())
    gin.parse_config_file(args.config_file, skip_unknown=True)
    test_runner = TestRunner()
    # yep.start("val.prof")
    
    # direct_integrator = psdr_cpu.DirectEdgeMLT(test_runner.scene)
    # out_arr = direct_integrator.diff_bsdf_test(0.05, -0.1, 0.1)
    # bsdf = out_arr[6:]
    # print(out_arr)
    # print(bsdf, np.linalg.norm(bsdf))
    
    test_runner.mutation_test()
    # test_runner.render_converge()
    test_runner.options.spp = 128
    test_runner.options.sppse0 = 0
    # test_runner.renderC()
    print("done")
    exit()
    