from numpy import dtype
import torch
from pypsdr.optimizer import LargeSteps, LargeSteps2

from pypsdr.validate import BoundaryIntegrator, Direct, GuidingOptions, PrimaryEdgeIntegrator, IndirectEdgeIntegrator, DirectEdgeIntegrator, RenderOptions
from .utils.timer import Timer
import psdr_cpu
import numpy as np
from scipy.spatial.transform import Rotation as R

import drjit as dr
import mitsuba as mi
import time

class RenderFunction(torch.autograd.Function):
    """for optimization of a single shape"""
    @staticmethod
    def forward(ctx, V: torch.Tensor, sensor_id, context, params):
        ctx.context = context
        ctx.context.update(params)  # update the default context

        sceneAD = context["sceneAD"]
        shape_id = context['shape_id']
        integrator = context['integrator']
        boundary_integrator = context['boundary_integrator']
        cameras = context['cameras']
        scene = sceneAD.val
        # update shape
        scene.shapes[shape_id].setVertices(V.detach().numpy())
        scene.shapes[shape_id].configure()  # !

        # switch camera
        ctx.sensor_id = sensor_id
        scene.camera = cameras[sensor_id]
        scene.configure()

        # render
        width, height = scene.camera.width, scene.camera.height
        image = integrator.renderC(scene, context["options"])\
            .reshape(height, width, 3)
        return torch.tensor(image, dtype=torch.float)

    @staticmethod
    def backward(ctx, grad_out: torch.Tensor):
        context = ctx.context
        sceneAD = context["sceneAD"]
        sceneAD.zeroGrad()
        scene = sceneAD.val
        shape_id = context['shape_id']
        options = context["options"]
        integrator = context['integrator']
        boundary_integrator = context['boundary_integrator']
        guiding_options = context["guiding_options"]
        mala_options = context["mala_options"]
        use_edge_sampling = context["use_edge_sampling"]
        # switch camera
        scene.camera = context["cameras"][ctx.sensor_id]
#        NOTE: need to reconfigure the primary edge integrator when switching camera
        boundary_integrator.configure_mala(mala_options)
        boundary_integrator.configure_primary(scene)
        for i in range(len(scene.shapes)):
            scene.shapes[i].enable_edge = False
        scene.shapes[shape_id].enable_edge = True
        boundary_integrator.configure(scene)
        # TODO change to d_render
        # NOTE can't bind torch.tensor(float64) to Eigen::Matrix
        integrator.renderD(sceneAD,
                           options, grad_out.reshape(-1).numpy())

        # boundary integrator need to be reconstructed every time
        boundary_integrator.configure_primary(scene)
        
        direct_integrator = psdr_cpu.DirectEdgeIntegrator(scene)
        indirect_integrator = psdr_cpu.IndirectEdgeIntegrator(scene)
        primary_integrator = psdr_cpu.PrimaryEdgeIntegrator(scene)
        # MLT_integrator.load_MALA_config(mala_options)

        # WARN: add guiding here
        if guiding_options.guide_type != "":
            if guiding_options.guide_option == "direct" or guiding_options.guide_option == "both":
                boundary_integrator.recompute_direct_edge(scene)
                if guiding_options.guide_type == "grid":
                    boundary_integrator.preprocess_grid_direct(
                        scene, guiding_options.grid_config_direct, options.max_bounces)
                    direct_integrator.preprocess_grid(
                        scene, guiding_options.grid_config_direct, options.max_bounces)
                else:
                    boundary_integrator.preprocess_aq_direct(
                        scene, guiding_options.aq_config_direct, options.max_bounces)
                    direct_integrator.preprocess_aq(
                        scene, guiding_options.aq_config_direct, options.max_bounces)
            if guiding_options.guide_option == "indirect" or guiding_options.guide_option == "both":
                boundary_integrator.recompute_indirect_edge(scene)
                if guiding_options.guide_type == "grid":
                    boundary_integrator.preprocess_grid_indirect(
                        scene, guiding_options.grid_config_indirect, options.max_bounces)
                    indirect_integrator.preprocess_grid(
                        scene, guiding_options.grid_config_indirect, options.max_bounces)
                else:
                    boundary_integrator.preprocess_aq_indirect(
                        scene, guiding_options.aq_config_indirect, options.max_bounces)
                    indirect_integrator.preprocess_aq(
                        scene, guiding_options.aq_config_indirect, options.max_bounces)

        # mb = options.max_bounces
        # options.max_bounces = 1
        if use_edge_sampling is False:
            boundary_integrator.renderD(sceneAD,
                                        options, grad_out.reshape(-1).numpy())
        else:
            primary_integrator.renderD(sceneAD,
                                        options, grad_out.reshape(-1).numpy())
            direct_integrator.renderD(sceneAD,
                                        options, grad_out.reshape(-1).numpy())
            indirect_integrator.renderD(sceneAD,
                                        options, grad_out.reshape(-1).numpy())
        # MLT_integrator.renderD(sceneAD,
        #                             options, grad_out.reshape(-1).numpy())
        grad_vertx = torch.tensor(
            sceneAD.der.shapes[shape_id].vertices)
        grad_vertx[grad_vertx.isnan()] = 0.
        print("backward")
        return (grad_vertx,
                None,
                None,
                None)


class Render(torch.nn.Module):

    def __init__(self, sceneAD, cameras, integrator, options, shape_id, 
                 guiding_options=GuidingOptions(""), mala_options=psdr_cpu.MALAOptions(), use_edge_sampling=False):
        super(Render, self).__init__()
        b_int = psdr_cpu.BoundaryIntegrator(sceneAD.val)
        b_int.configure_mala(mala_options)
        self.context = dict(
            sceneAD=sceneAD,
            cameras=cameras,
            integrator=integrator,
            boundary_integrator=b_int,
            options=options,
            guiding_options=guiding_options,
            shape_id=shape_id,
            use_edge_sampling=use_edge_sampling
        )
        print(shape_id)

    def setState(self, state):
        self.context.update(state)
        if "sceneAD" in state:
            sceneAD = self.context["sceneAD"]
            self.context['boundary_integrator'] = psdr_cpu.BoundaryIntegrator(
                sceneAD.val)

    def forward(self, V: torch.Tensor, sensor_id, params={}):
        return RenderFunction.apply(
            V, sensor_id, self.context, params)


class RenderFunction2(torch.autograd.Function):
    """for optimization of a single shape"""
    @staticmethod
    def forward(ctx, V: torch.Tensor, sensor_id, context, params):
        ctx.context = context
        ctx.context.update(params)  # update the default context

        sceneAD = context["sceneAD"]
        shape_id = context['shape_id']
        integrator = context['integrator']
        cameras = context['cameras']
        scene = sceneAD.val
        ctx.sensor_id = sensor_id

        raw_position = V.detach().numpy()
        pos = context["obj_pos"][sensor_id]
        r = R.from_rotvec(pos[3] * pos[0:3])
        rot_position = np.dot(raw_position, r.as_matrix().transpose())
        scene.shapes[0].setVertices(rot_position)

        scene.shapes[0].configure()  # !
        scene.configure()

        # render
        width, height = scene.camera.width, scene.camera.height
        image = integrator.renderC(scene, context["options"])\
            .reshape(height, width, 3)
        return torch.tensor(image, dtype=torch.float)

    @staticmethod
    def backward(ctx, grad_out: torch.Tensor):
        context = ctx.context
        sceneAD = context["sceneAD"]
        sceneAD.zeroGrad()
        scene = sceneAD.val
        shape_id = context['shape_id']
        options = context["options"]
        integrator = context['integrator']
        boundary_integrator = context['boundary_integrator']
        guiding_options = context["guiding_options"]
        mala_options = context["mala_options"]
        use_edge_sampling = context["use_edge_sampling"]
        # switch rotation?

        # raw_position = V.detach().numpy()
        # scene.configure()
        # TODO change to d_render
        # NOTE can't bind torch.tensor(float64) to Eigen::Matrix
        integrator.renderD(sceneAD,
                           options, grad_out.reshape(-1).numpy())
        # boundary integrator need to be reconstructed every time
        boundary_integrator.configure_primary(scene)
        boundary_integrator.configure_mala(mala_options)
        
        direct_integrator = psdr_cpu.DirectEdgeIntegrator(scene)
        indirect_integrator = psdr_cpu.IndirectEdgeIntegrator(scene)
        primary_integrator = psdr_cpu.PrimaryEdgeIntegrator(scene)
        # MLT_integrator.load_MALA_config(mala_options)

        # WARN: add guiding here
        if guiding_options.guide_type != "":
            if guiding_options.guide_option == "direct" or guiding_options.guide_option == "both":
                boundary_integrator.recompute_direct_edge(scene)
                if guiding_options.guide_type == "grid":
                    boundary_integrator.preprocess_grid_direct(
                        scene, guiding_options.grid_config_direct, options.max_bounces)
                    direct_integrator.preprocess_grid(
                        scene, guiding_options.grid_config_direct, options.max_bounces)
                else:
                    boundary_integrator.preprocess_aq_direct(
                        scene, guiding_options.aq_config_direct, options.max_bounces)
                    direct_integrator.preprocess_aq(
                        scene, guiding_options.aq_config_direct, options.max_bounces)
            if guiding_options.guide_option == "indirect" or guiding_options.guide_option == "both":
                boundary_integrator.recompute_indirect_edge(scene)
                if guiding_options.guide_type == "grid":
                    boundary_integrator.preprocess_grid_indirect(
                        scene, guiding_options.grid_config_indirect, options.max_bounces)
                    indirect_integrator.preprocess_grid(
                        scene, guiding_options.grid_config_indirect, options.max_bounces)
                else:
                    boundary_integrator.preprocess_aq_indirect(
                        scene, guiding_options.aq_config_indirect, options.max_bounces)
                    indirect_integrator.preprocess_aq(
                        scene, guiding_options.aq_config_indirect, options.max_bounces)

        # mb = options.max_bounces
        # options.max_bounces = 1
        if use_edge_sampling is False:
            boundary_integrator.renderD(sceneAD,
                                        options, grad_out.reshape(-1).numpy())
        else:
            primary_integrator.renderD(sceneAD,
                                        options, grad_out.reshape(-1).numpy())
            direct_integrator.renderD(sceneAD,
                                        options, grad_out.reshape(-1).numpy())
            indirect_integrator.renderD(sceneAD,
                                        options, grad_out.reshape(-1).numpy())
        # options.max_bounces = mb

        raw_grad = sceneAD.der.shapes[0].vertices
        pos = context["obj_pos"][ctx.sensor_id]
        r = R.from_rotvec((-pos[3]) * pos[0:3])
        rot_grad = np.dot(raw_grad, r.as_matrix().transpose())

        grad_vertx = torch.tensor(rot_grad)
        grad_vertx[grad_vertx.isnan()] = 0.

        # grad_vertx = torch.tensor(
        #     sceneAD.der.shapes[shape_id].vertices)
        # grad_vertx[grad_vertx.isnan()] = 0.

        return (grad_vertx,
                None,
                None,
                None)


class Render2(torch.nn.Module):

    def __init__(self, sceneAD, cameras, integrator, options, guiding_options, shape_id, obj_pos, 
                 mala_options=psdr_cpu.MALAOptions(), use_edge_sampling=False):
        super(Render2, self).__init__()
        b_int = psdr_cpu.BoundaryIntegrator(sceneAD.val)
        b_int.configure_mala(mala_options)
        self.context = dict(
            sceneAD=sceneAD,
            cameras=cameras,
            integrator=integrator,
            boundary_integrator=b_int,
            options=options,
            guiding_options=guiding_options,
            shape_id=shape_id,
            obj_pos=obj_pos,
            use_edge_sampling=use_edge_sampling
        )

    def setState(self, state):
        self.context.update(state)
        if "sceneAD" in state:
            sceneAD = self.context["sceneAD"]
            self.context['boundary_integrator'] = psdr_cpu.BoundaryIntegrator(
                sceneAD.val)

    def forward(self, V: torch.Tensor, sensor_id, params={}):
        return RenderFunction2.apply(
            V, sensor_id, self.context, params)


class RenderFunction3(torch.autograd.Function):
    """for optimization of a single shape"""
    @staticmethod
    def forward(ctx, texture: torch.Tensor, context):
        ctx.context = context
        sceneAD = context["sceneAD"]
        integrator = context['integrator']
        scene = sceneAD.val
        # update shape
        bitmap = psdr_cpu.Bitmap(
            texture.reshape(-1).detach().numpy(), (225, 225))
        scene.bsdfs[0].reflectance = bitmap
        scene.configure()
        # render
        width, height = scene.camera.width, scene.camera.height
        image = integrator.renderC(scene, context["options"])\
            .reshape(height, width, 3)
        return torch.tensor(image, dtype=torch.float)

    @staticmethod
    def backward(ctx, grad_out: torch.Tensor):
        context = ctx.context
        sceneAD = context["sceneAD"]
        sceneAD.zeroGrad()
        scene = sceneAD.val
        options = context["options"]
        integrator = context['integrator']
        # NOTE can't bind torch.tensor(float64) to Eigen::Matrix
        integrator.renderD(sceneAD,
                           options, grad_out.reshape(-1).numpy())
        width, height = scene.camera.width, scene.camera.height
        grad_vertx = torch.tensor(
            np.array(sceneAD.der.bsdfs[0].reflectance.m_data)).reshape(225, 225, 3)
        grad_vertx[grad_vertx.isnan()] = 0.
        return (grad_vertx,
                None)


class Render3(torch.nn.Module):
    def __init__(self, sceneAD, integrator, options):
        super(Render3, self).__init__()
        self.context = dict(
            sceneAD=sceneAD,
            integrator=integrator,
            options=options
        )

    def setState(self, state):
        self.context.update(state)

    def forward(self, texture: torch.Tensor, params):
        self.context.update(params)
        return RenderFunction3.apply(
            texture, self.context)


class JointRenderFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, context):
        ctx.context = context
        with Timer("forward pass") as timer:
            image = context["integrator"].renderC(context["sceneAD"].val,
                                                  context["options"])
            return torch.tensor(image, dtype=torch.float).reshape(context["sceneAD"].val.camera.height,
                                                                  context["sceneAD"].val.camera.width,
                                                                  3)

    @staticmethod
    def backward(ctx, img_grad):
        context = ctx.context
        sceneAD = context["sceneAD"]
        context["sceneAD"].zeroGrad()
        # print(np.array(sceneAD.der.shapes[0].vertices).sum())
        with Timer("backward pass") as timer:
            context["integrator"].renderD(context["sceneAD"],
                                          context["options"],
                                          img_grad.reshape(-1).numpy().astype(np.float32))
        # print(np.array(sceneAD.val.shapes[0].vertices).sum())
        # print(np.array(sceneAD.der.shapes[0].vertices).sum())

        context["boundary_integrator"].configure_primary(
            context["sceneAD"].val)
        context["boundary_integrator"].renderD(context["sceneAD"],
                                               context["options"],
                                               img_grad.reshape(-1).numpy().astype(np.float32))
        param_map = {}
        for key in context["param_map"]:
            param_map[key] = torch.tensor(
                np.array(eval("sceneAD.der." + key)), dtype=torch.float32)
        param_grad = torch.cat([param_map[key].reshape(-1)
                               for key in param_map])

        # print
        for key in param_map:
            print("grad ", key, param_map[key].sum())

        param_grad = torch.nan_to_num(param_grad)
        return param_grad, None


class Model:
    def __init__(self,
                 sceneAD,
                 param_ids,  # parameter indices in the Scene
                 context={}):
        self.sceneAD = sceneAD
        # parameters to be optimized
        self.param_map = {
            key: torch.tensor(np.array(self.eval("self.sceneAD.val." + key)),
                              dtype=torch.float32,
                              requires_grad=True)
            for key in param_ids
        }
        # optimizers for parameters in the param_map
        self.optimizers = [
            self.get_optimizer(params=[self.param_map[key]],
                               props=param_ids[key])
            for key in param_ids
        ]
        self.default_context = {
            "sceneAD": sceneAD,
            "integrator": Direct(),
            "boundary_integrator": psdr_cpu.BoundaryIntegrator(sceneAD.val),
            "options": RenderOptions(seed=1,
                                     max_bounce=1,
                                     spp=128,
                                     sppe=0,
                                     sppse0=0,
                                     sppse1=0),
            "param_map": self.param_map,  # parameters to be optimized
            "cameras": [sceneAD.val.camera],
            "sensor_id": 0
        }
        self.context = self.default_context
        # update context
        self.context.update(context)

    def render(self, context={}):
        # update camera
        ctx = self.context
        ctx.update(context)
        # multi-view
        ctx["sceneAD"].val.camera = ctx["cameras"][ctx["sensor_id"]]
        # the parameters that will be updated in the scene
        render_param_map = {key: ctx["param_map"][key]
                            for key in ctx["param_map"]}
        # multi-pose. do the transformation here and update the param_map
        '''
        pose_setting = {
            "id": "shapes[0].vertices",
            "poses": [np.array, np.array, ...],
            "pose_id" : 0,
        }
        '''
        pose_setting = ctx["pose_setting"]
        if pose_setting:
            key = pose_setting["id"]
            # validate the pose_setting
            poses = pose_setting["poses"]
            assert(key in ctx["param_map"])
            assert(key.endswith("vertices"))  # make sure it's a shape
            # get the rotation
            pose_id = pose_setting["pose_id"]
            rotation = poses[pose_id]
            # assert(rotation.shape == (3, 3))  # make sure it's a rotation
            # transform the vertices
            V = ctx["param_map"][key]  # get the original vertices
            if rotation.shape == (3, 3):
                with Timer("transform vertices") as timer:
                    V = V.double() @ torch.tensor(rotation, dtype=torch.float64).T  # in the computation graph
                    V = V.float()
            elif rotation.shape == (4, 4):
                V = V.double()
                V = torch.hstack([V, torch.ones((V.shape[0], 1), dtype=torch.float64)])
                V = V @ torch.tensor(rotation, dtype=torch.float64).T   # in the computation graph
                V = V[:, :3] / V[:, 3:]
                V = V.float()
            render_param_map[key] = V  # update the param_map

            # update scene
            self.update_scene(render_param_map)

        x = torch.cat([render_param_map[key].reshape(-1)
                      for key in render_param_map])
        return JointRenderFunction.apply(x, ctx)

    def eval(self, s):
        return eval(s)

    def step(self):
        for optimizer in self.optimizers:
            optimizer.step()
        self.update_scene(self.param_map)

    def zero_grad(self):
        for optimizer in self.optimizers:
            optimizer.zero_grad()

    def update_scene(self, param_map):
        for key in param_map:
            exec("self.sceneAD.val." + key +
                 " = param_map[key].detach().numpy()")
        # update relavent scene data structure
        self.sceneAD.val.configure()
        # update boundary integrator
        self.context['boundary_integrator'] = psdr_cpu.BoundaryIntegrator(
            self.sceneAD.val)

    def get_scene(self):
        return self.sceneAD.val

    def get_optimizer(self, params, props):
        name = props["optimizer"]
        if name == "adam":
            props = {
                key: props[key]
                for key in props if key in ["lr", "betas", "eps"]
            }
            return torch.optim.Adam(params, **props)
        elif name == "largestep":
            is_gpu = "gpu" in props and props["gpu"]
            props = {
                key: props[key]
                for key in props if key in ["F", "lr", "lmbda"]
            }
            scene = self.sceneAD.val
            props["F"] = torch.tensor(
                np.array(self.eval("self.sceneAD.val." + props["F"])))

            if is_gpu:
                from pypsdr.optimizer_gpu import LargeStepsGpu
                return LargeStepsGpu(params, **props)
            else:
                return LargeSteps(params, **props)
        elif name == "largestep2":
            props = {
                key: props[key]
                for key in props if key in ["F", "lr", "lmbda"]
            }
            scene = self.sceneAD.val
            props["F"] = torch.tensor(
                np.array(self.eval("self.sceneAD.val." + props["F"])))
            return LargeSteps2(params, **props)
        else:
            raise NotImplementedError

class RenderMi(torch.nn.Module):

    def __init__(self, integrator, obj_pos):
        super(RenderMi, self).__init__()
        self.integrator = integrator
        self.obj_pos = obj_pos

    def forward(self, scene, params, key, seed, spp, V, pos_id, fwd_img=None):
        
        context = dict(
            scene=scene,
            params=params,
            key=key,
            integrator=self.integrator,
            seed=seed,
            spp=spp,
            obj_pos=self.obj_pos,
            pos_id = pos_id,
        )
        return RenderFunctionMi.apply(
            context, V, fwd_img)

class RenderFunctionMi(torch.autograd.Function):
    """for optimization of a single shape"""
    @staticmethod
    def forward(ctx, context, V, fwd_img=None):
        scene = context["scene"]
        key = context["key"]
        params = context["params"]
        integrator = context["integrator"]
        seed = context["seed"]
        spp = context["spp"]
        ctx.context = context
        ctx.V = V
        raw_position = V.detach().numpy()
        pos = context["obj_pos"][context["pos_id"]]
        r = R.from_rotvec(pos[3] * pos[0:3])
        rot_position = np.dot(raw_position, r.as_matrix().transpose())
        ctx.fwd_img = fwd_img
        ctx.V_shape = rot_position.shape
        if fwd_img is not None:
            return fwd_img
        else:
            V_mi = dr.llvm.ad.Float(rot_position.flatten())
            params[key] = V_mi
            params.update()
            time0 = time.time()
            with dr.suspend_grad():
                img = np.array(mi.render(scene, params, integrator=integrator, spp=spp, seed=seed))
            time1 = time.time()
            print("forward time: ", time1 - time0)
            return torch.tensor(img)

    @staticmethod
    def backward(ctx, grad_out: torch.Tensor):
        context = ctx.context
        scene = context["scene"]
        key = context["key"]
        params = context["params"]
        integrator = context["integrator"]
        seed = context["seed"]
        spp = context["spp"]
        V = ctx.V
        
        raw_position = V.detach().numpy()
        pos = context["obj_pos"][context["pos_id"]]
        r = R.from_rotvec(pos[3] * pos[0:3])
        rot_position = np.dot(raw_position, r.as_matrix().transpose())

        V_mi = dr.llvm.ad.Float(rot_position.flatten())
        dr.enable_grad(V_mi)
        params[key] = V_mi
        params.update()
        # grad_out_mirrored = torch.flip(grad_out, [0])
        time0 = time.time()
        grad_mi = dr.llvm.ad.TensorXf(grad_out)
        integrator.render_backward(scene, params, grad_mi)
        time1 = time.time()
        print("backward time: ", time1 - time0)
        print(" ")
        grad_V = dr.grad(V_mi)
        raw_grad = np.array(grad_V).reshape(ctx.V_shape)
        pos = context["obj_pos"][context["pos_id"]]
        r = R.from_rotvec((-pos[3]) * pos[0:3])
        rot_grad = np.dot(raw_grad, r.as_matrix().transpose())

        grad_vertx = torch.tensor(rot_grad)
        grad_vertx[grad_vertx.isnan()] = 0.
        # print(torch.sum(grad_vertx))
        return (None,
                grad_vertx,
                None)
        
class RenderMi2(torch.nn.Module):

    def __init__(self, integrator_b, integrator_i, cameras, cameras_psdr, max_bounces, shape_id):
        super(RenderMi2, self).__init__()
        self.integrator_b = integrator_b # boundary integrator from mitsuba
        self.integrator_i = integrator_i # interior integrator from psdr
        self.cameras = cameras
        self.cameras_psdr = cameras_psdr
        self.max_bounces = max_bounces
        self.shape_id = shape_id

    def forward(self, scene, sceneAD, params, key, seed, spp, V, cam_id): 
        # scene from mitsuba, sceneAD from psdr
        context = dict(
            scene=scene,
            sceneAD = sceneAD,
            params=params,
            key=key,
            integrator_b=self.integrator_b,
            integrator_i=self.integrator_i,
            seed=seed,
            spp=spp,
            cameras=self.cameras,
            cameras_psdr=self.cameras_psdr,
            cam_id=cam_id,
            max_bounces=self.max_bounces,
            shape_id=self.shape_id
        )
        return RenderFunctionMi2.apply(
            context, V)

class RenderFunctionMi2(torch.autograd.Function):
    """for optimization of a single shape"""
    @staticmethod
    def forward(ctx, context, V):
        scene = context["scene"]
        sceneAD = context["sceneAD"]
        key = context["key"]
        params = context["params"]
        integrator_b = context["integrator_b"]
        # integrator_i = context["integrator_i"]
        seed = context["seed"]
        spp = context["spp"]
        ctx.context = context
        ctx.V = V
        ctx.V_shape = V.shape
        cam_id = context["cam_id"]
        cameras = context["cameras"]
        cameras_psdr = context["cameras_psdr"]
        shape_id=context["shape_id"]
        max_bounces=context["max_bounces"]
        # update shape
        # sceneAD.val.shapes[shape_id].setVertices(V.detach().numpy())
        # sceneAD.val.shapes[shape_id].configure()  # !
        # sceneAD.val.configure()
        print("cam id: ", cam_id)

        # switch camera
        # ctx.sensor_id = cam_id
        # sceneAD.val.camera = cameras_psdr[cam_id]
        # sceneAD.val.configure()
        
        # options = RenderOptions(
        #         seed, max_bounces, spp, 0, 0, 0, 0)
        
        # width, height = sceneAD.val.camera.width, sceneAD.val.camera.height
        # image = integrator_i.renderC(sceneAD.val, options)\
        #     .reshape(height, width, 3)
        V_mi = dr.llvm.ad.Float(V.flatten())
        params[key] = V_mi
        params.update()
        image = np.array(mi.render(scene, params, integrator=integrator_b, sensor=cameras[cam_id], spp=spp, seed=seed))
        return torch.tensor(image)

    @staticmethod
    def backward(ctx, grad_out: torch.Tensor):
        context = ctx.context
        scene = context["scene"]
        sceneAD = context["sceneAD"]
        key = context["key"]
        params = context["params"]
        integrator_b = context["integrator_b"]
        integrator_i = context["integrator_i"]
        seed = context["seed"]
        spp = context["spp"]
        cam_id = context["cam_id"]
        cam = context["cameras"][cam_id]
        max_bounces=context["max_bounces"]
        shape_id=context["shape_id"]
        cameras_psdr = context["cameras_psdr"]
        
        sceneAD.val.camera = cameras_psdr[cam_id]
        sceneAD.val.configure()
        
        V = ctx.V
        V_mi = dr.llvm.ad.Float(V.flatten())
        dr.enable_grad(V_mi)
        params[key] = V_mi
        params.update()
        # grad_out_mirrored = torch.flip(grad_out, [0])
        grad_mi = dr.llvm.ad.TensorXf(grad_out.to(torch.float32))
        integrator_b.render_backward(scene, params, grad_mi, sensor=cam)
        grad_V = dr.grad(V_mi)
        grad_V = np.array(grad_V).reshape(ctx.V_shape)
        # options = RenderOptions(
        #         seed, max_bounces, spp, 0, 0, 0, 0)
        # integrator_i.renderD(sceneAD,
        #                    options, grad_out.reshape(-1).numpy())
        grad_boundary = torch.tensor(grad_V)
        # grad_interior = torch.tensor(
        #     sceneAD.der.shapes[shape_id].vertices)
        grad_vertx = grad_boundary
        grad_vertx[grad_vertx.isnan()] = 0.
        # print(torch.sum(grad_vertx))
        return (None,
                grad_vertx,
                None)