Skip to content

Commit

Permalink
Add support for tensor parallelism to CheckpointManager
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu committed Apr 26, 2024
1 parent 41cd4d1 commit 72bfc54
Showing 1 changed file with 102 additions and 48 deletions.
150 changes: 102 additions & 48 deletions src/fairseq2/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ def load_model(
:param out:
The model to load.
:param device:
The device on which to load ``out`` if it is on the meta device.
The device on which to load ``out`` if it is on the meta device;
ignored otherwise.
"""

@abstractmethod
Expand All @@ -135,7 +136,8 @@ def load_last_model(self, out: Module, *, device: Optional[Device] = None) -> in
:param out:
The model to load.
:param device:
The device on which to load ``out`` if it is on the meta device.
The device on which to load ``out`` if it is on the meta device;
ignored otherwise.
:returns:
The number of the training step.
Expand All @@ -145,11 +147,11 @@ def load_last_model(self, out: Module, *, device: Optional[Device] = None) -> in
def has_checkpoint(
self, step_nr: Optional[int] = None, *, with_model: bool = False
) -> bool:
"""Return ``True`` if the manager has a checkpoint.
"""Return ``True`` if the manager holds a checkpoint.
:param step_nr:
The number of the training step. If ``None``, returns ``True`` if
the manager has at least one checkpoint.
the manager holds at least one checkpoint.
:param with_model:
If ``True``, only considers training steps with a saved model.
"""
Expand All @@ -176,7 +178,9 @@ class FileCheckpointManager(CheckpointManager):
"""Saves and loads training checkpoints on a file system."""

_checkpoint_dir: Path
_gang: Gang
_root_gang: Gang
_dp_gang: Gang
_shard_suffix: str
_distributed_fs: bool
_model_key: str
_replicated_keys: Set[str]
Expand All @@ -186,6 +190,8 @@ def __init__(
checkpoint_dir: Path,
gang: Gang,
*,
dp_gang: Optional[Gang] = None,
tp_gang: Optional[Gang] = None,
distributed_fs: bool = True,
model_key: str = "model",
replicated_keys: Optional[Sequence[str]] = None,
Expand All @@ -195,6 +201,11 @@ def __init__(
The base directory under which to store the checkpoints.
:param gang:
The gang to coordinate the checkpoint operations.
:param dp_gang:
The gang used for data parallelism.
:param tp_gang:
The gang used for tensor parallelism. Must be specified if ``dp_gang``
is not ``None``.
:param distributed_fs:
If ``True``, the underlying file system of ``checkpoint_dir`` is
considered distributed (e.g. NFS).
Expand All @@ -204,13 +215,28 @@ def __init__(
The keys in provided checkpoints whose values are replicated across
all processes in the gang.
"""
self._gang = gang
self._root_gang = gang

self._dp_gang = gang

self._shard_suffix = ""

if dp_gang is not None and tp_gang is not None:
self._dp_rank = dp_gang

if tp_gang.size > 1:
self._shard_suffix = f".{tp_gang.rank}"
elif dp_gang is not None or tp_gang is not None:
raise ValueError("`dp_gang` and `tp_gang` must be both specified.")

self._distributed_fs = distributed_fs

if distributed_fs:
self._checkpoint_dir = checkpoint_dir
else:
self._checkpoint_dir = checkpoint_dir.joinpath(f"rank_{gang.rank}")
self._checkpoint_dir = checkpoint_dir.joinpath(
f"rank_{self._dp_gang.rank}{self._shard_suffix}"
)

self._model_key = model_key

Expand Down Expand Up @@ -246,33 +272,34 @@ def raise_error(cause: Exception) -> NoReturn:

tmp_step_dir = self._checkpoint_dir.joinpath(f"step_{step_nr}.tmp")

if self._gang.rank == 0 or not self._distributed_fs:
if self._root_gang.rank == 0 or not self._distributed_fs:
try:
tmp_step_dir.mkdir(parents=True)
except OSError as ex:
raise_error(ex)

self._gang.barrier()
self._root_gang.barrier()

# Do not modify the argument in-place. In case we fail, it should stay
# intact.
rank_part = checkpoint.copy()

# If the model is replicated, always save it into its own file.
if self._model_key in self._replicated_keys or "*" in self._replicated_keys:
if (state_dict := checkpoint.pop(self._model_key, None)) is not None:
if self._gang.rank == 0 or not self._distributed_fs:
model_file = tmp_step_dir.joinpath("model.pt")
if self._model_replicated():
if (state_dict := rank_part.pop(self._model_key, None)) is not None:
if self._dp_gang.rank == 0 or not self._distributed_fs:
model_file = tmp_step_dir.joinpath(f"model{self._shard_suffix}.pt")

try:
torch.save({"model": state_dict}, model_file)
except (RuntimeError, OSError, PickleError) as ex:
raise_error(ex)

self._gang.barrier()
self._root_gang.barrier()

rank_part = checkpoint.copy()

# For non-distributed file systems, we disregard the replicated keys and
# force each process in the gang to save the full checkpoint.
# For non-distributed file systems, we ignore the replicated keys and
# force each process to save the full checkpoint.
if self._replicated_keys and self._distributed_fs:
if self._gang.rank == 0:
if self._dp_gang.rank == 0:
replicated_part = {}

if "*" in self._replicated_keys:
Expand All @@ -285,7 +312,9 @@ def raise_error(cause: Exception) -> NoReturn:
pass

if replicated_part:
replicated_file = tmp_step_dir.joinpath("replicated.pt")
replicated_file = tmp_step_dir.joinpath(
f"replicated{self._shard_suffix}.pt"
)

try:
torch.save(replicated_part, replicated_file)
Expand All @@ -301,41 +330,45 @@ def raise_error(cause: Exception) -> NoReturn:
except KeyError:
pass

self._gang.barrier()
self._root_gang.barrier()

# Check if anything is left to save for the rank.
skip_rank = not rank_part
else:
skip_rank = False

if not skip_rank:
rank_file = tmp_step_dir.joinpath(f"rank_{self._gang.rank}.pt")
rank_file = tmp_step_dir.joinpath(
f"rank_{self._dp_gang.rank}{self._shard_suffix}.pt"
)

try:
torch.save(rank_part, rank_file)
except (RuntimeError, OSError, PickleError) as ex:
raise_error(ex)

self._gang.barrier()
self._root_gang.barrier()

if metadata is not None:
if self._gang.rank == 0 or not self._distributed_fs:
metadata_file = tmp_step_dir.joinpath("metadata.pt")
if self._dp_gang.rank == 0 or not self._distributed_fs:
metadata_file = tmp_step_dir.joinpath(
f"metadata{self._shard_suffix}.pt"
)

try:
torch.save(metadata, metadata_file)
except (RuntimeError, OSError, PickleError) as ex:
raise_error(ex)

self._gang.barrier()
self._root_gang.barrier()

if self._gang.rank == 0 or not self._distributed_fs:
if self._root_gang.rank == 0 or not self._distributed_fs:
try:
tmp_step_dir.replace(tmp_step_dir.with_suffix(""))
except OSError as ex:
raise_error(ex)

self._gang.barrier()
self._root_gang.barrier()

@override
def load_checkpoint(self, step_nr: int) -> Dict[str, Any]:
Expand All @@ -346,10 +379,13 @@ def raise_error(cause: Exception) -> NoReturn:

parts = []

filenames = ["replicated.pt", f"rank_{self._gang.rank}.pt"]
filenames = [
f"replicated{self._shard_suffix}.pt",
f"rank_{self._dp_gang.rank}{self._shard_suffix}.pt",
]

if self._model_key in self._replicated_keys or "*" in self._replicated_keys:
filenames.append("model.pt")
if self._model_replicated():
filenames.append(f"model{self._shard_suffix}.pt")

step_dir = self._checkpoint_dir.joinpath(f"step_{step_nr}")

Expand All @@ -365,15 +401,15 @@ def raise_error(cause: Exception) -> NoReturn:

if part is not None:
# Restore the actual model key.
if filename == "model.pt" and self._model_key != "model":
if filename.startswith("model") and self._model_key != "model":
try:
part = {self._model_key: part["model"]}
except KeyError as ex:
raise_error(ex)

parts.append(part)

self._gang.barrier()
self._root_gang.barrier()

if not parts:
raise CheckpointNotFoundError(f"Training step {step_nr} has no checkpoint.")
Expand All @@ -386,6 +422,12 @@ def raise_error(cause: Exception) -> NoReturn:

return checkpoint

def _model_replicated(self) -> bool:
if self._dp_gang.size == 1:
return True

return self._model_key in self._replicated_keys or "*" in self._replicated_keys

@override
def load_last_checkpoint(self) -> Tuple[int, Dict[str, Any]]:
last_step_nr = self.get_last_step_number()
Expand All @@ -395,12 +437,14 @@ def load_last_checkpoint(self) -> Tuple[int, Dict[str, Any]]:
# If we don't have a distributed file system, we have to ensure that we
# have a consistent view of checkpoints across all processes.
if not self._distributed_fs:
gang = self._root_gang

step_numbers = torch.empty(
(self._gang.size,), device=self._gang.device, dtype=torch.int64
(gang.size,), device=gang.device, dtype=torch.int64
)

self._gang.all_gather(
step_numbers, torch.tensor(last_step_nr, device=self._gang.device)
self._root_gang.all_gather(
step_numbers, torch.tensor(last_step_nr, device=gang.device)
)

if not (step_numbers == last_step_nr).all():
Expand All @@ -416,7 +460,9 @@ def load_last_checkpoint(self) -> Tuple[int, Dict[str, Any]]:

@override
def load_metadata(self, step_nr: int) -> Optional[Dict[str, Any]]:
metadata_file = self._checkpoint_dir.joinpath(f"step_{step_nr}/metadata.pt")
metadata_file = self._checkpoint_dir.joinpath(
f"step_{step_nr}/metadata{self._shard_suffix}.pt"
)

try:
metadata = load_checkpoint(metadata_file, map_location=CPU, mmap=True)
Expand All @@ -427,13 +473,13 @@ def load_metadata(self, step_nr: int) -> Optional[Dict[str, Any]]:
f"The checkpoint metadata of training step {step_nr} cannot be loaded. See nested exception for details."
) from ex

self._gang.barrier()
self._root_gang.barrier()

return metadata

@override
def delete_checkpoint(self, step_nr: int, *, missing_ok: bool = False) -> None:
if self._gang.rank == 0 or not self._distributed_fs:
if self._root_gang.rank == 0 or not self._distributed_fs:
step_dir = self._checkpoint_dir.joinpath(f"step_{step_nr}")

try:
Expand All @@ -452,7 +498,7 @@ def delete_checkpoint(self, step_nr: int, *, missing_ok: bool = False) -> None:
f"The checkpoint of training step {step_nr} cannot be deleted. See nested exception for details."
) from ex

self._gang.barrier()
self._root_gang.barrier()

@override
def keep_last_n_checkpoints(self, n: int) -> None:
Expand All @@ -470,8 +516,12 @@ def save_consolidated_fsdp_model(self, step_nr: int, model: Module) -> None:
):
state_dict = model.state_dict()

if self._gang.rank == 0:
tmp_model_file = self._checkpoint_dir.joinpath(f"step_{step_nr}/model.tmp")
self._root_gang.barrier()

if self._dp_gang.rank == 0:
tmp_model_file = self._checkpoint_dir.joinpath(
f"step_{step_nr}/model{self._shard_suffix}.tmp"
)

try:
torch.save({"model": state_dict}, tmp_model_file)
Expand All @@ -487,7 +537,7 @@ def save_consolidated_fsdp_model(self, step_nr: int, model: Module) -> None:
f"The model of training step {step_nr} cannot be saved. See nested exception for details."
) from ex

self._gang.barrier()
self._root_gang.barrier()

# compat
@override
Expand All @@ -498,7 +548,9 @@ def save_consolidated_model(self, step_nr: int, model: Module) -> None:
def load_model(
self, step_nr: int, out: Module, *, device: Optional[Device] = None
) -> None:
model_file = self._checkpoint_dir.joinpath(f"step_{step_nr}/model.pt")
model_file = self._checkpoint_dir.joinpath(
f"step_{step_nr}/model{self._shard_suffix}.pt"
)

def raise_error(cause: Exception) -> NoReturn:
raise RuntimeError(
Expand Down Expand Up @@ -540,7 +592,7 @@ def raise_error(cause: Exception) -> NoReturn:
# have to explicitly initialize them.
reset_non_persistent_buffers(out)

self._gang.barrier()
self._root_gang.barrier()

@override
def load_last_model(self, out: Module, *, device: Optional[Device] = None) -> int:
Expand All @@ -565,7 +617,9 @@ def get_model_path(self, step_nr: Optional[int] = None) -> Optional[Path]:
if step_nr is None:
return None

return self._checkpoint_dir.joinpath(f"step_{step_nr}/model.pt")
return self._checkpoint_dir.joinpath(
f"step_{step_nr}/model{self._shard_suffix}.pt"
)

# compat
def get_model_checkpoint_path(
Expand Down Expand Up @@ -616,7 +670,7 @@ def _iter_step_numbers(self, with_model: bool) -> Iterator[int]:
# cached LOOKUP results.
self._clear_nfs_lookup_cache(step_dir)

if not step_dir.joinpath("model.pt").exists():
if not step_dir.joinpath(f"model{self._shard_suffix}.pt").exists():
continue

yield step_nr
Expand Down

0 comments on commit 72bfc54

Please sign in to comment.