Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

internal change #1312

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [0.9.0] - 2024-11-08
### Added
- Introduce `CheckpointManagerOptions.should_keep_fn` as an alternative to Introduce `CheckpointManagerOptions.keep_period`.

## [0.9.0] - 2024-11-08
### Changed
- Create `Composite` class, which `CompositeArgs` now subclasses.
- Move `tree` to `_src`.
Expand Down
27 changes: 26 additions & 1 deletion checkpoint/orbax/checkpoint/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,12 @@ class CheckpointManagerOptions:
keep_period:
If set, any existing checkpoints matching checkpoint_step % keep_period == 0
will not be deleted.
should_keep_fn:
WARNING: It is an experimental feature and may change without notice.
If set then this Callable overrides the behavior of `keep_period`, which is
ignored. It does not change the behavior of `keep_time_interval` and
`keep_checkpoints_without_metrics`. It is a predicate with signature:
`Callable[[int], bool]`, where int is the step number of the checkpoint.
best_fn:
If set, maintains checkpoints based on the quality of given
metrics rather than recency. The function should accept a PyTree of metrics,
Expand Down Expand Up @@ -256,6 +262,7 @@ class CheckpointManagerOptions:
max_to_keep: Optional[int] = None
keep_time_interval: Optional[datetime.timedelta] = None
keep_period: Optional[int] = None
should_keep_fn: Optional[Callable[[int], bool]] = None
best_fn: Optional[Callable[[PyTree], float]] = None
best_mode: str = 'max'
keep_checkpoints_without_metrics: bool = True
Expand Down Expand Up @@ -306,8 +313,10 @@ def __post_init__(self):
)
if self.read_only and self.keep_period is not None:
self.keep_period = None
self.should_keep_fn = None
logging.warning(
'CheckpointManagerOptions.read_only=True, setting keep_period=None.'
'CheckpointManagerOptions.read_only=True, setting keep_period=None'
' and should_keep_fn=None.'
)
if self.read_only and self.create:
self.create = False
Expand Down Expand Up @@ -337,6 +346,13 @@ def __post_init__(self):
'CheckpointManagerOptions.read_only=True, setting'
' should_save_fn=None.'
)
if self.should_keep_fn is not None:
logging.warning(
'CheckpointManagerOptions.should_keep_fn is set, setting'
' keep_period=None (was %s).',
self.keep_period,
)
self.keep_period = None
self.save_on_steps = frozenset(self.save_on_steps or ())


Expand Down Expand Up @@ -1614,6 +1630,15 @@ def _get_old_steps_to_remove(self) -> List[int]:
)
kept_checkpoints.add(info)
continue
if (
self._options.should_keep_fn is not None
and self._options.should_keep_fn(info.step)
):
logging.info(
'Preserving %s: (Reason: on should_keep_fn callback).', info
)
kept_checkpoints.add(info)
continue
if (
self._options.keep_period is not None
and info.step % self._options.keep_period == 0
Expand Down
Loading