Skip to content

Commit

Permalink
Change the way load_path is handled (#132)
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh authored Jan 7, 2025
1 parent 03a597a commit 7519e0a
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 16 deletions.
1 change: 0 additions & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,6 @@ jobs:
cluster:
# H100 clusters
- ai2/jupiter-cirrascale-2
- ai2/pluto-cirrascale
- ai2/augusta-google-1
# A100 clusters
- ai2/saturn-cirrascale
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed

- Changed storage of shared shard state in sharded checkpoints from smallest shard to lowest rank (normally 0).
- Changed how the trainer handles loading a checkpoint when `load_path` is provided. Now `load_path` is only used if no checkpoint is found in the `save_folder`.

### Fixed

Expand Down
5 changes: 4 additions & 1 deletion src/olmo_core/train/callbacks/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,10 @@ def pre_train(self):
path for _, path in sorted(ephemeral_checkpoints, key=lambda x: x[0])
]
for path in self._ephemeral_checkpoints:
log.info(f"Collected existing ephemeral checkpoint at '{path}'")
log.info(
f"Found existing ephemeral checkpoint at '{path}' which will "
"be removed when the next checkpoint is saved"
)

def post_train_batch(self):
self._await_last_checkpoint(blocking=False)
Expand Down
8 changes: 5 additions & 3 deletions src/olmo_core/train/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,19 @@ class LoadStrategy(StrEnum):

if_available = "if_available"
"""
Only load from the load path if a checkpoint exists there.
The trainer will attempt to load a checkpoint from the save folder or load path (in that order)
but will train from scratch if no checkoint is found.
"""

always = "always"
"""
Always try loading from the load path.
The trainer will attempt to load a checkpoint from the save folder or load path (in that order)
and raise an error if no checkpoint is found.
"""

never = "never"
"""
Never load from the load path.
The trainer will never load a checkpoint even if one exists in the save folder or load path.
"""


Expand Down
43 changes: 32 additions & 11 deletions src/olmo_core/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,10 @@ class Trainer:

load_path: Optional[PathOrStr] = None
"""
Where to load a checkpoint from prior to training.
Defaults to ``save_folder``.
An alternative location to load a checkpoint from if no checkpoint is found in the current :data:`save_folder`.
This can be set to a checkpoint path or the path to a folder of checkpoints such as the :data:`save_folder`
from a different run.
"""

load_strategy: LoadStrategy = LoadStrategy.if_available
Expand Down Expand Up @@ -538,20 +540,38 @@ def check_if_canceled(self):

def fit(self):
"""
Fit the model, potentially loading a checkpoint before hand depending on the
Fit the model, potentially loading a checkpoint first depending on the
:data:`load_strategy`.
"""
self._canceled = False
self._cancel_reason = None
self._canceling_rank = None

# Maybe load a checkpoint.
if not self.checkpoint_loaded:
load_path = self.load_path if self.load_path is not None else self.save_folder
if self.load_strategy == LoadStrategy.always:
self.load_checkpoint(load_path)
elif self.load_strategy == LoadStrategy.if_available:
self.maybe_load_checkpoint(load_path)
if not self.checkpoint_loaded and self.load_strategy != LoadStrategy.never:
# Try loading from the save folder first.
self.maybe_load_checkpoint(self.save_folder)

# Then fallback to the load path, if provided.
if self.load_path is not None:
if not self.checkpoint_loaded:
self.maybe_load_checkpoint(self.load_path)
else:
log.warning(
f"Ignoring load path ('{self.load_path}') since checkpoint was found in save folder"
)

if not self.checkpoint_loaded:
if self.load_strategy == LoadStrategy.always:
raise FileNotFoundError(
f"No checkpoint found in save folder ('{self.save_folder}') or "
f"load path ('{self.load_path}')"
)
else:
log.warning(
f"No checkpoint found in save folder ('{self.save_folder}') or "
f"load path ('{self.load_path}'), will train from scratch..."
)

log.info(f"Training for {self.max_steps:,d} steps")

Expand Down Expand Up @@ -709,9 +729,10 @@ def maybe_load_checkpoint(
load_optimizer_state=load_optimizer_state,
load_trainer_state=load_trainer_state,
)
assert self.checkpoint_loaded
return True
else:
log.warning(f"No checkpoint found in '{dir}', will train from scratch...")
return should_load
return False

def save_checkpoint(self) -> PathOrStr:
"""
Expand Down

0 comments on commit 7519e0a

Please sign in to comment.