Skip to content

Commit

Permalink
preserve MegDS-like Seq+DP groups in MPU--useful if we were to add ul…
Browse files Browse the repository at this point in the history
…ysses
  • Loading branch information
haileyschoelkopf committed Aug 19, 2024
1 parent c93c1b4 commit 9117ccc
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 2 deletions.
2 changes: 2 additions & 0 deletions megatron/mpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from .initialize import get_tensor_model_parallel_group
from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_world_size
from .initialize import get_sequence_data_parallel_group, get_sequence_data_parallel_rank, get_sequence_data_parallel_world_size
from .initialize import get_sequence_parallel_group, get_sequence_parallel_rank, get_sequence_parallel_world_size
from .initialize import get_io_parallel_group
from .initialize import initialize_model_parallel
from .initialize import model_parallel_is_initialized
Expand Down
95 changes: 93 additions & 2 deletions megatron/mpu/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@
_DATA_PARALLEL_GROUP = None
# Pipeline parallel group that the current rank belongs to.
_PIPE_PARALLEL_GROUP = None
# Sequence parallel group that the current rank belongs to. # TODO: do we need this?
_SEQ_PARALLEL_GROUP = None
# Sequence + data parallel group that the current rank belongs to.
# used to determine group over which ZeRO should allreduce grads --> keep sharded LNs in sync
_SEQ_DATA_PARALLEL_GROUP = None

# A group used to sync during the IO process. Usually this is data_parallel_group(),
# but with pipeline parallelism it must also involve the last stage (which is not in the
Expand All @@ -50,6 +55,7 @@ def is_unitialized():
return _DATA_PARALLEL_GROUP is None


# TODO: this fn should not set up distinct seq or seq+DP groups (make them None and just DP groups, respectively) when sequence_parallel is false
def initialize_model_parallel(model_parallel_size, topology=None, fp32_allreduce=False):
"""
Initialize model data parallel groups.
Expand Down Expand Up @@ -128,31 +134,75 @@ def initialize_model_parallel(model_parallel_size, topology=None, fp32_allreduce
# Build the model parallel groups.
global _MODEL_PARALLEL_GROUP
assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized"
# we also build the sequence parallel group
# (which is, in our case for now, identical always to the model parallel group.)
global _SEQ_PARALLEL_GROUP
assert _SEQ_PARALLEL_GROUP is None, "sequence parallel group is already initialized"
if topology:
# Short circuit case without model parallelism.
# TODO: it would be nice to avoid this branching case?
if model_parallel_size == 1:
for group_rank in range(world_size):
group = torch.distributed.new_group(ranks=[group_rank])
seq_group = torch.distributed.new_group(ranks=[group_rank])
if rank == 0:
print(f"MPU MP:", [group_rank])
print(f"MPU SeqP:", [group_rank])
if rank == group_rank:
_MODEL_PARALLEL_GROUP = group
return
_SEQ_PARALLEL_GROUP = seq_group
return # TODO: why does this short circuit? Seems off.

for mp_group in topology.get_axis_comm_lists("model"):
group = torch.distributed.new_group(ranks=mp_group)
seq_group = torch.distributed.new_group(ranks=mp_group)
if rank == 0:
print(f"MPU MP:", mp_group)
print(f"MPU SeqP:", mp_group)
if rank in mp_group:
_MODEL_PARALLEL_GROUP = group
_SEQ_PARALLEL_GROUP = seq_group

else:
for i in range(world_size // model_parallel_size):
ranks = range(i * model_parallel_size, (i + 1) * model_parallel_size)
group = torch.distributed.new_group(ranks)
seq_group = torch.distributed.new_group(ranks)
if i == (rank // model_parallel_size):
_MODEL_PARALLEL_GROUP = group
_SEQ_PARALLEL_GROUP = seq_group

# Build the sequence+data parallel group.
global _SEQ_DATA_PARALLEL_GROUP
assert _SEQ_DATA_PARALLEL_GROUP is None, "sequence+data parallel group is already initialized"
if topology:
# Short circuit case without model parallelism.
# TODO: it would be nice to avoid this branching case?
if model_parallel_size == 1:
# seq+data just becomes data parallel group
for dp_group in topology.get_axis_comm_lists("data"):
group = torch.distributed.new_group(ranks=dp_group)
if rank in dp_group:
_SEQ_DATA_PARALLEL_GROUP = group
return

for pp_group_idx in range(topology.get_dim("pipe")):
tpdp_group = topology.filter_match(pipe=pp_group_idx) # per PP stage, what are all the corresponding TP and DP ranks?
# our seq+data parallel group should allow us to reduce grads across both
seq_dp_group = torch.distributed.new_group(ranks=tpdp_group)
if rank == 0:
print(f"MPU Seq+DP:", tpdp_group)
if rank in tpdp_group:
_SEQ_DATA_PARALLEL_GROUP = seq_dp_group

else:
raise NotImplementedError("Assume that we always pass a topology object to mpu.initialize_model_parallel .")
for i in range(world_size // model_parallel_size):
ranks = range(i * model_parallel_size, (i + 1) * model_parallel_size)
group = torch.distributed.new_group(ranks)
seq_group = torch.distributed.new_group(ranks)
if i == (rank // model_parallel_size):
_SEQ_DATA_PARALLEL_GROUP = seq_group

global _FP32_ALLREDUCE
assert _FP32_ALLREDUCE is None, "fp32_allreduce is already initialized"
Expand Down Expand Up @@ -252,7 +302,7 @@ def get_topology():

def get_pipe_parallel_group():
"""Get the pipe parallel group the caller rank belongs to."""
assert _PIPE_PARALLEL_GROUP is not None, "data parallel group is not initialized"
assert _PIPE_PARALLEL_GROUP is not None, "pipw parallel group is not initialized"
return _PIPE_PARALLEL_GROUP


Expand Down Expand Up @@ -298,6 +348,41 @@ def get_tensor_model_parallel_rank():
return get_model_parallel_rank()


def get_sequence_parallel_group():
"""Get the sequence parallel group the caller rank belongs to."""
assert _SEQ_PARALLEL_GROUP is not None, "sequence parallel group is not initialized"
return _SEQ_PARALLEL_GROUP


def get_sequence_data_parallel_group():
"""Get the sequence+data parallel group the caller rank belongs to."""
assert _SEQ_DATA_PARALLEL_GROUP is not None, "sequence+data parallel group is not initialized"
return _SEQ_DATA_PARALLEL_GROUP


def get_sequence_parallel_world_size():
"""Return world size for the sequence parallel group."""
return torch.distributed.get_world_size(group=get_sequence_parallel_group())


def get_sequence_data_parallel_world_size():
"""Return world size for the sequence parallel group."""
return torch.distributed.get_world_size(group=get_sequence_data_parallel_group())


def get_sequence_parallel_rank():
"""Return my rank for the sequence parallel group."""
return torch.distributed.get_rank(group=get_sequence_parallel_group())


def get_sequence_data_parallel_rank():
"""Return my rank for the sequence+data parallel group."""
return torch.distributed.get_rank(group=get_sequence_data_parallel_group())


# TODO: other groups use a set_... and ..._src_rank fn. Do we need these too for seq, seq+data?


def destroy_model_parallel():
"""Set the groups to none."""
global _MODEL_PARALLEL_GROUP
Expand All @@ -308,6 +393,12 @@ def destroy_model_parallel():
_PIPE_PARALLEL_GROUP = None
global _IO_PARALLEL_GROUP
_IO_PARALLEL_GROUP = None

global _SEQ_PARALLEL_GROUP
_SEQ_PARALLEL_GROUP = None
global _SEQ_DATA_PARALLEL_GROUP
_SEQ_DATA_PARALLEL_GROUP = None

global _MPU_WORLD_SIZE
global _MPU_RANK
_MPU_WORLD_SIZE = None
Expand Down

0 comments on commit 9117ccc

Please sign in to comment.