import math
import torch
from scipy.spatial.transform import Rotation as R
import numpy as np

def translate(vertices, vec):
    return vertices + vec

def rotate(vertices, ang_vel, t):
    return vertices @ Rotation3x3(ang_vel, t).transpose(0, 1)

def Rotation3x3(ang_vel, t):
    return R.from_rotvec(ang_vel * t).as_matrix()
    return gen_rotation_matrix3x3(ang_vel, t)

def gen_rotation_matrix3x3(ang_vel, t):
    vel = math.sqrt(ang_vel[0]*ang_vel[0] + ang_vel[1]*ang_vel[1] + ang_vel[2]*ang_vel[2])
    if vel > 1e-4:
        theta = vel * t
        ux = ang_vel[0]/vel
        uy = ang_vel[1]/vel
        uz = ang_vel[2]/vel
        cos_theta = torch.cos(theta)[0] #TODO
        sin_theta = torch.sin(theta)[0]
        mat = torch.stack([
            torch.stack([cos_theta + ux*ux*(1-cos_theta),      ux*uy *
                        (1-cos_theta) - uz*sin_theta,  ux*uz*(1-cos_theta) + uy*sin_theta]),
            torch.stack([uy*ux*(1-cos_theta) + uz*sin_theta,   cos_theta +
                        uy*uy*(1-cos_theta),     uy*uz*(1-cos_theta) - ux*sin_theta]),
            torch.stack([uz*ux*(1-cos_theta) - uy*sin_theta,   uz*uy *
                        (1-cos_theta) + ux*sin_theta,  cos_theta + uz*uz*(1-cos_theta)])
        ])
        return mat      
    else:
        return torch.eye(3)


def sample_spherical(npoints, ndim=3):
    vec = np.random.randn(ndim, npoints)
    vec /= np.linalg.norm(vec, axis=0)
    return vec.T

def sample_rotation(n):
    R.random(n).as_matrix()

if __name__ == "__main__":
    print(sample_spherical(10))
    vertices = torch.tensor([[0., 0., 0.], [1., 0., 0.], [0., 1., 0.]])
    print(vertices)
    print(translate(vertices, torch.tensor([1., 2., 3.])))
    print(rotate(vertices, torch.tensor([0., 1., 0.]), torch.tensor([math.pi])))
    print(Rotation3x3(torch.tensor([0., 1., 0.]), torch.tensor([1])))
    print(R.from_rotvec([0, 1, 0]).as_matrix())
    xform = R.from_rotvec([0, 1, 0]).as_matrix()
    # assert(xform.shape == (4, 4))
    print(sample_spherical(10).T)