from psdr_cpu import Camera, look_at
import psdr_cpu
import numpy as np
from scipy.spatial.transform import Rotation as R


def sample_on_sphere(batch, radius, cos_phi_min=-1.0, cos_phi_max=1.0):
    """
    sample points on a sphere with radius=radius
    """
    cosPhi = np.linspace(cos_phi_max - 0.01, cos_phi_min + 0.01, batch)
    theta = np.linspace(0, np.pi * 10, batch) + \
        np.random.uniform(0, 1, batch) * (1/float(batch))
    sinPhi = np.sqrt(1 - cosPhi * cosPhi)
    sinTheta = np.sin(theta)
    cosTheta = np.cos(theta)

    return np.array([sinPhi * cosTheta * radius, cosPhi * radius, sinPhi * sinTheta * radius]).T


def sample_on_hemisphere(batch, radius):
    """
    sample points on a sphere with radius=radius
    """
    cosPhi = np.linspace(1.0 - 0.01, 0.0, batch)
    theta = np.linspace(0, np.pi * 10, batch) + \
        np.random.uniform(0, 1, batch) * (1/float(batch))
    sinPhi = np.sqrt(1 - cosPhi * cosPhi)
    sinTheta = (np.sin(theta))
    cosTheta = np.cos(theta)

    return np.array([sinPhi * cosTheta * radius, cosPhi * radius, sinPhi * sinTheta * radius]).T


def gen_camera(origin, target, up, fov, resolution, type):
    """
    generate a camera
    """
    return Camera(width=resolution[0], height=resolution[1],
                  fov=fov, to_world=look_at(origin, target, up), type=type)


def gen_cameras(positions, target, up, fov, resolution, type):
    """
    generate a list of cameras from a list of positions
    """
    return [gen_camera(p, target, up, fov, resolution, type)
            for p in positions]


def gen_random_rotations(batch, vec=None):
    """
    generate a list of random rotations
    """
    if vec is None:
        return R.random(batch).as_matrix()
    else:
        theta = np.linspace(0, 2 * np.pi, batch)
        return R.from_rotvec(theta * vec).as_matrix()


if __name__ == '__main__':
    camera = gen_camera([0., 0., 0.], [0., 0., 1.], [
                        0., 1., 0.], 45., [512, 512])
    print(camera.cam_to_world)
    print(camera.cam_to_ndc)

    print(gen_random_rotations(2))

    V = np.array([[2., 0., 0.], [0., 2., 0.], [0., 0., 2.]])
    r = R.from_euler('z', 90, degrees=True).as_matrix()
    print(r)
    print(V @ r.T)

    r = R.from_rotvec([0, np.pi, 0]).as_matrix()
    print(r)
