Skip to content

Commit

Permalink
Move utils, tweak docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Mar 15, 2021
1 parent 6d17985 commit bce582e
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 16 deletions.
6 changes: 2 additions & 4 deletions jaxlie/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
from . import manifold, types
from . import manifold, types, utils
from ._base import MatrixLieGroup, SEBase, SOBase
from ._se2 import SE2
from ._se3 import SE3
from ._so2 import SO2
from ._so3 import SO3
from ._utils import register_lie_group

__all__ = [
"manifold",
"types",
"utils",
"MatrixLieGroup",
"SOBase",
"SEBase",
"SE2",
"SO2",
"SE3",
"SO3",
"types",
"register_lie_group",
]
7 changes: 4 additions & 3 deletions jaxlie/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,20 +183,21 @@ def sample_uniform(cls: Type[GroupType], key: jnp.ndarray) -> GroupType:


class SOBase(MatrixLieGroup):
pass
"""Base class for special orthogonal groups."""


class SEBase(MatrixLieGroup):
"""Base class for special Euclidean groups."""

# Standard interface
# SE-specific interface

@staticmethod
@abc.abstractmethod
def from_rotation_and_translation(
rotation: SOBase,
translation: types.Vector,
) -> SEGroupType:
"""Construct an rigid transform from a rotation and a translation."""
"""Construct a rigid transform from a rotation and a translation."""

@abc.abstractmethod
def rotation(self) -> SOBase:
Expand Down
2 changes: 1 addition & 1 deletion jaxlie/_se2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from . import _base, types
from ._so2 import SO2
from ._utils import get_epsilon, register_lie_group
from .utils import get_epsilon, register_lie_group


@register_lie_group(
Expand Down
2 changes: 1 addition & 1 deletion jaxlie/_se3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from . import _base, types
from ._so3 import SO3
from ._utils import get_epsilon, register_lie_group
from .utils import get_epsilon, register_lie_group


def _skew(omega: types.Vector) -> types.Matrix:
Expand Down
2 changes: 1 addition & 1 deletion jaxlie/_so2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from overrides import final, overrides

from . import _base, types
from ._utils import register_lie_group
from .utils import register_lie_group


@register_lie_group(
Expand Down
2 changes: 1 addition & 1 deletion jaxlie/_so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from overrides import final, overrides

from . import _base, types
from ._utils import get_epsilon, register_lie_group
from .utils import get_epsilon, register_lie_group


@register_lie_group(
Expand Down
6 changes: 6 additions & 0 deletions jaxlie/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from ._utils import get_epsilon, register_lie_group

__all__ = [
"get_epsilon",
"register_lie_group",
]
11 changes: 6 additions & 5 deletions jaxlie/_utils.py → jaxlie/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,23 @@
import jax
from jax import numpy as jnp

from . import types
from .. import types

if TYPE_CHECKING:
from ._base import MatrixLieGroup
from .._base import MatrixLieGroup


T = TypeVar("T", bound="MatrixLieGroup")


def get_epsilon(dtype: jnp.dtype) -> float:
"""Helper for grabbing type-specific precision constants."""
return {
jnp.dtype("float32"): 1e-5,
jnp.dtype("float64"): 1e-10,
}[dtype]


T = TypeVar("T", bound="MatrixLieGroup")


def register_lie_group(
*,
matrix_dim: int,
Expand Down

0 comments on commit bce582e

Please sign in to comment.