Skip to content

Commit

Permalink
More general pack/unpack for serialization helper
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Mar 11, 2021
1 parent eab5330 commit ba54b59
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions jaxlie/_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -95,10 +96,18 @@ def _unflatten_group(treedef: Any, children: Sequence[Any]) -> T:

# Make object flax-serializable
def _ty_to_state_dict(x: "MatrixLieGroup") -> Dict[str, types.Array]:
return {"params": x.parameters}
return {
key: flax.serialization.to_state_dict(value)
for key, value in vars(x).items()
}

def _ty_from_state_dict(x: "MatrixLieGroup", state: Dict) -> "MatrixLieGroup":
return type(x)(state["params"])
updates: Dict[str, Any] = {}
for key, value in vars(x).items():
updates[key] = flax.serialization.from_state_dict(
getattr(x, key), value
)
return dataclasses.replace(x, **updates)

flax.serialization.register_serialization_state(
ty=cls,
Expand Down

0 comments on commit ba54b59

Please sign in to comment.