from scipy.spatial.transform import Rotation as R
import dataclasses
import psdr_cpu
import gin
import os
import numpy as np

class GuidingOptions:
    guide_type = "" # empty, grid, aq
    guide_option = "" # direct, indirect, both

    grid_config_direct: psdr_cpu.grid_config = None
    grid_config_indirect: psdr_cpu.grid_config = None

    sort_config: psdr_cpu.Sort_config = None
    aq_config_direct: psdr_cpu.aq_config = None
    aq_config_indirect: psdr_cpu.aq_config = None


@gin.configurable
def RenderOptions(seed, max_bounce, spp, sppe, sppse0, sppse1 = 0, sppe0 = 0):
    options = psdr_cpu.RenderOptions(seed, spp, max_bounce,
                                 sppe, sppse0, False)
    options.sppse1 = sppse1
    options.sppe0 = sppe0
    return options

@gin.configurable
def MALAOptions(num_chains = 300, num_samples = 5000, step_length = 1e-7, p_global = 0.1,
                mode = 0, burn_in = 0, thinning = 1, phase_one_samples = 24, across_edge = True, 
                use_gradient = True, scale_steplength = True, scale = np.array([1.0, 1.0, 1.0])):
    options = psdr_cpu.MALAOptions(num_chains, num_samples, step_length, 
                                   p_global, mode, burn_in, thinning, phase_one_samples, across_edge, use_gradient, scale_steplength, scale)
    return options

@gin.configurable
def GuidingOptions(guide_type = "", guide_option = "indirect", grid_config_direct = [1,1,1,1], grid_config_indirect = [1,1,1,1], sort_config = [], 
    direct_thold=0.01, direct_spg=32, direct_min_spg=16, direct_sample_decay=0.5, direct_weight_decay=0.5, direct_max_depth=8, direct_npass=0, direct_use_heap=False, direct_edge_draw=False, direct_max_depth_x=8, direct_max_depth_y=16, direct_max_depth_z=16, direct_eps=0.1, direct_shape_opt_id=-1, direct_local_backward=False,
    indirect_thold=0.01, indirect_spg=32, indirect_min_spg=16, indirect_sample_decay=0.5, indirect_weight_decay=0.5, indirect_max_depth=8, indirect_npass=0, indirect_use_heap=False, indirect_edge_draw=False, indirect_max_depth_x=8, indirect_max_depth_y=16, indirect_max_depth_z=16, indirect_eps=0.1, indirect_shape_opt_id=-1, indirect_local_backward=False):
    options = GuidingOptions
    options.guide_type = guide_type
    options.guide_option = guide_option

    options.sort_config = psdr_cpu.Sort_config()

    if guide_type != "":
        if len(sort_config) != 0:
            options.sort_config = psdr_cpu.Sort_config(sort_config[0], sort_config[1], sort_config[2])

        if guide_type == "grid":
            if guide_option == "direct" or guide_option == "both":
                options.grid_config_direct = psdr_cpu.grid_config(grid_config_direct[0:3], grid_config_direct[3])
            if guide_option == "indirect" or guide_option == "both":
                options.grid_config_indirect = psdr_cpu.grid_config(grid_config_indirect[0:3], grid_config_indirect[3])

        elif guide_type == "aq":
            if guide_option == "direct" or guide_option == "both":
                options.aq_config_direct = psdr_cpu.aq_config()
                options.aq_config_direct.thold = direct_thold
                options.aq_config_direct.spg = direct_spg
                options.aq_config_direct.min_spg = direct_min_spg
                options.aq_config_direct.sample_decay = direct_sample_decay
                options.aq_config_direct.weight_decay = direct_weight_decay
                options.aq_config_direct.max_depth = direct_max_depth
                options.aq_config_direct.npass = direct_npass
                options.aq_config_direct.use_heap = direct_use_heap
                options.aq_config_direct.edge_draw = direct_edge_draw
                options.aq_config_direct.max_depth_x = direct_max_depth_x
                options.aq_config_direct.max_depth_y = direct_max_depth_y
                options.aq_config_direct.max_depth_z = direct_max_depth_z
                options.aq_config_direct.eps = direct_eps
                options.aq_config_direct.shape_opt_id = direct_shape_opt_id
                options.aq_config_direct.local_backward = direct_local_backward

            if guide_option == "indirect" or guide_option == "both":
                options.aq_config_indirect = psdr_cpu.aq_config()
                options.aq_config_indirect.thold = indirect_thold
                options.aq_config_indirect.spg = indirect_spg
                options.aq_config_indirect.min_spg = indirect_min_spg
                options.aq_config_indirect.sample_decay = indirect_sample_decay
                options.aq_config_indirect.weight_decay = indirect_weight_decay
                options.aq_config_indirect.max_depth = indirect_max_depth
                options.aq_config_indirect.npass = indirect_npass
                options.aq_config_indirect.use_heap = indirect_use_heap
                options.aq_config_indirect.edge_draw = indirect_edge_draw
                options.aq_config_indirect.max_depth_x = indirect_max_depth_x
                options.aq_config_indirect.max_depth_y = indirect_max_depth_y
                options.aq_config_indirect.max_depth_z = indirect_max_depth_z
                options.aq_config_indirect.eps = indirect_eps
                options.aq_config_indirect.shape_opt_id = indirect_shape_opt_id
                options.aq_config_indirect.local_backward = indirect_local_backward

    return options


@gin.configurable
def Direct():
    return psdr_cpu.Direct()


@gin.configurable
def Direct2():
    return psdr_cpu.Direct2()


@gin.configurable
def Path():
    return psdr_cpu.Path()

@gin.configurable
def Volpath(enable_antithetic=True):
    integrator = psdr_cpu.Volpath()
    integrator.enable_antithetic = enable_antithetic
    return integrator

@gin.configurable
def Volpath2(enable_antithetic=True):
    integrator = psdr_cpu.Volpath2()
    integrator.enable_antithetic = enable_antithetic
    return integrator
    
@gin.configurable
def VolpathMerged(enable_antithetic=True):
    integrator = psdr_cpu.VolpathMerged()
    integrator.enable_antithetic = enable_antithetic
    return integrator

@gin.configurable
def Path2(enable_antithetic=True):
    integrator = psdr_cpu.Path2()
    integrator.enable_antithetic = enable_antithetic
    return integrator

@gin.configurable
def Mask(enable_antithetic=True):
    integrator = psdr_cpu.Mask()
    integrator.enable_antithetic = enable_antithetic
    return integrator

@gin.configurable
def PTracer(enable_antithetic=True, is_equal_trans=False):
    integrator = psdr_cpu.PTracer()
    integrator.enable_antithetic = enable_antithetic
    integrator.is_equal_trans = is_equal_trans
    return integrator


@gin.configurable
def BdptNaive():
    return psdr_cpu.Bdpt(False)

@gin.configurable
def Bdpt(enable_antithetic = True):
    integrator = psdr_cpu.Bdpt(enable_antithetic)
    return integrator

@gin.configurable
def DirectEdgeIntegrator(scene):
    return psdr_cpu.DirectEdgeIntegrator(scene)


@gin.configurable
def IndirectEdgeIntegrator(scene):
    return psdr_cpu.IndirectEdgeIntegrator(scene)


@gin.configurable
def PrimaryEdgeIntegrator(scene):
    return psdr_cpu.PrimaryEdgeIntegrator(scene)

@gin.configurable
def BoundaryIntegrator(scene):
    return psdr_cpu.BoundaryIntegrator(scene)

@gin.configurable
def PixelBoundaryIntegrator(scene, enable_antithetic = True):
    integrator = psdr_cpu.PixelBoundaryIntegrator(scene)
    integrator.enable_antithetic = enable_antithetic
    return integrator

@gin.configurable
class CompositeIntegrator:
    def __init__(self, integrators):
        self.integrators = integrators
    def renderC(self, scene, options):
        # check
        assert(isinstance(scene, psdr_cpu.Scene))
        assert(isinstance(options, psdr_cpu.RenderOptions))
        img = np.zeros((scene.camera.height, scene.camera.width, 3), dtype=np.float32)
        for integrator in self.integrators:

            return integrator.renderC(scene, options).reshape((scene.camera.height, scene.camera.width, 3))
        return img
    
    def renderD(self, sceneAD, options, d_image):
        # check instance
        assert(isinstance(sceneAD, psdr_cpu.SceneAD))
        assert(isinstance(options, psdr_cpu.RenderOptions))
        camera = sceneAD.val.camera
        grad_img = np.zeros((camera.height, camera.width, 3), dtype=np.float32)
        for integrator in self.integrators:
            # options.seed = np.random.randint(0, 2**32)
            grad_img += integrator.renderD(sceneAD, options, d_image).reshape((camera.height, camera.width, 3))
        return grad_img
        

import pypsdr
@gin.configurable
def Scene(file_name, config = None):
    if config is None:
        cur_dir = os.getcwd()
        os.chdir(os.path.dirname(file_name))
        scene = psdr_cpu.Scene("./" + os.path.basename(file_name))
        os.chdir(cur_dir)
        return scene
    else:
        f = open(file_name, "r")
        scene_str = f.read()
        f.close()
        if 'mesh_file' in config:
            mesh_file_name = os.path.abspath(config['mesh_file'])
            scene_str = scene_str.format(mesh_file = mesh_file_name)
        if 'texture_file' in config:
            texture_file_name = os.path.abspath(config['texture_file'])
            scene_str = scene_str.format(texture_file = texture_file_name)
        cur_dir = os.getcwd()
        os.chdir(os.path.dirname(file_name))
        scene = psdr_cpu.load_from_string(scene_str)
        scene.configure()
        os.chdir(cur_dir)
        return scene

# ================== Transformation ===================


@gin.configurable
class Transformation:
    pass


@gin.configurable
@dataclasses.dataclass
class Translation(Transformation):
    translation: list

    def transform(self, vertex, delta=1.):
        return vertex + np.array(self.translation) * delta


@gin.configurable
@dataclasses.dataclass
class Rotation(Transformation):
    rotation: list

    def transform(self, vertex, delta=1.):
        r = R.from_rotvec(np.array(self.rotation) * delta)
        return r.apply(vertex)


@gin.configurable
@dataclasses.dataclass
class Transform:
    shape_id: int = -1
    vertex_id: int = -1  # -1 means all vertices
    transformation: Transformation = None
