import enum
from typing import List
from pypsdr.utils.io import imwrite
import psdr_cpu
from psdr_cpu import Path2
from pypsdr.common import gen_camera, gen_cameras, sample_on_sphere
import os
import numpy as np
import psdr_cpu
import gin
from scipy.spatial.transform import Rotation as R

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'


def render_scene(scene: psdr_cpu.Scene, options, integrator=Path2()) -> np.ndarray:
    """ 
    render the scene and return the image
    """
    width, height = scene.camera.width, scene.camera.height
    image = integrator.renderC(scene, options).reshape((height, width, 3))
    return image


def render_scene_file(scene_file, options, integrator=Path2()) -> np.ndarray:
    dir_path = os.path.dirname(os.path.abspath(scene_file))
    os.chdir(dir_path)
    scene = psdr_cpu.Scene(scene_file)
    return render_scene(scene, options, integrator)


def render_multi_views(scene, cameras, options, integrator=Path2()) -> List[np.ndarray]:
    images = []
    for camera in cameras:
        # switch camera
        scene.camera = camera
        # render
        image = integrator.renderC(scene, options) \
            .reshape((camera.height, camera.width, 3))
        images.append(image)
    return images


def render_save_multi_views(scene, cameras, options, integrator, out_dir) -> List[np.ndarray]:
    images = []
    for i, camera in enumerate(cameras):
        # switch camera
        scene.camera = camera
        # render
        image = integrator.renderC(scene, options) \
            .reshape((camera.height, camera.width, 3))
        imwrite(image, os.path.join(out_dir, "sensor_{}.exr".format(i)))

def render_save_multi_pose(scene, shape_id, xforms, options, integrator, out_dir) -> List[np.ndarray]:
    shape = scene.shapes[shape_id]
    V = np.array(shape.vertices)
    camera = scene.camera
    for i, xform in enumerate(xforms):
        assert(xform.shape == (4, 4))
        shape.setVertices(V @ xform[:3, :3] + xform[:3, 3])
        scene.configure()
        image = integrator.renderC(scene, options) \
            .reshape((camera.height, camera.width, 3))
        VV = np.array(shape.vertices)
        FF = np.array(shape.indices)
        imwrite(image, os.path.join(out_dir, "sensor_{}.exr".format(i)))

def render_save_multi_rotation(scene, obj_pos, options, integrator, out_dir) -> List[np.ndarray]:
    images = []
    raw_position = np.array(scene.shapes[0].vertices)
    for i, pos in enumerate(obj_pos):
        print(i, pos)
        r = R.from_rotvec(pos[3] * pos[0:3])
        rot_position = np.dot(raw_position, r.as_matrix().transpose())
        scene.shapes[0].setVertices(rot_position)
        scene.configure()

        image = integrator.renderC(scene, options) \
            .reshape((scene.camera.height, scene.camera.width, 3))
        imwrite(image, os.path.join(out_dir, "sensor_{}.exr".format(i)))



if __name__ == '__main__':
    import matplotlib.pyplot as plt
    import cv2
    options = psdr_cpu.RenderOptions(7,      # random seed
                                    128,    # spp
                                    1,      # max bounces
                                    0,      # sppe
                                    0,      # sppse0
                                    False)
    # print(options)
    psdr_cpu.set_verbose(True)
    scene_file = '/home/javis/Documents/MyProjects/psdr_enzyme/tests/inverse_rendering/kitty/tar.xml'
    dir_path = os.path.dirname(os.path.abspath(scene_file))
    os.chdir(dir_path)
    scene = psdr_cpu.Scene('tar.xml')
    positions = sample_on_sphere(10, 100)
    cameras = gen_cameras(positions=positions, target=[
                          0, 0, 0], fov=10., resolution=[256, 256])
    images = render_multi_views(scene, cameras, options)
    image = np.power(images[2], 1/2.2)
    image = np.uint8(np.clip(image * 255., 0., 255.))
    plt.imshow(image)
    plt.show()
