import unittest
import vredner
import pyvredner
import os
import os.path
import numpy as np
import torch

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

curr_dir = os.path.dirname(os.path.realpath(__file__))
options = vredner.RenderOptions(13,     # random seed
                                1024,    # spp
                                1,      # max bounces
                                0,      # sppe
                                0,      # sppse0
                                False)


class TestNormalPreprocess(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        dir_path = os.path.dirname(os.path.realpath(__file__))
        scene_dir = os.path.join(dir_path, 'plane')
        os.chdir(scene_dir)
        cls.scene_file = os.path.join(scene_dir, 'scene.xml')
        cls.integrator = vredner.Direct2()
    
    """ compare the new loader and old loader """
    def test_forward_render(self):
        vredner.set_verbose(True)
        # old loader
        scene1, integrator = pyvredner.load_mitsuba("./scene.xml")
        cam = scene1.camera
        c_scene1 = scene1.c_obj()
        img1 = torch.zeros(cam.resolution[1], cam.resolution[0], 3)
        self.integrator.render(c_scene1, options, vredner.float_ptr(
            img1.data_ptr()))
        pyvredner.imwrite(img1, "./old_loader.exr")

        #new loader
        scene = vredner.Scene("./scene.xml")
        image = integrator.renderC(scene, options).reshape(180, 320, 3)
        pyvredner.imwrite(torch.from_numpy(image), "./new_loader.exr")
        
    def test_renderC(self):
        vredner.set_verbose(True)
        # old loader
        scene1, integrator = pyvredner.load_mitsuba("./scene.xml")
        c_scene1 = scene1.c_obj()
        #new loader
        scene = vredner.Scene("./scene.xml")
        image = self.integrator.renderC(scene, options).reshape(180, 320, 3)
        pyvredner.imwrite(torch.from_numpy(image), "./forward.exr")

    """ validate gradient image by comparing to finite difference """
    def test_finite_difference(self):
        print("finite difference")
        scene1, integrator = pyvredner.load_mitsuba("./scene.xml")
        integrator = vredner.Direct()
        cam = scene1.camera
        c_scene1 = scene1.c_obj()
        img1 = torch.zeros(cam.resolution[1], cam.resolution[0], 3)
        integrator.render(c_scene1, options, vredner.float_ptr(
            img1.data_ptr()))
        pyvredner.imwrite(img1, "./img1.exr")

        scene2, integrator = pyvredner.load_mitsuba("./scene.xml")
        integrator = vredner.Direct()
        cam = scene2.camera
        #! hard code
        scene2.shapes[0].vertices[0] += torch.tensor([0., 0.01, 0.])

        c_scene2 = scene2.c_obj()
        img2 = torch.zeros(cam.resolution[1], cam.resolution[0], 3)
        integrator.render(c_scene2, options, vredner.float_ptr(
            img2.data_ptr()))
        pyvredner.imwrite(img2, "./img2.exr")

        img_diff = (img2 - img1)/0.01
        # img_diff[:, :, 1: 3] = 0
        pyvredner.imwrite(img_diff, "./diff.exr")

    def test_d_render(self):
        vredner.set_verbose(True)
        scene, integrator = pyvredner.load_mitsuba("./scene.xml")
        c_scene = scene.c_obj()
        integrator = vredner.Direct()
        cam = scene.camera
        dscene, integrator = pyvredner.load_mitsuba("./scene.xml")
        # dscene.set_zero()
        d_scene = dscene.c_obj()
        d_scene.setZero()
        integrator = vredner.Direct()
        # c_scene.shape_list[0].requires_grad = True
        # c_scene.shape_list[0].vertex_idx = 0
        # c_scene.shape_list[0].setTranslation(np.array([0., 1., 0.]))
        # d_scene.shape_list[0].requires_grad = True
        # d_scene.shape_list[0].vertex_idx = 0
        # d_scene.shape_list[0].setTranslation(np.array([0., 1., 0.]))

        d_image = torch.ones(cam.resolution[1], cam.resolution[0], 3)
        d_image[:, :, 1:3] = 0
        d_image1 = torch.zeros(cam.resolution[1], cam.resolution[0], 3)

        integrator.d_render(c_scene, d_scene, options,
                            vredner.float_ptr(d_image.data_ptr()),
                            vredner.float_ptr(d_image1.data_ptr()))
        pyvredner.imwrite(d_image1, "./deriv.exr")
        grad = np.array(d_scene.shape_list[0].vertices)
        print(np.linalg.norm(grad))

    def test_renderD(self):
        vredner.set_verbose(True)
        scene = vredner.Scene(self.scene_file)
        d_scene = vredner.Scene(self.scene_file)
        d_scene.setZero()
        d_image = np.ones((180, 320, 3))
        d_image[:, :, 1:3] = 0

        shape = scene.shape_list[0]
        shape.requires_grad = True
        shape.vertex_idx = 8
        shape.setTranslation(np.array([0., 1., 0.]))

        img = self.integrator.renderD(scene, d_scene, options, d_image.reshape((-1,3)))
        img = img.reshape((180, 320, 3))
        pyvredner.imwrite(torch.from_numpy(img), "./grad.exr")
        print(np.array(d_scene.shape_list[0].vertices).sum())

    def test_normal_preprocess(self):
        vredner.set_verbose(True)
        scene = vredner.Scene(self.scene_file)
        d_scene = vredner.Scene(self.scene_file)
        d_scene.setZero()
        integrator = vredner.Direct()
        d_image = np.ones((180, 320, 3))
        d_image[:, :, 1:3] = 0
        integrator.d_render1(scene, d_scene, options, d_image.reshape((-1,3)))
        # vredner.d_scene_precompute_normal(scene, d_scene)
        grad = np.array(d_scene.shape_list[0].vertices)
        print(np.linalg.norm(grad))



if __name__ == '__main__':
    test = TestNormalPreprocess()
    test.setUpClass()
    test.test_renderD()
