From 7cd3e37fa73c861e16a565372bdf12cd5da0ead0 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Wed, 2 Jun 2021 01:05:53 -0700 Subject: [PATCH] Fixes for `jax_dataclasses` API change --- jaxlie/_se2.py | 2 +- jaxlie/_se3.py | 2 +- jaxlie/_so2.py | 2 +- jaxlie/_so3.py | 2 +- setup.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/jaxlie/_se2.py b/jaxlie/_se2.py index 4b73617..c2b66cb 100644 --- a/jaxlie/_se2.py +++ b/jaxlie/_se2.py @@ -15,7 +15,7 @@ tangent_dim=3, space_dim=2, ) -@jax_dataclasses.dataclass +@jax_dataclasses.pytree_dataclass class SE2(_base.SEBase[SO2]): """Special Euclidean group for proper rigid transforms in 2D.""" diff --git a/jaxlie/_se3.py b/jaxlie/_se3.py index 0005ab9..c1acf71 100644 --- a/jaxlie/_se3.py +++ b/jaxlie/_se3.py @@ -28,7 +28,7 @@ def _skew(omega: hints.Vector) -> hints.MatrixJax: tangent_dim=6, space_dim=3, ) -@jax_dataclasses.dataclass +@jax_dataclasses.pytree_dataclass class SE3(_base.SEBase[SO3]): """Special Euclidean group for proper rigid transforms in 3D.""" diff --git a/jaxlie/_so2.py b/jaxlie/_so2.py index 4403afb..f3c3d4e 100644 --- a/jaxlie/_so2.py +++ b/jaxlie/_so2.py @@ -14,7 +14,7 @@ tangent_dim=1, space_dim=2, ) -@jax_dataclasses.dataclass +@jax_dataclasses.pytree_dataclass class SO2(_base.SOBase): """Special orthogonal group for 2D rotations.""" diff --git a/jaxlie/_so3.py b/jaxlie/_so3.py index 6c79ea6..26d94b3 100644 --- a/jaxlie/_so3.py +++ b/jaxlie/_so3.py @@ -14,7 +14,7 @@ tangent_dim=3, space_dim=3, ) -@jax_dataclasses.dataclass +@jax_dataclasses.pytree_dataclass class SO3(_base.SOBase): """Special orthogonal group for 3D rotations.""" diff --git a/setup.py b/setup.py index a75861c..d008941 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ "jax", "jaxlib", "numpy", - "jax_dataclasses", + "jax_dataclasses>=1.0.0", # `overrides` should not be updated until the following issues are resolved: # > https://github.com/mkorpela/overrides/issues/65 # > https://github.com/mkorpela/overrides/issues/63