diff --git a/src/fairseq2/checkpoint.py b/src/fairseq2/checkpoint.py index 44e5a1808..5e74806c9 100644 --- a/src/fairseq2/checkpoint.py +++ b/src/fairseq2/checkpoint.py @@ -125,7 +125,8 @@ def load_model( :param out: The model to load. :param device: - The device on which to load ``out`` if it is on the meta device. + The device on which to load ``out`` if it is on the meta device; + ignored otherwise. """ @abstractmethod @@ -135,7 +136,8 @@ def load_last_model(self, out: Module, *, device: Optional[Device] = None) -> in :param out: The model to load. :param device: - The device on which to load ``out`` if it is on the meta device. + The device on which to load ``out`` if it is on the meta device; + ignored otherwise. :returns: The number of the training step. @@ -145,11 +147,11 @@ def load_last_model(self, out: Module, *, device: Optional[Device] = None) -> in def has_checkpoint( self, step_nr: Optional[int] = None, *, with_model: bool = False ) -> bool: - """Return ``True`` if the manager has a checkpoint. + """Return ``True`` if the manager holds a checkpoint. :param step_nr: The number of the training step. If ``None``, returns ``True`` if - the manager has at least one checkpoint. + the manager holds at least one checkpoint. :param with_model: If ``True``, only considers training steps with a saved model. """ @@ -176,7 +178,9 @@ class FileCheckpointManager(CheckpointManager): """Saves and loads training checkpoints on a file system.""" _checkpoint_dir: Path - _gang: Gang + _root_gang: Gang + _dp_gang: Gang + _shard_suffix: str _distributed_fs: bool _model_key: str _replicated_keys: Set[str] @@ -186,6 +190,8 @@ def __init__( checkpoint_dir: Path, gang: Gang, *, + dp_gang: Optional[Gang] = None, + tp_gang: Optional[Gang] = None, distributed_fs: bool = True, model_key: str = "model", replicated_keys: Optional[Sequence[str]] = None, @@ -195,6 +201,11 @@ def __init__( The base directory under which to store the checkpoints. :param gang: The gang to coordinate the checkpoint operations. + :param dp_gang: + The gang used for data parallelism. + :param tp_gang: + The gang used for tensor parallelism. Must be specified if ``dp_gang`` + is not ``None``. :param distributed_fs: If ``True``, the underlying file system of ``checkpoint_dir`` is considered distributed (e.g. NFS). @@ -204,13 +215,28 @@ def __init__( The keys in provided checkpoints whose values are replicated across all processes in the gang. """ - self._gang = gang + self._root_gang = gang + + self._dp_gang = gang + + self._shard_suffix = "" + + if dp_gang is not None and tp_gang is not None: + self._dp_rank = dp_gang + + if tp_gang.size > 1: + self._shard_suffix = f".{tp_gang.rank}" + elif dp_gang is not None or tp_gang is not None: + raise ValueError("`dp_gang` and `tp_gang` must be both specified.") + self._distributed_fs = distributed_fs if distributed_fs: self._checkpoint_dir = checkpoint_dir else: - self._checkpoint_dir = checkpoint_dir.joinpath(f"rank_{gang.rank}") + self._checkpoint_dir = checkpoint_dir.joinpath( + f"rank_{self._dp_gang.rank}{self._shard_suffix}" + ) self._model_key = model_key @@ -246,33 +272,34 @@ def raise_error(cause: Exception) -> NoReturn: tmp_step_dir = self._checkpoint_dir.joinpath(f"step_{step_nr}.tmp") - if self._gang.rank == 0 or not self._distributed_fs: + if self._root_gang.rank == 0 or not self._distributed_fs: try: tmp_step_dir.mkdir(parents=True) except OSError as ex: raise_error(ex) - self._gang.barrier() + self._root_gang.barrier() + + # Do not modify the argument in-place. In case we fail, it should stay + # intact. + rank_part = checkpoint.copy() - # If the model is replicated, always save it into its own file. - if self._model_key in self._replicated_keys or "*" in self._replicated_keys: - if (state_dict := checkpoint.pop(self._model_key, None)) is not None: - if self._gang.rank == 0 or not self._distributed_fs: - model_file = tmp_step_dir.joinpath("model.pt") + if self._model_replicated(): + if (state_dict := rank_part.pop(self._model_key, None)) is not None: + if self._dp_gang.rank == 0 or not self._distributed_fs: + model_file = tmp_step_dir.joinpath(f"model{self._shard_suffix}.pt") try: torch.save({"model": state_dict}, model_file) except (RuntimeError, OSError, PickleError) as ex: raise_error(ex) - self._gang.barrier() + self._root_gang.barrier() - rank_part = checkpoint.copy() - - # For non-distributed file systems, we disregard the replicated keys and - # force each process in the gang to save the full checkpoint. + # For non-distributed file systems, we ignore the replicated keys and + # force each process to save the full checkpoint. if self._replicated_keys and self._distributed_fs: - if self._gang.rank == 0: + if self._dp_gang.rank == 0: replicated_part = {} if "*" in self._replicated_keys: @@ -285,7 +312,9 @@ def raise_error(cause: Exception) -> NoReturn: pass if replicated_part: - replicated_file = tmp_step_dir.joinpath("replicated.pt") + replicated_file = tmp_step_dir.joinpath( + f"replicated{self._shard_suffix}.pt" + ) try: torch.save(replicated_part, replicated_file) @@ -301,7 +330,7 @@ def raise_error(cause: Exception) -> NoReturn: except KeyError: pass - self._gang.barrier() + self._root_gang.barrier() # Check if anything is left to save for the rank. skip_rank = not rank_part @@ -309,33 +338,37 @@ def raise_error(cause: Exception) -> NoReturn: skip_rank = False if not skip_rank: - rank_file = tmp_step_dir.joinpath(f"rank_{self._gang.rank}.pt") + rank_file = tmp_step_dir.joinpath( + f"rank_{self._dp_gang.rank}{self._shard_suffix}.pt" + ) try: torch.save(rank_part, rank_file) except (RuntimeError, OSError, PickleError) as ex: raise_error(ex) - self._gang.barrier() + self._root_gang.barrier() if metadata is not None: - if self._gang.rank == 0 or not self._distributed_fs: - metadata_file = tmp_step_dir.joinpath("metadata.pt") + if self._dp_gang.rank == 0 or not self._distributed_fs: + metadata_file = tmp_step_dir.joinpath( + f"metadata{self._shard_suffix}.pt" + ) try: torch.save(metadata, metadata_file) except (RuntimeError, OSError, PickleError) as ex: raise_error(ex) - self._gang.barrier() + self._root_gang.barrier() - if self._gang.rank == 0 or not self._distributed_fs: + if self._root_gang.rank == 0 or not self._distributed_fs: try: tmp_step_dir.replace(tmp_step_dir.with_suffix("")) except OSError as ex: raise_error(ex) - self._gang.barrier() + self._root_gang.barrier() @override def load_checkpoint(self, step_nr: int) -> Dict[str, Any]: @@ -346,10 +379,13 @@ def raise_error(cause: Exception) -> NoReturn: parts = [] - filenames = ["replicated.pt", f"rank_{self._gang.rank}.pt"] + filenames = [ + f"replicated{self._shard_suffix}.pt", + f"rank_{self._dp_gang.rank}{self._shard_suffix}.pt", + ] - if self._model_key in self._replicated_keys or "*" in self._replicated_keys: - filenames.append("model.pt") + if self._model_replicated(): + filenames.append(f"model{self._shard_suffix}.pt") step_dir = self._checkpoint_dir.joinpath(f"step_{step_nr}") @@ -365,7 +401,7 @@ def raise_error(cause: Exception) -> NoReturn: if part is not None: # Restore the actual model key. - if filename == "model.pt" and self._model_key != "model": + if filename.startswith("model") and self._model_key != "model": try: part = {self._model_key: part["model"]} except KeyError as ex: @@ -373,7 +409,7 @@ def raise_error(cause: Exception) -> NoReturn: parts.append(part) - self._gang.barrier() + self._root_gang.barrier() if not parts: raise CheckpointNotFoundError(f"Training step {step_nr} has no checkpoint.") @@ -386,6 +422,12 @@ def raise_error(cause: Exception) -> NoReturn: return checkpoint + def _model_replicated(self) -> bool: + if self._dp_gang.size == 1: + return True + + return self._model_key in self._replicated_keys or "*" in self._replicated_keys + @override def load_last_checkpoint(self) -> Tuple[int, Dict[str, Any]]: last_step_nr = self.get_last_step_number() @@ -395,12 +437,14 @@ def load_last_checkpoint(self) -> Tuple[int, Dict[str, Any]]: # If we don't have a distributed file system, we have to ensure that we # have a consistent view of checkpoints across all processes. if not self._distributed_fs: + gang = self._root_gang + step_numbers = torch.empty( - (self._gang.size,), device=self._gang.device, dtype=torch.int64 + (gang.size,), device=gang.device, dtype=torch.int64 ) - self._gang.all_gather( - step_numbers, torch.tensor(last_step_nr, device=self._gang.device) + self._root_gang.all_gather( + step_numbers, torch.tensor(last_step_nr, device=gang.device) ) if not (step_numbers == last_step_nr).all(): @@ -416,7 +460,9 @@ def load_last_checkpoint(self) -> Tuple[int, Dict[str, Any]]: @override def load_metadata(self, step_nr: int) -> Optional[Dict[str, Any]]: - metadata_file = self._checkpoint_dir.joinpath(f"step_{step_nr}/metadata.pt") + metadata_file = self._checkpoint_dir.joinpath( + f"step_{step_nr}/metadata{self._shard_suffix}.pt" + ) try: metadata = load_checkpoint(metadata_file, map_location=CPU, mmap=True) @@ -427,13 +473,13 @@ def load_metadata(self, step_nr: int) -> Optional[Dict[str, Any]]: f"The checkpoint metadata of training step {step_nr} cannot be loaded. See nested exception for details." ) from ex - self._gang.barrier() + self._root_gang.barrier() return metadata @override def delete_checkpoint(self, step_nr: int, *, missing_ok: bool = False) -> None: - if self._gang.rank == 0 or not self._distributed_fs: + if self._root_gang.rank == 0 or not self._distributed_fs: step_dir = self._checkpoint_dir.joinpath(f"step_{step_nr}") try: @@ -452,7 +498,7 @@ def delete_checkpoint(self, step_nr: int, *, missing_ok: bool = False) -> None: f"The checkpoint of training step {step_nr} cannot be deleted. See nested exception for details." ) from ex - self._gang.barrier() + self._root_gang.barrier() @override def keep_last_n_checkpoints(self, n: int) -> None: @@ -470,8 +516,12 @@ def save_consolidated_fsdp_model(self, step_nr: int, model: Module) -> None: ): state_dict = model.state_dict() - if self._gang.rank == 0: - tmp_model_file = self._checkpoint_dir.joinpath(f"step_{step_nr}/model.tmp") + self._root_gang.barrier() + + if self._dp_gang.rank == 0: + tmp_model_file = self._checkpoint_dir.joinpath( + f"step_{step_nr}/model{self._shard_suffix}.tmp" + ) try: torch.save({"model": state_dict}, tmp_model_file) @@ -487,7 +537,7 @@ def save_consolidated_fsdp_model(self, step_nr: int, model: Module) -> None: f"The model of training step {step_nr} cannot be saved. See nested exception for details." ) from ex - self._gang.barrier() + self._root_gang.barrier() # compat @override @@ -498,7 +548,9 @@ def save_consolidated_model(self, step_nr: int, model: Module) -> None: def load_model( self, step_nr: int, out: Module, *, device: Optional[Device] = None ) -> None: - model_file = self._checkpoint_dir.joinpath(f"step_{step_nr}/model.pt") + model_file = self._checkpoint_dir.joinpath( + f"step_{step_nr}/model{self._shard_suffix}.pt" + ) def raise_error(cause: Exception) -> NoReturn: raise RuntimeError( @@ -540,7 +592,7 @@ def raise_error(cause: Exception) -> NoReturn: # have to explicitly initialize them. reset_non_persistent_buffers(out) - self._gang.barrier() + self._root_gang.barrier() @override def load_last_model(self, out: Module, *, device: Optional[Device] = None) -> int: @@ -565,7 +617,9 @@ def get_model_path(self, step_nr: Optional[int] = None) -> Optional[Path]: if step_nr is None: return None - return self._checkpoint_dir.joinpath(f"step_{step_nr}/model.pt") + return self._checkpoint_dir.joinpath( + f"step_{step_nr}/model{self._shard_suffix}.pt" + ) # compat def get_model_checkpoint_path( @@ -616,7 +670,7 @@ def _iter_step_numbers(self, with_model: bool) -> Iterator[int]: # cached LOOKUP results. self._clear_nfs_lookup_cache(step_dir) - if not step_dir.joinpath("model.pt").exists(): + if not step_dir.joinpath(f"model{self._shard_suffix}.pt").exists(): continue yield step_nr