diff --git a/README.md b/README.md index bb1ad7f..4d95bad 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ rotvec = torch.randn(batch_shape + (3,)) q = roma.rotvec_to_unitquat(rotvec) R = roma.unitquat_to_rotmat(q) Rbis = roma.rotvec_to_rotmat(rotvec) +euler_angles = roma.unitquat_to_euler('xyz', q, as_tensor=True, degrees=True) # Regression of a rotation from an arbitrary input: # Special Procrustes orthonormalization of a 3x3 matrix diff --git a/roma/euler.py b/roma/euler.py index b6e05b3..5cba988 100644 --- a/roma/euler.py +++ b/roma/euler.py @@ -63,7 +63,10 @@ def euler_to_unitquat(convention: str, angles, degrees=False, normalize=True, dt raise ValueError("Invalid convention (expected format: 'xyz', 'zxz', 'XYZ', etc.).") q = roma.rotvec_to_unitquat(rotvec) unitquats.append(q) - return roma.quat_composition(unitquats, normalize=normalize) + if len(unitquats) == 1: + return unitquats[0] + else: + return roma.quat_composition(unitquats, normalize=normalize) def euler_to_rotvec(convention: str, angles, degrees=False, dtype=None, device=None): """