Skip to content

Commit

Permalink
MOE: Update global norm calculation for pipeline
Browse files Browse the repository at this point in the history
When using MoE with MoE-TP disabled, use pipeline parallel group to max or sum
MoE gradients.

This also fixes the behavior for following configuration:
No pipeline, TP enabled, MoE TP disabled.

Signed-off-by: Moshe Island <[email protected]>
  • Loading branch information
misland-habana committed Mar 30, 2024
1 parent 5831321 commit 7a5e888
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 3 deletions.
23 changes: 20 additions & 3 deletions deepspeed/runtime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from torch import inf

from deepspeed.utils import groups, logger
from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank
from deepspeed.utils.bwc import (bwc_tensor_model_parallel_rank, bwc_pipeline_parallel_world_size,
bwc_pipeline_parallel_group)
from deepspeed.runtime.constants import PIPE_REPLICATED
from numpy import prod
from deepspeed.accelerator import get_accelerator
Expand Down Expand Up @@ -856,8 +857,16 @@ def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=F
if norm_type == inf:
total_norm = max(t.data.abs().max() for t in input_tensors)
total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
# Max across model parallel
if mpu is not None:
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group())
# For MoE grads, max over model parallel only if MoE-TP is enabled
if moe_ep_group is None or groups._get_expert_model_parallel_world_size() > 1:
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group())
# If MoE grads and MoE-TP disabled, max over pipeline parallel
elif bwc_pipeline_parallel_world_size(mpu) > 1:
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=bwc_pipeline_parallel_group(mpu))

# MoE grads: max across expert parallel group
if moe_ep_group is not None:
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=moe_ep_group)
total_norm = total_norm_cuda[0].item()
Expand All @@ -880,8 +889,16 @@ def _norm_tensors(tensor_list, _compute_buffer, _norm_type):
total_norm = sum([t.data.float().norm(norm_type).item()**norm_type for t in input_tensors])

total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]).detach()
# Sum across model parallel
if mpu is not None:
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group())
# For MoE grads, sum over model parallel only if MoE-TP is enabled
if moe_ep_group is None or groups._get_expert_model_parallel_world_size() > 1:
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group())
# If MoE grads and MoE-TP disabled, sum over pipeline parallel
elif bwc_pipeline_parallel_world_size(mpu) > 1:
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=bwc_pipeline_parallel_group(mpu))

# MoE grads: sum across expert parallel group
if moe_ep_group is not None:
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=moe_ep_group)

Expand Down
26 changes: 26 additions & 0 deletions deepspeed/utils/bwc.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,29 @@ def bwc_tensor_model_parallel_group(mpu=None):
else:
# Deprecated Megatron and DeepSpeed convention
return mpu.get_model_parallel_group()


def bwc_pipeline_parallel_world_size(mpu=None):
"""Backwards-compatible way of querying the pipeline parallel world size."""
world_size = 1
if mpu is not None:
if hasattr(mpu, 'get_pipeline_model_parallel_world_size'):
# New Megatron and DeepSpeed convention (post pipeline-parallelism release)
world_size = mpu.get_pipeline_model_parallel_world_size()
elif hasattr(mpu, 'get_pipe_parallel_world_size'):
# DeepSpeed Topology
world_size = mpu.get_pipe_parallel_world_size()
return world_size


def bwc_pipeline_parallel_group(mpu=None):
"""Backwards-compatible way of querying the pipeline parallel group."""
if mpu is None:
return None
if hasattr(mpu, 'get_pipeline_model_parallel_group'):
# Megatron
return mpu.get_pipeline_model_parallel_group()
elif hasattr(mpu, 'get_pipe_parallel_group'):
# DeepSpeed Topology
return mpu.get_pipe_parallel_group()
assert False, 'mpu does not support pipeline parallel group'

0 comments on commit 7a5e888

Please sign in to comment.