Skip to content

Commit

Permalink
Internal change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 699400177
  • Loading branch information
niketkumar authored and Orbax Authors committed Nov 23, 2024
1 parent ab94736 commit ae1c3b4
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,7 @@ def _save_fn():
param_infos,
save_args=save_args,
use_zarr3=use_zarr3,
pytree_metadata_options=self._pytree_metadata_options,
)
path.write_text(json.dumps(metadata_content.to_json()))
jax.monitoring.record_event_duration_secs(
Expand Down Expand Up @@ -767,7 +768,8 @@ def _read_metadata_file(
f' {directory}.'
)
return tree_metadata.InternalTreeMetadata.from_json(
json.loads(path.read_text())
json.loads(path.read_text()),
pytree_metadata_options=self._pytree_metadata_options,
)

def metadata(self, directory: epath.Path) -> Optional[PyTree]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from orbax.checkpoint._src import asyncio_utils
from orbax.checkpoint._src.handlers import async_checkpoint_handler
from orbax.checkpoint._src.handlers import pytree_checkpoint_handler
from orbax.checkpoint._src.metadata import pytree_metadata_options as pytree_metadata_options_lib
from orbax.checkpoint._src.tree import utils as tree_utils


Expand Down Expand Up @@ -68,6 +69,9 @@ def __init__(
save_concurrent_gb: int = 96,
restore_concurrent_gb: int = 96,
multiprocessing_options: options_lib.MultiprocessingOptions = options_lib.MultiprocessingOptions(),
pytree_metadata_options: pytree_metadata_options_lib.PyTreeMetadataOptions = (
pytree_metadata_options_lib.PYTREE_METADATA_OPTIONS
),
):
"""Creates StandardCheckpointHandler.
Expand All @@ -79,12 +83,15 @@ def __init__(
Can help to reduce the possibility of OOM's when large checkpoints are
restored.
multiprocessing_options: See orbax.checkpoint.options.
pytree_metadata_options: Options to control types like tuple and
namedtuple in pytree metadata.
"""
self._supported_types = checkpoint_utils.STANDARD_ARRAY_TYPES
self._impl = pytree_checkpoint_handler.PyTreeCheckpointHandler(
save_concurrent_gb=save_concurrent_gb,
restore_concurrent_gb=restore_concurrent_gb,
multiprocessing_options=multiprocessing_options,
pytree_metadata_options=pytree_metadata_options,
)

def _validate_save_state(
Expand Down
8 changes: 4 additions & 4 deletions checkpoint/orbax/checkpoint/_src/testing/test_tree_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class EmptyNamedTuple(NamedTuple):
class NamedTupleWithNestedAttributes(NamedTuple):
nested_mu_nu: MuNu | None = None
nested_dict: Dict[str, jax.Array] | None = None
nested_tuple: Tuple[jax.Array, np.ndarray] | None = None
nested_tuple: Tuple[jax.Array, jax.Array] | None = None
nested_empty_named_tuple: EmptyNamedTuple | None = None
my_empty_chex: MyEmptyChex | None = None

Expand Down Expand Up @@ -823,7 +823,7 @@ def __repr__(self):
'named_tuple_with_nested_attrs': NamedTupleWithNestedAttributes(
nested_mu_nu=MuNu(mu=jnp.arange(8), nu=np.arange(8)),
nested_dict={'a': jnp.arange(8), 'b': np.arange(8)},
nested_tuple=(jnp.arange(8), np.arange(8)),
nested_tuple=(jnp.arange(8), jnp.arange(8)),
nested_empty_named_tuple=EmptyNamedTuple(),
my_empty_chex=MyEmptyChex(),
)
Expand Down Expand Up @@ -856,7 +856,7 @@ def __repr__(self):
skip_deserialize=False,
),
tree_metadata.ValueMetadataEntry(
value_type='np.ndarray',
value_type='jax.Array',
skip_deserialize=False,
),
],
Expand Down Expand Up @@ -917,7 +917,7 @@ def __repr__(self):
skip_deserialize=False,
),
tree_metadata.ValueMetadataEntry(
value_type='np.ndarray',
value_type='jax.Array',
skip_deserialize=False,
),
),
Expand Down

0 comments on commit ae1c3b4

Please sign in to comment.