Skip to content

Commit

Permalink
Add flax serialization support, bump version
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Jan 13, 2021
1 parent 6426989 commit ef9a82c
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 13 deletions.
17 changes: 9 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,33 @@
![lint](https://github.com/brentyi/jaxlie/workflows/lint/badge.svg)
[![codecov](https://codecov.io/gh/brentyi/jaxlie/branch/master/graph/badge.svg)](https://codecov.io/gh/brentyi/jaxlie)

**[ [API reference](https://brentyi.github.io/jaxlie) ]**
**[ [PyPI](https://pypi.org/project/jaxlie/) ]**
**[ [API reference](https://brentyi.github.io/jaxlie) ]** **[
[PyPI](https://pypi.org/project/jaxlie/) ]**

`jaxlie` is a Lie theory library for rigid body transformations and optimization
in JAX.

Current functionality:

- SO(2), SE(2), SO(3), and SE(3) Lie groups implemented as high-level
dataclasses.
- High-level interfaces for SO(2), SE(2), SO(3), and SE(3) Lie groups.
- **`exp()`**, **`log()`**, **`adjoint()`**, **`multiply()`**, **`inverse()`**,
and **`identity()`** implementations for each Lie group.
- Pytree registration for all dataclasses.
and **`identity()`** implementations for each group.
- Helpers + analytical Jacobians for on-manifold optimization
(**`jaxlie.manifold`**).
- Dataclass-style implementations, with support for (un)flattening as pytree
nodes and serialization using [flax](https://github.com/google/flax).

---

##### Install (Python >=3.6)

```bash
pip install jaxlie
```

---

##### Example usage
##### Example usage for SE(3)

```python
import numpy as onp
Expand All @@ -45,7 +46,6 @@ from jaxlie import SE3
# to `SE3.from_matrix(expm(wedge(twist)))`
twist = onp.array([1.0, 0.0, 0.2, 0.0, 0.5, 0.0])
T_w_b = SE3.exp(twist)
p_b = onp.random.randn(3)

# We can print the (quaternion) rotation term; this is a `SO3` object:
print(T_w_b.rotation)
Expand Down Expand Up @@ -75,6 +75,7 @@ T_w_b = SE3(xyz_wxyz=T_w_b.xyz_wxyz)
#############################

# Transform points with the `@` operator:
p_b = onp.random.randn(3)
p_w = T_w_b @ p_b
print(p_w)

Expand Down
20 changes: 19 additions & 1 deletion jaxlie/_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import TYPE_CHECKING, Callable, Tuple, Type, TypeVar
from typing import TYPE_CHECKING, Callable, Dict, Tuple, Type, TypeVar

import flax
import jax
from jax import numpy as jnp

Expand Down Expand Up @@ -29,6 +30,10 @@ def register_lie_group(
- Makes the group hashable
- Marks all functions for JIT compilation
- Adds flattening/unflattening ops for use as a PyTree node
- Adds serialization ops for `flax.serialization`
Note that a significant amount of functionality here could be replaced by
`flax.struct`, but `flax.struct` doesn't work very well with jedi or mypy.
Example:
```
Expand Down Expand Up @@ -78,6 +83,19 @@ def _unflatten_group(

jax.tree_util.register_pytree_node(cls, _flatten_group, _unflatten_group)

# Make object flax-serializable
def _ty_to_state_dict(x: "MatrixLieGroup") -> Dict[str, jnp.ndarray]:
return {"params": x.parameters}

def _ty_from_state_dict(x: "MatrixLieGroup", state: Dict) -> "MatrixLieGroup":
return type(x)(state["params"])

flax.serialization.register_serialization_state(
ty=cls,
ty_to_state_dict=_ty_to_state_dict,
ty_from_state_dict=_ty_from_state_dict,
)

return cls

return _wrap
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,19 @@

setup(
name="jaxlie",
version="0.0.2",
version="0.0.3",
description="Matrix Lie groups in Jax",
long_description=long_description,
long_description_content_type="text/markdown",
url="http://github.com/brentyi/jaxlie",
author="brentyi",
author_email="[email protected]",
license="MIT",
packages=find_packages(exclude=["examples", "tests"]),
packages=find_packages(),
package_data={"liejax": ["py.typed"]},
python_requires=">=3.6",
install_requires=[
"flax",
"jax",
"jaxlib",
"numpy",
Expand Down
32 changes: 32 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Tests for transform serialization, for things like saving calibrated transforms to
disk.
"""

from typing import Type

import flax
from utils import assert_transforms_close, general_group_test, sample_transform

import jaxlie


@general_group_test
def test_serialization_state_dict_bijective(
Group: Type[jaxlie.MatrixLieGroup], _random_module
):
"""Check bijectivity of state dict representation conversations."""
T = sample_transform(Group)
T_recovered = flax.serialization.from_state_dict(
T, flax.serialization.to_state_dict(T)
)
assert_transforms_close(T, T_recovered)


@general_group_test
def test_serialization_bytes_bijective(
Group: Type[jaxlie.MatrixLieGroup], _random_module
):
"""Check bijectivity of byte representation conversations."""
T = sample_transform(Group)
T_recovered = flax.serialization.from_bytes(T, flax.serialization.to_bytes(T))
assert_transforms_close(T, T_recovered)
5 changes: 3 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ def assert_transforms_close(a: jaxlie.MatrixLieGroup, b: jaxlie.MatrixLieGroup):
assert_arrays_close(a.as_matrix(), b.as_matrix())

# Flip signs for quaternions
p1 = a.parameters
p2 = b.parameters
# We use `jnp.asarray` here in case inputs are onp arrays and don't support `.at()`
p1 = jnp.asarray(a.parameters)
p2 = jnp.asarray(b.parameters)
if isinstance(a, jaxlie.SO3):
p1 = p1 * jnp.sign(jnp.sum(p1))
p2 = p2 * jnp.sign(jnp.sum(p2))
Expand Down

0 comments on commit ef9a82c

Please sign in to comment.