Skip to content

Commit

Permalink
internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 694643070
  • Loading branch information
niketkumar authored and Orbax Authors committed Nov 8, 2024
1 parent 4ddd6ba commit fe86ca4
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
4 changes: 3 additions & 1 deletion checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [0.9.0] - 2024-11-08
### Added
- Introduce `CheckpointManagerOptions.should_keep_fn` as an alternative to Introduce `CheckpointManagerOptions.keep_period`.

## [0.9.0] - 2024-11-08
### Changed
- Create `Composite` class, which `CompositeArgs` now subclasses.
- Move `tree` to `_src`.
Expand Down
27 changes: 26 additions & 1 deletion checkpoint/orbax/checkpoint/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,12 @@ class CheckpointManagerOptions:
keep_period:
If set, any existing checkpoints matching checkpoint_step % keep_period == 0
will not be deleted.
should_keep_fn:
WARNING: It is an experimental feature and may change without notice.
If set then this Callable overrides the behavior of `keep_period`, which is
ignored. It does not change the behavior of `keep_time_interval` and
`keep_checkpoints_without_metrics`. It is a predicate with signature:
`Callable[[int], bool]`, where int is the step number of the checkpoint.
best_fn:
If set, maintains checkpoints based on the quality of given
metrics rather than recency. The function should accept a PyTree of metrics,
Expand Down Expand Up @@ -256,6 +262,7 @@ class CheckpointManagerOptions:
max_to_keep: Optional[int] = None
keep_time_interval: Optional[datetime.timedelta] = None
keep_period: Optional[int] = None
should_keep_fn: Optional[Callable[[int], bool]] = None
best_fn: Optional[Callable[[PyTree], float]] = None
best_mode: str = 'max'
keep_checkpoints_without_metrics: bool = True
Expand Down Expand Up @@ -306,8 +313,10 @@ def __post_init__(self):
)
if self.read_only and self.keep_period is not None:
self.keep_period = None
self.should_keep_fn = None
logging.warning(
'CheckpointManagerOptions.read_only=True, setting keep_period=None.'
'CheckpointManagerOptions.read_only=True, setting keep_period=None'
' and should_keep_fn=None.'
)
if self.read_only and self.create:
self.create = False
Expand Down Expand Up @@ -337,6 +346,13 @@ def __post_init__(self):
'CheckpointManagerOptions.read_only=True, setting'
' should_save_fn=None.'
)
if self.should_keep_fn is not None:
logging.warning(
'CheckpointManagerOptions.should_keep_fn is set, setting'
' keep_period=None (was %s).',
self.keep_period,
)
self.keep_period = None
self.save_on_steps = frozenset(self.save_on_steps or ())


Expand Down Expand Up @@ -1614,6 +1630,15 @@ def _get_old_steps_to_remove(self) -> List[int]:
)
kept_checkpoints.add(info)
continue
if (
self._options.should_keep_fn is not None
and self._options.should_keep_fn(info.step)
):
logging.info(
'Preserving %s: (Reason: on should_keep_fn callback).', info
)
kept_checkpoints.add(info)
continue
if (
self._options.keep_period is not None
and info.step % self._options.keep_period == 0
Expand Down

0 comments on commit fe86ca4

Please sign in to comment.