From d48b845c173c3fe6eb2b38ef1698a24d9739822b Mon Sep 17 00:00:00 2001 From: Niket Kumar Bhumihar Date: Fri, 22 Nov 2024 22:50:47 -0800 Subject: [PATCH] Internal change. PiperOrigin-RevId: 699400177 --- .../_src/handlers/base_pytree_checkpoint_handler.py | 4 +++- .../_src/handlers/standard_checkpoint_handler.py | 7 +++++++ .../orbax/checkpoint/_src/testing/test_tree_utils.py | 8 ++++---- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py index b61c5d93d..d1fb50b90 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py @@ -736,6 +736,7 @@ def _save_fn(): param_infos, save_args=save_args, use_zarr3=use_zarr3, + pytree_metadata_options=self._pytree_metadata_options, ) path.write_text(json.dumps(metadata_content.to_json())) jax.monitoring.record_event_duration_secs( @@ -767,7 +768,8 @@ def _read_metadata_file( f' {directory}.' ) return tree_metadata.InternalTreeMetadata.from_json( - json.loads(path.read_text()) + json.loads(path.read_text()), + pytree_metadata_options=self._pytree_metadata_options, ) def metadata(self, directory: epath.Path) -> Optional[PyTree]: diff --git a/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler.py index 25adba965..0249d858d 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler.py @@ -30,6 +30,7 @@ from orbax.checkpoint._src import asyncio_utils from orbax.checkpoint._src.handlers import async_checkpoint_handler from orbax.checkpoint._src.handlers import pytree_checkpoint_handler +from orbax.checkpoint._src.metadata import pytree_metadata_options as pytree_metadata_options_lib from orbax.checkpoint._src.tree import utils as tree_utils @@ -68,6 +69,9 @@ def __init__( save_concurrent_gb: int = 96, restore_concurrent_gb: int = 96, multiprocessing_options: options_lib.MultiprocessingOptions = options_lib.MultiprocessingOptions(), + pytree_metadata_options: pytree_metadata_options_lib.PyTreeMetadataOptions = ( + pytree_metadata_options_lib.PYTREE_METADATA_OPTIONS + ), ): """Creates StandardCheckpointHandler. @@ -79,12 +83,15 @@ def __init__( Can help to reduce the possibility of OOM's when large checkpoints are restored. multiprocessing_options: See orbax.checkpoint.options. + pytree_metadata_options: Options to control types like tuple and + namedtuple in pytree metadata. """ self._supported_types = checkpoint_utils.STANDARD_ARRAY_TYPES self._impl = pytree_checkpoint_handler.PyTreeCheckpointHandler( save_concurrent_gb=save_concurrent_gb, restore_concurrent_gb=restore_concurrent_gb, multiprocessing_options=multiprocessing_options, + pytree_metadata_options=pytree_metadata_options, ) def _validate_save_state( diff --git a/checkpoint/orbax/checkpoint/_src/testing/test_tree_utils.py b/checkpoint/orbax/checkpoint/_src/testing/test_tree_utils.py index 4d4ae401f..8395706b7 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/test_tree_utils.py +++ b/checkpoint/orbax/checkpoint/_src/testing/test_tree_utils.py @@ -51,7 +51,7 @@ class EmptyNamedTuple(NamedTuple): class NamedTupleWithNestedAttributes(NamedTuple): nested_mu_nu: MuNu | None = None nested_dict: Dict[str, jax.Array] | None = None - nested_tuple: Tuple[jax.Array, np.ndarray] | None = None + nested_tuple: Tuple[jax.Array, jax.Array] | None = None nested_empty_named_tuple: EmptyNamedTuple | None = None my_empty_chex: MyEmptyChex | None = None @@ -823,7 +823,7 @@ def __repr__(self): 'named_tuple_with_nested_attrs': NamedTupleWithNestedAttributes( nested_mu_nu=MuNu(mu=jnp.arange(8), nu=np.arange(8)), nested_dict={'a': jnp.arange(8), 'b': np.arange(8)}, - nested_tuple=(jnp.arange(8), np.arange(8)), + nested_tuple=(jnp.arange(8), jnp.arange(8)), nested_empty_named_tuple=EmptyNamedTuple(), my_empty_chex=MyEmptyChex(), ) @@ -856,7 +856,7 @@ def __repr__(self): skip_deserialize=False, ), tree_metadata.ValueMetadataEntry( - value_type='np.ndarray', + value_type='jax.Array', skip_deserialize=False, ), ], @@ -917,7 +917,7 @@ def __repr__(self): skip_deserialize=False, ), tree_metadata.ValueMetadataEntry( - value_type='np.ndarray', + value_type='jax.Array', skip_deserialize=False, ), ),