from numpy import dtype
import torch
from pypsdr.optimizer import LargeSteps

from pypsdr.validate import BoundaryIntegrator, Direct, GuidingOptions, PrimaryEdgeIntegrator, IndirectEdgeIntegrator, DirectEdgeIntegrator, RenderOptions
from .utils.timer import Timer
import psdr_cpu
import numpy as np
from scipy.spatial.transform import Rotation as R


class RenderFunction(torch.autograd.Function):
    """for optimization of a single shape"""
    @staticmethod
    def forward(ctx, V: torch.Tensor, sensor_id, context, params):
        ctx.context = context
        ctx.context.update(params)  # update the default context

        sceneAD = context["sceneAD"]
        shape_id = context['shape_id']
        integrator = context['integrator']
        boundary_integrator = context['boundary_integrator']
        cameras = context['cameras']
        scene = sceneAD.val
        # update shape
        scene.shapes[shape_id].setVertices(V.detach().numpy())
        scene.shapes[shape_id].configure()  # !

        # switch camera
        ctx.sensor_id = sensor_id
        scene.camera = cameras[sensor_id]
        scene.configure()

        # render
        width, height = scene.camera.width, scene.camera.height
        image = integrator.renderC(scene, context["options"])\
            .reshape(height, width, 3)
        return torch.tensor(image, dtype=torch.float)

    @staticmethod
    def backward(ctx, grad_out: torch.Tensor):
        context = ctx.context
        sceneAD = context["sceneAD"]
        sceneAD.zeroGrad()
        scene = sceneAD.val
        shape_id = context['shape_id']
        options = context["options"]
        integrator = context['integrator']
        boundary_integrator = context['boundary_integrator']
        guiding_options = context["guiding_options"]
        # switch camera
        scene.camera = context["cameras"][ctx.sensor_id]
#        NOTE: need to reconfigure the primary edge integrator when switching camera
        boundary_integrator.configure_primary(scene)
        # TODO change to d_render
        # NOTE can't bind torch.tensor(float64) to Eigen::Matrix
        integrator.renderD(sceneAD,
                           options, grad_out.reshape(-1).numpy())
        print(np.array(sceneAD.val.shapes[0].vertices).sum())
        print(np.array(sceneAD.der.shapes[0].vertices).sum())

        # boundary integrator need to be reconstructed every time
        boundary_integrator.configure_primary(scene)

        # WARN: add guiding here
        if guiding_options.guide_type != "":
            if guiding_options.guide_option == "direct" or guiding_options.guide_option == "both":
                boundary_integrator.recompute_direct_edge(scene)
                if guiding_options.guide_type == "grid":
                    boundary_integrator.preprocess_grid_direct(
                        scene, guiding_options.grid_config_direct, options.max_bounces)
                else:
                    boundary_integrator.preprocess_aq_direct(
                        scene, guiding_options.aq_config_direct, options.max_bounces)
            if guiding_options.guide_option == "indirect" or guiding_options.guide_option == "both":
                boundary_integrator.recompute_indirect_edge(scene)
                if guiding_options.guide_type == "grid":
                    boundary_integrator.preprocess_grid_indirect(
                        scene, guiding_options.grid_config_indirect, options.max_bounces)
                else:
                    boundary_integrator.preprocess_aq_indirect(
                        scene, guiding_options.aq_config_indirect, options.max_bounces)

        boundary_integrator.renderD(sceneAD,
                                    options, grad_out.reshape(-1).numpy())

        grad_vertx = torch.tensor(
            sceneAD.der.shapes[shape_id].vertices)
        grad_vertx[grad_vertx.isnan()] = 0.
        return (grad_vertx,
                None,
                None,
                None)


class Render(torch.nn.Module):

    def __init__(self, sceneAD, cameras, integrator, options, shape_id, guiding_options=GuidingOptions("")):
        super(Render, self).__init__()
        self.context = dict(
            sceneAD=sceneAD,
            cameras=cameras,
            integrator=integrator,
            boundary_integrator=psdr_cpu.BoundaryIntegrator(sceneAD.val),
            options=options,
            guiding_options=guiding_options,
            shape_id=shape_id
        )

    def setState(self, state):
        self.context.update(state)
        if "sceneAD" in state:
            sceneAD = self.context["sceneAD"]
            self.context['boundary_integrator'] = psdr_cpu.BoundaryIntegrator(
                sceneAD.val)

    def forward(self, V: torch.Tensor, sensor_id, params={}):
        return RenderFunction.apply(
            V, sensor_id, self.context, params)


class RenderFunction2(torch.autograd.Function):
    """for optimization of a single shape"""
    @staticmethod
    def forward(ctx, V: torch.Tensor, sensor_id, context, params):
        ctx.context = context
        ctx.context.update(params)  # update the default context

        sceneAD = context["sceneAD"]
        shape_id = context['shape_id']
        integrator = context['integrator']
        cameras = context['cameras']
        scene = sceneAD.val
        ctx.sensor_id = sensor_id

        raw_position = V.detach().numpy()
        pos = context["obj_pos"][sensor_id]
        r = R.from_rotvec(pos[3] * pos[0:3])
        rot_position = np.dot(raw_position, r.as_matrix().transpose())
        scene.shapes[0].setVertices(rot_position)

        scene.shapes[0].configure()  # !
        scene.configure()

        # render
        width, height = scene.camera.width, scene.camera.height
        image = integrator.renderC(scene, context["options"])\
            .reshape(height, width, 3)
        return torch.tensor(image, dtype=torch.float)

    @staticmethod
    def backward(ctx, grad_out: torch.Tensor):
        context = ctx.context
        sceneAD = context["sceneAD"]
        sceneAD.zeroGrad()
        scene = sceneAD.val
        shape_id = context['shape_id']
        options = context["options"]
        integrator = context['integrator']
        boundary_integrator = context['boundary_integrator']
        guiding_options = context["guiding_options"]
        # switch rotation?

        # raw_position = V.detach().numpy()
        # scene.configure()
        # TODO change to d_render
        # NOTE can't bind torch.tensor(float64) to Eigen::Matrix
        integrator.renderD(sceneAD,
                           options, grad_out.reshape(-1).numpy())
        # boundary integrator need to be reconstructed every time
        boundary_integrator.configure_primary(scene)

        # WARN: add guiding here
        if guiding_options.guide_type != "":
            if guiding_options.guide_option == "direct" or guiding_options.guide_option == "both":
                boundary_integrator.recompute_direct_edge(scene)
                if guiding_options.guide_type == "grid":
                    boundary_integrator.preprocess_grid_direct(
                        scene, guiding_options.grid_config_direct, options.max_bounces)
                else:
                    boundary_integrator.preprocess_aq_direct(
                        scene, guiding_options.aq_config_direct, options.max_bounces)
            if guiding_options.guide_option == "indirect" or guiding_options.guide_option == "both":
                boundary_integrator.recompute_indirect_edge(scene)
                if guiding_options.guide_type == "grid":
                    boundary_integrator.preprocess_grid_indirect(
                        scene, guiding_options.grid_config_indirect, options.max_bounces)
                else:
                    boundary_integrator.preprocess_aq_indirect(
                        scene, guiding_options.aq_config_indirect, options.max_bounces)

        boundary_integrator.renderD(sceneAD,
                                    options, grad_out.reshape(-1).numpy())

        raw_grad = sceneAD.der.shapes[0].vertices
        pos = context["obj_pos"][ctx.sensor_id]
        r = R.from_rotvec((-pos[3]) * pos[0:3])
        rot_grad = np.dot(raw_grad, r.as_matrix().transpose())

        grad_vertx = torch.tensor(rot_grad)
        grad_vertx[grad_vertx.isnan()] = 0.

        # grad_vertx = torch.tensor(
        #     sceneAD.der.shapes[shape_id].vertices)
        # grad_vertx[grad_vertx.isnan()] = 0.

        return (grad_vertx,
                None,
                None,
                None)


class Render2(torch.nn.Module):

    def __init__(self, sceneAD, cameras, integrator, options, guiding_options, shape_id, obj_pos):
        super(Render2, self).__init__()
        self.context = dict(
            sceneAD=sceneAD,
            cameras=cameras,
            integrator=integrator,
            boundary_integrator=psdr_cpu.BoundaryIntegrator(sceneAD.val),
            options=options,
            guiding_options=guiding_options,
            shape_id=shape_id,
            obj_pos=obj_pos
        )

    def setState(self, state):
        self.context.update(state)
        if "sceneAD" in state:
            sceneAD = self.context["sceneAD"]
            self.context['boundary_integrator'] = psdr_cpu.BoundaryIntegrator(
                sceneAD.val)

    def forward(self, V: torch.Tensor, sensor_id, params={}):
        return RenderFunction2.apply(
            V, sensor_id, self.context, params)


class RenderFunction3(torch.autograd.Function):
    """for optimization of a single shape"""
    @staticmethod
    def forward(ctx, texture: torch.Tensor, context):
        ctx.context = context
        sceneAD = context["sceneAD"]
        integrator = context['integrator']
        scene = sceneAD.val
        # update shape
        bitmap = psdr_cpu.Bitmap(
            texture.reshape(-1).detach().numpy(), (225, 225))
        scene.bsdfs[0].reflectance = bitmap
        scene.configure()
        # render
        width, height = scene.camera.width, scene.camera.height
        image = integrator.renderC(scene, context["options"])\
            .reshape(height, width, 3)
        return torch.tensor(image, dtype=torch.float)

    @staticmethod
    def backward(ctx, grad_out: torch.Tensor):
        context = ctx.context
        sceneAD = context["sceneAD"]
        sceneAD.zeroGrad()
        scene = sceneAD.val
        options = context["options"]
        integrator = context['integrator']
        # NOTE can't bind torch.tensor(float64) to Eigen::Matrix
        integrator.renderD(sceneAD,
                           options, grad_out.reshape(-1).numpy())
        width, height = scene.camera.width, scene.camera.height
        grad_vertx = torch.tensor(
            np.array(sceneAD.der.bsdfs[0].reflectance.m_data)).reshape(225, 225, 3)
        grad_vertx[grad_vertx.isnan()] = 0.
        return (grad_vertx,
                None)


class Render3(torch.nn.Module):
    def __init__(self, sceneAD, integrator, options):
        super(Render3, self).__init__()
        self.context = dict(
            sceneAD=sceneAD,
            integrator=integrator,
            options=options
        )

    def setState(self, state):
        self.context.update(state)

    def forward(self, texture: torch.Tensor, params):
        self.context.update(params)
        return RenderFunction3.apply(
            texture, self.context)


class JointRenderFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, context):
        ctx.context = context
        with Timer("forward pass") as timer:
            image = context["integrator"].renderC(context["sceneAD"].val,
                                                context["options"])
            return torch.tensor(image, dtype=torch.float).reshape(context["sceneAD"].val.camera.height,
                                                                context["sceneAD"].val.camera.width,
                                                                3)

    @staticmethod
    def backward(ctx, img_grad):
        context = ctx.context
        sceneAD = context["sceneAD"]
        context["sceneAD"].zeroGrad()
        # print(np.array(sceneAD.der.shapes[0].vertices).sum())
        with Timer("backward pass") as timer:
            context["integrator"].renderD(context["sceneAD"],
                                        context["options"],
                                        img_grad.reshape(-1).numpy().astype(np.float32))
        # print(np.array(sceneAD.val.shapes[0].vertices).sum())
        # print(np.array(sceneAD.der.shapes[0].vertices).sum())

        context["boundary_integrator"].configure_primary(
            context["sceneAD"].val)
        context["boundary_integrator"].renderD(context["sceneAD"],
                                               context["options"],
                                               img_grad.reshape(-1).numpy().astype(np.float32))
        param_map = {}
        for key in context["param_map"]:
            param_map[key] = torch.tensor(
                np.array(eval("sceneAD.der." + key)), dtype=torch.float32)
        param_grad = torch.cat([param_map[key].reshape(-1)
                               for key in param_map])

        # print
        for key in param_map:
            print("grad ", key, param_map[key].sum())

        param_grad = torch.nan_to_num(param_grad)
        return param_grad, None


class Model:
    def __init__(self,
                 sceneAD,
                 param_ids,  # parameter indices in the Scene
                 context={}):
        self.sceneAD = sceneAD
        # parameters to be optimized
        self.param_map = {
            key: torch.tensor(np.array(self.eval("self.sceneAD.val." + key)),
                              dtype=torch.float32,
                              requires_grad=True)
            for key in param_ids
        }
        # optimizers for parameters in the param_map
        self.optimizers = [
            self.get_optimizer(params=[self.param_map[key]],
                               props=param_ids[key])
            for key in param_ids
        ]
        self.default_context = {
            "sceneAD": sceneAD,
            "integrator": Direct(),
            "boundary_integrator": psdr_cpu.BoundaryIntegrator(sceneAD.val),
            "options": RenderOptions(seed=1,
                                     max_bounce=1,
                                     spp=128,
                                     sppe=0,
                                     sppse0=0,
                                     sppse1=0),
            "param_map": self.param_map,  # parameters to be optimized
            "cameras": [sceneAD.val.camera],
            "sensor_id": 0
        }
        self.context = self.default_context
        # update context
        self.context.update(context)

    def render(self, context={}):
        # update camera
        ctx = self.context
        ctx.update(context)
        ctx["sceneAD"].val.camera = ctx["cameras"][ctx["sensor_id"]]
        x = torch.cat([ctx['param_map'][key].reshape(-1)
                      for key in ctx['param_map']])
        return JointRenderFunction.apply(x, ctx)

    def eval(self, s):
        return eval(s)

    def step(self):
        for optimizer in self.optimizers:
            optimizer.step()
        self.update_scene()

    def zero_grad(self):
        for optimizer in self.optimizers:
            optimizer.zero_grad()

    def update_scene(self):
        for key in self.param_map:
            exec("self.sceneAD.val." + key +
                 " = self.param_map[key].detach().numpy()")
        self.sceneAD.val.configure()
        self.context['boundary_integrator'] = psdr_cpu.BoundaryIntegrator(
            self.sceneAD.val)

    def get_scene(self):
        return self.sceneAD.val

    def get_optimizer(self, params, props):
        name = props["optimizer"]
        if name == "adam":
            props = {
                key: props[key]
                for key in props if key in ["lr", "betas", "eps"]
            }
            return torch.optim.Adam(params, **props)
        elif name == "largestep":
            is_gpu = "gpu" in props and props["gpu"]
            props = {
                key: props[key]
                for key in props if key in ["F", "lr", "lmbda"]
            }
            scene = self.sceneAD.val
            props["F"] = torch.tensor(
                np.array(self.eval("self.sceneAD.val." + props["F"])))
            
            if is_gpu:
                from pypsdr.optimizer_gpu import LargeStepsGpu
                return LargeStepsGpu(params, **props)
            else:
                return LargeSteps(params, **props)
        else:
            raise NotImplementedError
