Skip to content

Commit

Permalink
Save StepMetadata in the Checkpointer.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707670612
  • Loading branch information
BlaziusMaximus authored and Orbax Authors committed Dec 19, 2024
1 parent b79a298 commit 6e9e8b6
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 2 deletions.
2 changes: 2 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
`CompositeCheckpointHandler.metadata()` to retrieve item metadata by
default-constructing `CheckpointHandler`s when they're listed in the saved
`StepMetadata` but aren't found in the checkpoint.
- `FileOptions.format` to specify the underlying checkpointing file format.

### Fixed
- Ignore not-exists and not-dir errors while building step metadata in
_StandardNameFormat.

### Changed
- Return `StepMetadata` from `CompositeCheckpointHandler.metadata()`.
- `Checkpointer.save()` also saves `StepMetadata`.

## [0.10.2] - 2024-12-04

Expand Down
70 changes: 68 additions & 2 deletions checkpoint/orbax/checkpoint/_src/checkpointers/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from orbax.checkpoint._src.handlers import checkpoint_handler
from orbax.checkpoint._src.handlers import composite_checkpoint_handler
from orbax.checkpoint._src.metadata import checkpoint
from orbax.checkpoint._src.metadata import step_metadata_serialization
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.path import atomicity
from orbax.checkpoint._src.path import atomicity_defaults
Expand All @@ -42,6 +43,7 @@
get_legacy_handler_wrapper = (
composite_checkpoint_handler.get_legacy_handler_wrapper
)
StepMetadata = checkpoint.StepMetadata


def construct_checkpoint_args(
Expand Down Expand Up @@ -161,7 +163,12 @@ async def create_temporary_path(
return tmpdir

def save(
self, directory: epath.PathLike, *args, force: bool = False, **kwargs
self,
directory: epath.PathLike,
*args,
force: bool = False,
custom: dict[str, Any] | None = None,
**kwargs,
):
"""Saves the given item to the provided directory.
Expand All @@ -176,6 +183,8 @@ def save(
*args: additional args to provide to the CheckpointHandler's save method.
force: if True, allows overwriting an existing directory. May add overhead
due to the need to delete any existing files.
custom: a dictionary of custom metadata to be written to the checkpoint
directory via StepMetadata.
**kwargs: additional keyword args to provide to the CheckpointHandler's
save method.
Expand Down Expand Up @@ -226,6 +235,17 @@ def save(
processes=self._active_processes,
)

if utils.is_primary_host(self._primary_host):
self._save_step_metadata(directory, custom=custom)
multihost.sync_global_processes(
multihost.unique_barrier_key(
'Checkpointer:step_metadata_save',
prefix=self._barrier_sync_key_prefix,
suffix=directory.name,
),
processes=self._active_processes,
)

def restore(self, directory: epath.PathLike, *args, **kwargs) -> Any:
"""See superclass documentation."""
directory = epath.Path(directory)
Expand All @@ -251,11 +271,57 @@ def _restore(
) -> Any:
return self._handler.restore(directory, args=args)

def metadata(self, directory: epath.PathLike) -> Optional[Any]:
def metadata(self, directory: epath.PathLike) -> StepMetadata | Any | None:
"""See superclass documentation."""
directory = epath.Path(directory)
return self._handler.metadata(directory)

def _save_step_metadata(
self, directory: epath.Path, custom: dict[str, Any] | None
):
"""Saves StepMetadata to the checkpoint directory."""
if not directory.exists():
logging.warning(
'Checkpoint at %s not found. Skipping step metadata save.', directory
)
return

step_metadata = StepMetadata(
format=self._file_options.format,
custom=custom,
)
if isinstance(
self._handler, composite_checkpoint_handler.CompositeCheckpointHandler
):
try:
partial_metadata: StepMetadata = self._handler.metadata(directory)
except (FileNotFoundError, NotImplementedError):
logging.warning(
'Failed to get per-item metadata from directory %s. Handler types '
'will not be saved.',
directory,
)
else:
step_metadata.item_metadata = partial_metadata.item_metadata
step_metadata.item_handlers = partial_metadata.item_handlers
else:
try:
step_metadata.item_metadata: checkpoint.SingleItemMetadata = (
self._handler.metadata(directory)
)
except (FileNotFoundError, NotImplementedError):
logging.warning(
'Failed to get handler metadata from directory %s.',
directory,
)
step_metadata.item_handlers: checkpoint.CheckpointHandlerTypeStr = (
self._handler.typestr()
)
self._metadata_store.update(
file_path=checkpoint.step_metadata_file_path(directory),
**step_metadata_serialization.serialize(step_metadata),
)

def close(self):
"""Closes the underlying CheckpointHandler."""
self._handler.close()
Expand Down
5 changes: 5 additions & 0 deletions checkpoint/orbax/checkpoint/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@



_ORBAX_STANDARD_FORMAT = 'orbax-standard'


@dataclasses.dataclass
class AsyncOptions:
"""Options used to configure async behavior.
Expand Down Expand Up @@ -65,9 +68,11 @@ class FileOptions:
metadata files. e.g. 0o750. Please check
https://github.com/google/etils/blob/main/etils/epath/backend.py if your
path is supported. default=None.
format: The checkpoint file format. Defaults to 'orbax-standard'.
"""

path_permission_mode: Optional[int] = None
format: str = _ORBAX_STANDARD_FORMAT



0 comments on commit 6e9e8b6

Please sign in to comment.