Skip to content

Commit

Permalink
Unofficial support for batch axes for various ops
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Feb 24, 2021
1 parent 065c9f3 commit 0bafd98
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 8 deletions.
4 changes: 2 additions & 2 deletions jaxlie/_se2.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ def from_rotation_and_translation(

@property
def rotation(self) -> SO2:
return SO2(unit_complex=self.xy_unit_complex[2:])
return SO2(unit_complex=self.xy_unit_complex[..., 2:])

@property
def translation(self) -> types.Vector:
return self.xy_unit_complex[:2]
return self.xy_unit_complex[..., :2]

# Factory

Expand Down
4 changes: 2 additions & 2 deletions jaxlie/_se3.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ def from_rotation_and_translation(

@property
def rotation(self) -> SO3:
return SO3(wxyz=self.xyz_wxyz[3:])
return SO3(wxyz=self.xyz_wxyz[..., 3:])

@property
def translation(self) -> types.Vector:
return self.xyz_wxyz[:3]
return self.xyz_wxyz[..., :3]

# Factory

Expand Down
6 changes: 4 additions & 2 deletions jaxlie/_so2.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def from_radians(theta: types.Scalar) -> "SO2":
return SO2(unit_complex=jnp.array([cos, sin]))

def as_radians(self) -> jnp.ndarray:
(radians,) = self.log()
radians = self.log()[..., 0]
return radians

# Factory
Expand Down Expand Up @@ -89,7 +89,9 @@ def exp(tangent: types.TangentVector) -> "SO2":

@overrides
def log(self: "SO2") -> types.TangentVector:
return jnp.arctan2(self.unit_complex[1, None], self.unit_complex[0, None])
return jnp.arctan2(
self.unit_complex[..., 1, None], self.unit_complex[..., 0, None]
)

@overrides
def adjoint(self: "SO2") -> types.Matrix:
Expand Down
4 changes: 2 additions & 2 deletions jaxlie/_so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,8 @@ def log(self: "SO3") -> types.TangentVector:
# Reference:
# > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/so3.hpp#L247

w = self.wxyz[0]
norm_sq = self.wxyz[1:] @ self.wxyz[1:]
w = self.wxyz[..., 0]
norm_sq = self.wxyz[..., 1:] @ self.wxyz[..., 1:]
norm = jnp.sqrt(norm_sq)
use_taylor = norm < get_epsilon(norm_sq.dtype)
atan_factor = jnp.where(
Expand Down

0 comments on commit 0bafd98

Please sign in to comment.