diff --git a/megatron/mpu/__init__.py b/megatron/mpu/__init__.py index 780fb33e8..a6fe98f9d 100644 --- a/megatron/mpu/__init__.py +++ b/megatron/mpu/__init__.py @@ -34,6 +34,8 @@ from .initialize import get_tensor_model_parallel_group from .initialize import get_tensor_model_parallel_rank from .initialize import get_tensor_model_parallel_world_size +from .initialize import get_sequence_data_parallel_group, get_sequence_data_parallel_rank, get_sequence_data_parallel_world_size +from .initialize import get_sequence_parallel_group, get_sequence_parallel_rank, get_sequence_parallel_world_size from .initialize import get_io_parallel_group from .initialize import initialize_model_parallel from .initialize import model_parallel_is_initialized diff --git a/megatron/mpu/initialize.py b/megatron/mpu/initialize.py index 19d231524..e7b3a2e9a 100644 --- a/megatron/mpu/initialize.py +++ b/megatron/mpu/initialize.py @@ -28,6 +28,11 @@ _DATA_PARALLEL_GROUP = None # Pipeline parallel group that the current rank belongs to. _PIPE_PARALLEL_GROUP = None +# Sequence parallel group that the current rank belongs to. # TODO: do we need this? +_SEQ_PARALLEL_GROUP = None +# Sequence + data parallel group that the current rank belongs to. +# used to determine group over which ZeRO should allreduce grads --> keep sharded LNs in sync +_SEQ_DATA_PARALLEL_GROUP = None # A group used to sync during the IO process. Usually this is data_parallel_group(), # but with pipeline parallelism it must also involve the last stage (which is not in the @@ -50,6 +55,7 @@ def is_unitialized(): return _DATA_PARALLEL_GROUP is None +# TODO: this fn should not set up distinct seq or seq+DP groups (make them None and just DP groups, respectively) when sequence_parallel is false def initialize_model_parallel(model_parallel_size, topology=None, fp32_allreduce=False): """ Initialize model data parallel groups. @@ -128,31 +134,75 @@ def initialize_model_parallel(model_parallel_size, topology=None, fp32_allreduce # Build the model parallel groups. global _MODEL_PARALLEL_GROUP assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized" + # we also build the sequence parallel group + # (which is, in our case for now, identical always to the model parallel group.) + global _SEQ_PARALLEL_GROUP + assert _SEQ_PARALLEL_GROUP is None, "sequence parallel group is already initialized" if topology: # Short circuit case without model parallelism. # TODO: it would be nice to avoid this branching case? if model_parallel_size == 1: for group_rank in range(world_size): group = torch.distributed.new_group(ranks=[group_rank]) + seq_group = torch.distributed.new_group(ranks=[group_rank]) if rank == 0: print(f"MPU MP:", [group_rank]) + print(f"MPU SeqP:", [group_rank]) if rank == group_rank: _MODEL_PARALLEL_GROUP = group - return + _SEQ_PARALLEL_GROUP = seq_group + return # TODO: why does this short circuit? Seems off. for mp_group in topology.get_axis_comm_lists("model"): group = torch.distributed.new_group(ranks=mp_group) + seq_group = torch.distributed.new_group(ranks=mp_group) if rank == 0: print(f"MPU MP:", mp_group) + print(f"MPU SeqP:", mp_group) if rank in mp_group: _MODEL_PARALLEL_GROUP = group + _SEQ_PARALLEL_GROUP = seq_group else: for i in range(world_size // model_parallel_size): ranks = range(i * model_parallel_size, (i + 1) * model_parallel_size) group = torch.distributed.new_group(ranks) + seq_group = torch.distributed.new_group(ranks) if i == (rank // model_parallel_size): _MODEL_PARALLEL_GROUP = group + _SEQ_PARALLEL_GROUP = seq_group + + # Build the sequence+data parallel group. + global _SEQ_DATA_PARALLEL_GROUP + assert _SEQ_DATA_PARALLEL_GROUP is None, "sequence+data parallel group is already initialized" + if topology: + # Short circuit case without model parallelism. + # TODO: it would be nice to avoid this branching case? + if model_parallel_size == 1: + # seq+data just becomes data parallel group + for dp_group in topology.get_axis_comm_lists("data"): + group = torch.distributed.new_group(ranks=dp_group) + if rank in dp_group: + _SEQ_DATA_PARALLEL_GROUP = group + return + + for pp_group_idx in range(topology.get_dim("pipe")): + tpdp_group = topology.filter_match(pipe=pp_group_idx) # per PP stage, what are all the corresponding TP and DP ranks? + # our seq+data parallel group should allow us to reduce grads across both + seq_dp_group = torch.distributed.new_group(ranks=tpdp_group) + if rank == 0: + print(f"MPU Seq+DP:", tpdp_group) + if rank in tpdp_group: + _SEQ_DATA_PARALLEL_GROUP = seq_dp_group + + else: + raise NotImplementedError("Assume that we always pass a topology object to mpu.initialize_model_parallel .") + for i in range(world_size // model_parallel_size): + ranks = range(i * model_parallel_size, (i + 1) * model_parallel_size) + group = torch.distributed.new_group(ranks) + seq_group = torch.distributed.new_group(ranks) + if i == (rank // model_parallel_size): + _SEQ_DATA_PARALLEL_GROUP = seq_group global _FP32_ALLREDUCE assert _FP32_ALLREDUCE is None, "fp32_allreduce is already initialized" @@ -252,7 +302,7 @@ def get_topology(): def get_pipe_parallel_group(): """Get the pipe parallel group the caller rank belongs to.""" - assert _PIPE_PARALLEL_GROUP is not None, "data parallel group is not initialized" + assert _PIPE_PARALLEL_GROUP is not None, "pipw parallel group is not initialized" return _PIPE_PARALLEL_GROUP @@ -298,6 +348,41 @@ def get_tensor_model_parallel_rank(): return get_model_parallel_rank() +def get_sequence_parallel_group(): + """Get the sequence parallel group the caller rank belongs to.""" + assert _SEQ_PARALLEL_GROUP is not None, "sequence parallel group is not initialized" + return _SEQ_PARALLEL_GROUP + + +def get_sequence_data_parallel_group(): + """Get the sequence+data parallel group the caller rank belongs to.""" + assert _SEQ_DATA_PARALLEL_GROUP is not None, "sequence+data parallel group is not initialized" + return _SEQ_DATA_PARALLEL_GROUP + + +def get_sequence_parallel_world_size(): + """Return world size for the sequence parallel group.""" + return torch.distributed.get_world_size(group=get_sequence_parallel_group()) + + +def get_sequence_data_parallel_world_size(): + """Return world size for the sequence parallel group.""" + return torch.distributed.get_world_size(group=get_sequence_data_parallel_group()) + + +def get_sequence_parallel_rank(): + """Return my rank for the sequence parallel group.""" + return torch.distributed.get_rank(group=get_sequence_parallel_group()) + + +def get_sequence_data_parallel_rank(): + """Return my rank for the sequence+data parallel group.""" + return torch.distributed.get_rank(group=get_sequence_data_parallel_group()) + + +# TODO: other groups use a set_... and ..._src_rank fn. Do we need these too for seq, seq+data? + + def destroy_model_parallel(): """Set the groups to none.""" global _MODEL_PARALLEL_GROUP @@ -308,6 +393,12 @@ def destroy_model_parallel(): _PIPE_PARALLEL_GROUP = None global _IO_PARALLEL_GROUP _IO_PARALLEL_GROUP = None + + global _SEQ_PARALLEL_GROUP + _SEQ_PARALLEL_GROUP = None + global _SEQ_DATA_PARALLEL_GROUP + _SEQ_DATA_PARALLEL_GROUP = None + global _MPU_WORLD_SIZE global _MPU_RANK _MPU_WORLD_SIZE = None