import torch
import vredner
import pyvredner
import copy
import numpy as np


class RenderFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, vertx: torch.Tensor,
                py_scene: pyvredner.Scene,
                py_d_scene,
                options):
        ctx.py_d_scene = copy.deepcopy(py_d_scene)
        ctx.py_d_scene.set_zero()
        #! hard code
        ctx.d_scene = ctx.py_d_scene.c_obj()
        ctx.options = options
        py_scene.shapes[1].vertices = vertx
        ctx.py_scene = py_scene
        ctx.c_scene = py_scene.c_obj()
        #! hard code
        ctx.integrator = vredner.DirectADps()
        image = torch.zeros(py_scene.camera.resolution[1],
                            py_scene.camera.resolution[0],
                            3)
        ctx.integrator.render(ctx.c_scene, options,
                              vredner.float_ptr(image.data_ptr()))
        return image

    @staticmethod
    def backward(ctx, grad_image):
        primary_integrator = vredner.PrimaryEdgeIntegrator(ctx.c_scene)
        direct_integrator = vredner.DirectEdgeIntegrator(ctx.c_scene)

        dummy_image = torch.zeros(ctx.py_scene.camera.resolution[1],
                                  ctx.py_scene.camera.resolution[0],
                                  3)
        ctx.integrator.d_render(ctx.c_scene, ctx.d_scene, ctx.options,
                                vredner.float_ptr(grad_image.data_ptr()),
                                vredner.float_ptr(dummy_image.data_ptr()))
        # primary_integrator.d_render(ctx.c_scene, ctx.d_scene, ctx.options,
        #                             vredner.float_ptr(grad_image.data_ptr()),
        #                             vredner.float_ptr(dummy_image.data_ptr()))
        # direct_integrator.d_render(ctx.c_scene, ctx.d_scene, ctx.options,
        #                            vredner.float_ptr(grad_image.data_ptr()),
        #                            vredner.float_ptr(dummy_image.data_ptr()))
        #! hard code
        grad_vertx = torch.tensor(ctx.d_scene.shape_list[1].vertices)
        return (grad_vertx,
                None,
                None,
                None)
