import vredner
import pyvredner
import torch

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

scene, integrator = pyvredner.load_mitsuba("./scene_side.xml")
integrator = vredner.Direct()
cam = scene.camera
c_scene = scene.c_obj()
#! hard code
# c_scene.shape_list[1].requires_grad = True
options = vredner.RenderOptions(13,     # random seed
                                1024,    # spp
                                1,      # max bounces
                                128,      # sppse0
                                128,      # sppe
                                False)

dscene, integrator = pyvredner.load_mitsuba("./scene_side.xml")
d_scene = dscene.c_obj()
# d_scene.shape_list[1].requires_grad = True

integrator = vredner.Direct()
integrator1 = vredner.PrimaryEdgeIntegrator(c_scene)
integrator2 = vredner.DirectEdgeIntegrator(c_scene)

rendered_image = torch.zeros(cam.resolution[1], cam.resolution[0], 3)
integrator.render(c_scene, options, vredner.float_ptr(rendered_image.data_ptr()))
pyvredner.imwrite(rendered_image, "./img1.exr")

d_image = torch.ones(cam.resolution[1], cam.resolution[0], 3)
d_image[:, :, 1:3] = 0
d_image1 = torch.ones(cam.resolution[1], cam.resolution[0], 3)
d_image1[:, :, 0:3] = 0

integrator.d_render(c_scene, d_scene, options,
                    vredner.float_ptr(d_image.data_ptr()),
                    vredner.float_ptr(d_image1.data_ptr()))

integrator1.d_render(c_scene, d_scene, options,
                    vredner.float_ptr(d_image.data_ptr()),
                    vredner.float_ptr(d_image1.data_ptr()))

integrator2.d_render(c_scene, d_scene, options,
                    vredner.float_ptr(d_image.data_ptr()),
                    vredner.float_ptr(d_image1.data_ptr()))

pyvredner.imwrite(d_image1, "./out.exr")