Skip to content

Commit

Permalink
Merge branch 'main' into zyaoj/harden-lr-scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
zyaoj authored Dec 20, 2024
2 parents c663e4d + 9a89641 commit 5dbc048
Show file tree
Hide file tree
Showing 4 changed files with 323 additions and 67 deletions.
205 changes: 157 additions & 48 deletions src/fairseq2/gang.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,154 @@ def fake_gangs(device: Device) -> Gangs:
return Gangs(gang, gang, gang)


def _setup_2D_mesh_gangs(
root_gang: Gang,
*,
row_length: int = 1,
create_single_rank_process_groups: bool = False,
dim_descriptions: list[str] | None = None,
) -> dict[int, Gang]:
"""Set up gangs for this process as defined by a 2D device mesh.
The two returned gangs are defined by the process' position in the mesh.
First gang is the row in the mesh, second is the column.
For example, assuming 8 devices denoted by g0 to g7, calling this function
with ``row_length`` = 4 amounts to defining the 2D mesh
[[g0, g1, g2, g3], [g4, g5, g6, g7]] and making 2 sets of gangs:
2 gangs of size 4 (mesh rows):
[g0, g1, g2, g3], [g4, g5, g6, g7]
4 gangs of size 2 (mesh columns):
[g0, g4], [g1, g5], [g2, g6], [g3, g7]
For the process of rank 5, the function would return the 2 sub-gangs
{0: [g4, g5, g6, g7], 1: [g1, g5]}. If adjacent ranks are on the same host
(for example, 2 hosts: one with g0 to g3, and the other with g4 to g7),
the first gang can be used to maximize local intra-host communication.
Example use-cases include making tensor- and data- parallel gangs, or
sharding and replicating gangs in FSDP's hybrid sharding.
:param root_gang:
The gang whose topology will be used to make the new gangs.
:param row_length:
The size of the gangs corresponding to the 2D mesh rows.
:param create_single_rank_process_groups:
If ``True``, create an underlying ``dist.ProcessGroup`` even for single-rank gangs.
The gang is faked otherwise.
:param dim_descriptions:
String descriptions of returned gangs, used in log and error messages.
:returns:
A ``dict`` of two gangs; 0 maps to the gang of 2D mesh row,
1 maps to the gang of the 2D mesh column.
"""
row_count = root_gang.size // row_length

mesh = torch.arange(root_gang.size).view(row_count, row_length)

# Get the coordinate of this process in the mesh.
rank_coords = [x.item() for x in torch.where(mesh == root_gang.rank)]
mesh_shape = mesh.size()

output = {}

log.info(
"Initializing sub-gangs for a 2D device mesh of shape {}.", list(mesh_shape)
)
if dim_descriptions is None:
dim_descriptions = [f"dim-{dim}" for dim in range(2)]

for dim in range(2):
current_subgang: Gang | None = None

gang_size = mesh_shape[1 - dim]

log.info(
"Initializing {} gang with a size of {}.", dim_descriptions[dim], gang_size
)

# Match row length (dim 0) or column length (dim 1)
match gang_size:
case 1:
if create_single_rank_process_groups:
current_subgang = root_gang.make_gang([root_gang.rank])
else:
current_subgang = FakeGang(device=root_gang.device)
case root_gang.size:
current_subgang = root_gang
case _:
# Create 1 gang per row (dim 0) or per column (dim 1)
for i in range(mesh_shape[dim]):
ranks = mesh[i, :] if dim == 0 else mesh[:, i]
sub_gang = root_gang.make_gang(ranks.tolist())
if i == rank_coords[dim]:
current_subgang = sub_gang

if current_subgang is None:
raise InternalError(f"`current_gang` ({dim_descriptions[dim]}) is `None`.")

output[dim] = current_subgang

return output


def setup_hybrid_fsdp_gangs(gang: Gang, local_world_size: int) -> tuple[Gang, Gang]:
"""Make gangs to be used for hybrid-sharding FSDP.
For instance; if we have 8 devices denoted by g0 to g7 and ``local_world_size``
is 4, this function will make 2 sharding gangs and 4 replication gangs:
2 sharding gangs of size 4:
[g0, g1, g2, g3], [g4, g5, g6, g7]
4 replication gangs of size 2:
[g0, g4], [g1, g5], [g2, g6], [g3, g7]
For efficiency, the caller should make sure adjacent ranks are on the same
host.
:param gang:
The gang over which to shard and replicate.
:param local_world_size:
``gang`` will be split into sub-gangs each containing
``local_world_size`` number of consecutive processes.
The model will be fully sharded within each sub-gang and
will be replicated across sub-gangs.
:returns:
A pair of two gangs: the sharding gang that the current process is
part of, and the replication gang that the current process is part of
"""
if local_world_size < 1:
raise ValueError(
f"`local_world_size` must be greater than 1, but is {local_world_size} instead."
)

if local_world_size == 1:
raise GangError(
f"`local_world_size` must be greater than 1, but is {local_world_size} instead. This hybrid configuration would force FSDP to switch to use `NO_SHARD`, which is deprecated. Please use DDP instead."
)

if local_world_size > gang.size:
raise ValueError(
f"`local_world_size` must be less than or equal to `gang.size` ({gang.size}), but is {local_world_size} instead."
)

if gang.size % local_world_size != 0:
raise GangError(
f"`gang.size` ({gang.size}) must be a multiple of `local_world_size` ({local_world_size})."
)

sub_gangs = _setup_2D_mesh_gangs(
gang,
row_length=local_world_size,
create_single_rank_process_groups=True,
dim_descriptions=["sharding", "replication"],
)

return sub_gangs[0], sub_gangs[1]


def setup_parallel_gangs(root_gang: Gang, *, tp_size: int = 1) -> Gangs:
"""Make gangs to be used for data and tensor parallelism.
Expand All @@ -729,9 +877,9 @@ def setup_parallel_gangs(root_gang: Gang, *, tp_size: int = 1) -> Gangs:
The size of tensor parallel gangs.
:returns:
A ``dict`` of two gangs; (1) the data parallel gang that this process
is part of denoted by the key "dp", (2) the tensor parallel gang that
this process is part of denoted by the key "tp".
Three gangs: the root gang, the data parallel gang that this
process is part of, and the tensor parallel gang that this process is
part of.
"""
if tp_size <= 0:
raise ValueError(f"`tp_size` must be greater than 0, but is {tp_size} instead.")
Expand All @@ -741,52 +889,13 @@ def setup_parallel_gangs(root_gang: Gang, *, tp_size: int = 1) -> Gangs:
f"The number of processes in the root gang is expected to be a multiple of the tensor parallel size ({tp_size}), but is {root_gang.size} instead."
)

dp_size = root_gang.size // tp_size

mesh = torch.arange(root_gang.size).view(dp_size, tp_size)

# Get the coordinate of this process in the mesh.
rank_coords = [x.item() for x in torch.where(mesh == root_gang.rank)]
output_from_2D_mesh = _setup_2D_mesh_gangs(
root_gang,
row_length=tp_size,
dim_descriptions=["tensor parallel", "data parallel"],
)

dp_gang: Gang | None = None

log.info("Initializing data parallel gang with a size of {}.", dp_size)

# Build the gangs for data parallelism.
match dp_size:
case 1:
dp_gang = FakeGang(device=root_gang.device)
case root_gang.size:
dp_gang = root_gang
case _:
for i in range(tp_size):
sub_gang = root_gang.make_gang(mesh[:, i].tolist())
if i == rank_coords[1]:
dp_gang = sub_gang

if dp_gang is None:
raise InternalError("`dp_gang` is `None`.")

tp_gang: Gang | None = None

log.info("Initializing tensor parallel gang with a size of {}.", tp_size)

# Build the gangs for tensor parallelism.
match tp_size:
case 1:
tp_gang = FakeGang(device=root_gang.device)
case root_gang.size:
tp_gang = root_gang
case _:
for i in range(dp_size):
sub_gang = root_gang.make_gang(mesh[i, :].tolist())
if i == rank_coords[0]:
tp_gang = sub_gang

if tp_gang is None:
raise InternalError("`tp_gang` is `None`.")

return Gangs(root_gang, dp_gang, tp_gang)
return Gangs(root_gang, output_from_2D_mesh[1], output_from_2D_mesh[0])


def broadcast_flag(gang: Gang, flag: bool, source_rank: int = 0) -> bool:
Expand Down
36 changes: 17 additions & 19 deletions src/fairseq2/nn/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import torch
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import (
BackwardPrefetch,
Expand All @@ -27,7 +28,7 @@
)
from torch.nn import Module, Parameter

from fairseq2.gang import Gang
from fairseq2.gang import Gang, setup_hybrid_fsdp_gangs
from fairseq2.logging import log
from fairseq2.nn.utils.module import (
apply_to_parameters,
Expand Down Expand Up @@ -87,30 +88,27 @@ def to_fsdp(
If ``True``, the gradients will be reduced in full precision. Only
relevant if ``mixed_precision_dtype`` is not ``None``.
"""
process_group: ProcessGroup | tuple[ProcessGroup, ProcessGroup] | None = None

if local_world_size is not None:
if local_world_size == 0:
raise ValueError(
f"`local_world_size` must be greater than 0, but is {local_world_size} instead."
)

if local_world_size > gang.size:
raise ValueError(
f"`local_world_size` must be less than or equal to `gang.size` ({gang.size}), but is {local_world_size} instead."
)

if gang.size % local_world_size != 0:
raise ValueError(
f"`gang.size` ({gang.size}) must be a multiple of `local_world_size` ({local_world_size})."
)

# TODO(balioglu): Finish!
raise NotImplementedError("`local_world_size` is not supported yet.")
sharding_strategy = ShardingStrategy.HYBRID_SHARD

sharding_gang, replication_gang = setup_hybrid_fsdp_gangs(
gang, local_world_size
)

process_group = (
sharding_gang.as_process_group(),
replication_gang.as_process_group(),
)
else:
if reshard_after_forward:
sharding_strategy = ShardingStrategy.FULL_SHARD
else:
sharding_strategy = ShardingStrategy.SHARD_GRAD_OP

process_group = gang.as_process_group()

if memory_policy is None:
memory_policy = FSDP_STANDARD_MEMORY_POLICY

Expand Down Expand Up @@ -156,7 +154,7 @@ def to_fsdp(

fsdp = FSDP(
module,
process_group=gang.as_process_group(),
process_group=process_group,
sharding_strategy=sharding_strategy,
cpu_offload=CPUOffload() if memory_policy.cpu_offload else None,
auto_wrap_policy=wrap_policy,
Expand Down
8 changes: 8 additions & 0 deletions src/fairseq2/recipes/lm/instruction_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,13 @@ class InstructionFinetuneConfig:
data_parallelism: Literal["ddp", "fsdp"] = "fsdp"
"""The data parallelism API to use."""

fsdp_local_world_size: int | None = None
"""
If not ``None``, enables hybrid sharding. The model will be fully sharded
within each worker group of size ``local_world_size`` and
will be replicated across groups.
"""

fsdp_wrap_granularity: Literal["layer", "stack", "model"] = "layer"
"""The granularity at which to wrap the model."""

Expand Down Expand Up @@ -407,6 +414,7 @@ def load_instruction_finetuner(
fsdp_mixed_precision_dtype=mp_dtype,
fsdp_fp32_reduce=True,
fsdp_wrap_granularity=config.fsdp_wrap_granularity,
fsdp_local_world_size=config.fsdp_local_world_size,
)

if config.activation_checkpointing:
Expand Down
Loading

0 comments on commit 5dbc048

Please sign in to comment.