diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index 232a3bc3..3712cfe7 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -7,8 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] -## [0.9.0] - 2024-11-08 +### Added +- Introduce `CheckpointManagerOptions.should_keep_fn` as an alternative to Introduce `CheckpointManagerOptions.keep_period`. +## [0.9.0] - 2024-11-08 ### Changed - Create `Composite` class, which `CompositeArgs` now subclasses. - Move `tree` to `_src`. diff --git a/checkpoint/orbax/checkpoint/checkpoint_manager.py b/checkpoint/orbax/checkpoint/checkpoint_manager.py index 81a0078d..080f6e87 100644 --- a/checkpoint/orbax/checkpoint/checkpoint_manager.py +++ b/checkpoint/orbax/checkpoint/checkpoint_manager.py @@ -190,6 +190,12 @@ class CheckpointManagerOptions: keep_period: If set, any existing checkpoints matching checkpoint_step % keep_period == 0 will not be deleted. + should_keep_fn: + WARNING: It is an experimental feature and may change without notice. + If set then this Callable overrides the behavior of `keep_period`, which is + ignored. It does not change the behavior of `keep_time_interval` and + `keep_checkpoints_without_metrics`. It is a predicate with signature: + `Callable[[int], bool]`, where int is the step number of the checkpoint. best_fn: If set, maintains checkpoints based on the quality of given metrics rather than recency. The function should accept a PyTree of metrics, @@ -256,6 +262,7 @@ class CheckpointManagerOptions: max_to_keep: Optional[int] = None keep_time_interval: Optional[datetime.timedelta] = None keep_period: Optional[int] = None + should_keep_fn: Optional[Callable[[int], bool]] = None best_fn: Optional[Callable[[PyTree], float]] = None best_mode: str = 'max' keep_checkpoints_without_metrics: bool = True @@ -306,8 +313,10 @@ def __post_init__(self): ) if self.read_only and self.keep_period is not None: self.keep_period = None + self.should_keep_fn = None logging.warning( - 'CheckpointManagerOptions.read_only=True, setting keep_period=None.' + 'CheckpointManagerOptions.read_only=True, setting keep_period=None' + ' and should_keep_fn=None.' ) if self.read_only and self.create: self.create = False @@ -337,6 +346,13 @@ def __post_init__(self): 'CheckpointManagerOptions.read_only=True, setting' ' should_save_fn=None.' ) + if self.should_keep_fn is not None: + logging.warning( + 'CheckpointManagerOptions.should_keep_fn is set, setting' + ' keep_period=None (was %s).', + self.keep_period, + ) + self.keep_period = None self.save_on_steps = frozenset(self.save_on_steps or ()) @@ -1614,6 +1630,15 @@ def _get_old_steps_to_remove(self) -> List[int]: ) kept_checkpoints.add(info) continue + if ( + self._options.should_keep_fn is not None + and self._options.should_keep_fn(info.step) + ): + logging.info( + 'Preserving %s: (Reason: on should_keep_fn callback).', info + ) + kept_checkpoints.add(info) + continue if ( self._options.keep_period is not None and info.step % self._options.keep_period == 0