Skip to content

Commit

Permalink
Support MoE for pipeline models (microsoft#5338)
Browse files Browse the repository at this point in the history
This PR enhances DeepSpeed to support MoE for pipeline models (e.g.
GPTModelPipe from Megatron-DeepSpeed).
Main changes:

- Enhance expert groups creation for pipeline (enhance both flavors:
DP/PP/EP and DP/TP/PP/EP)
- Fix MoE save/load checkpoint for PipelineModule based models.
- Display MoE loss for PipelineModule based models.
- Support gradients reduce for BF16_Optimizer for
PipelineModule.<br>Note that same commit also fixes gradients reduction
error when using Megatron-DeepSpeed GPTModelPipe with BF16_Optimizer
also for a dense (no MOE) model.
- When using no-drop tokens, all-reduce the capacity (op=max) using
expert parallel group instead of world group

---------

Signed-off-by: Moshe Island <[email protected]>
Co-authored-by: Moshe Island <[email protected]>
  • Loading branch information
2 people authored and rraminen committed May 9, 2024
1 parent c29a6f3 commit c8065c3
Show file tree
Hide file tree
Showing 14 changed files with 326 additions and 141 deletions.
2 changes: 1 addition & 1 deletion deepspeed/moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(self,

experts = Experts(expert, self.num_local_experts, self.expert_group_name)
self.deepspeed_moe = MOELayer(TopKGate(hidden_size, num_experts, k, capacity_factor, eval_capacity_factor,
min_capacity, noisy_gate_policy, drop_tokens, use_rts,
min_capacity, noisy_gate_policy, drop_tokens, use_rts, None,
top2_2nd_expert_sampling),
experts,
self.expert_group_name,
Expand Down
16 changes: 9 additions & 7 deletions deepspeed/moe/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

import torch
import deepspeed
from deepspeed.utils.bwc import (bwc_tensor_model_parallel_world_size, bwc_tensor_model_parallel_rank,
bwc_tensor_model_parallel_group)


def _gather_tokens(input_, dim=0):
Expand All @@ -31,11 +33,11 @@ def _gather_tokens(input_, dim=0):

input_ = input_.contiguous()
# Size and dimension.
rank = mpu.get_tensor_model_parallel_rank()
rank = bwc_tensor_model_parallel_rank(mpu)

tensor_list = [torch.empty_like(input_) for _ in range(mpu.get_tensor_model_parallel_world_size())]
tensor_list = [torch.empty_like(input_) for _ in range(bwc_tensor_model_parallel_world_size(mpu))]
tensor_list[rank] = input_
deepspeed.comm.all_gather(tensor_list, input_, group=mpu.get_tensor_model_parallel_group())
deepspeed.comm.all_gather(tensor_list, input_, group=bwc_tensor_model_parallel_group(mpu))

# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=dim).contiguous()
Expand All @@ -47,8 +49,8 @@ def _drop_tokens(input_, dim=0):
"""Divide a tensor among the tensor parallel ranks"""
mpu = deepspeed.utils.groups.mpu

total_chunks = mpu.get_tensor_model_parallel_world_size()
this_chunk = mpu.get_tensor_model_parallel_rank()
total_chunks = bwc_tensor_model_parallel_world_size(mpu)
this_chunk = bwc_tensor_model_parallel_rank(mpu)
assert input_.shape[
dim] % total_chunks == 0, f"input dimension {dim} ({input_.shape[dim]}) is not divisible by tensor parallel world size ({total_chunks})"
chunk_size = input_.shape[dim] // total_chunks
Expand Down Expand Up @@ -92,15 +94,15 @@ def backward(ctx, input_):

def gather_tokens(input_, dim=0):
mpu = deepspeed.utils.groups.mpu
if mpu is None or mpu.get_tensor_model_parallel_world_size() == 1:
if mpu is None or bwc_tensor_model_parallel_world_size(mpu) == 1:
# no tensor parallelism for non-experts
return input_
return _GatherTokens.apply(input_, dim)


def drop_tokens(input_, dim=0):
mpu = deepspeed.utils.groups.mpu
if mpu is None or mpu.get_tensor_model_parallel_world_size() == 1:
if mpu is None or bwc_tensor_model_parallel_world_size(mpu) == 1:
# no tensor parallelism for non-experts
return input_
return _DropTokens.apply(input_, dim)
30 changes: 21 additions & 9 deletions deepspeed/moe/sharded_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

from deepspeed.utils.timer import SynchronizedWallClockTimer
from deepspeed.utils import logger
from typing import Callable, Dict, TYPE_CHECKING, Any, Optional, Tuple
from deepspeed.utils.bwc import bwc_tensor_model_parallel_world_size
from typing import Callable, Dict, TYPE_CHECKING, Any, Optional, Tuple, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -184,6 +185,7 @@ def top1gating(logits: Tensor,
noisy_gate_policy: Optional[str] = None,
drop_tokens: bool = True,
use_rts: bool = True,
ep_group: Union[torch.distributed.ProcessGroup, None] = None,
use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Implements Top1Gating on logits."""
if noisy_gate_policy == 'RSample':
Expand All @@ -209,12 +211,13 @@ def top1gating(logits: Tensor,
# if we don't want to drop any tokens
if not drop_tokens:
new_capacity = torch.max(exp_counts).to(logits.device)
# Communicate across all processes to pick the maximum capacity.
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group())
# Communicate across expert processes to pick the maximum capacity.
if ep_group is not None:
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=ep_group)
if groups._get_expert_model_parallel_world_size() == 1:
# If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'.
# This is since we are going to activate drop_tokens() to drop duplicate tokens.
tp = 1 if groups.mpu is None else groups.mpu.get_tensor_model_parallel_world_size()
tp = 1 if groups.mpu is None else bwc_tensor_model_parallel_world_size(mpu=groups.mpu)
new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype)
# Make sure the capacity value does not exceed the number of tokens.
capacity = min(new_capacity, torch.tensor(mask1.size(0)))
Expand Down Expand Up @@ -286,6 +289,7 @@ def top2gating(logits: Tensor,
capacity_factor: float,
min_capacity: int,
drop_tokens: bool = True,
ep_group: Union[torch.distributed.ProcessGroup, None] = None,
top2_2nd_expert_sampling: bool = True) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Implements Top2Gating on logits."""
# everything is in fp32 in this function
Expand Down Expand Up @@ -328,11 +332,12 @@ def top2gating(logits: Tensor,
else:
# Do not drop tokens - set capacity according to current expert assignments
new_capacity = torch.max(exp_counts)
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group())
if ep_group is not None:
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=ep_group)
if groups._get_expert_model_parallel_world_size() == 1:
# If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'.
# This is since we are going to activate drop_tokens() to drop duplicate tokens.
tp = 1 if groups.mpu is None else groups.mpu.get_tensor_model_parallel_world_size()
tp = 1 if groups.mpu is None else bwc_tensor_model_parallel_world_size(mpu=groups.mpu)
new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype)
capacity = new_capacity

Expand Down Expand Up @@ -376,7 +381,7 @@ class TopKGate(Module):
Args:
model_dim (int):
size of model embedding dimension
num_experts (ints):
num_experts (int):
number of experts in model
"""

Expand All @@ -392,13 +397,15 @@ def __init__(self,
noisy_gate_policy: Optional[str] = None,
drop_tokens: bool = True,
use_rts: bool = True,
ep_group: Union[torch.distributed.ProcessGroup, None] = None,
top2_2nd_expert_sampling: bool = True) -> None:
super().__init__()

# Only top-1 and top-2 are supported at the moment.
if k != 1 and k != 2:
raise ValueError('Only top-1 and top-2 gatings are supported.')
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False)
self.ep_group = ep_group
self.k = k
self.capacity_factor = capacity_factor
self.eval_capacity_factor = eval_capacity_factor
Expand All @@ -411,6 +418,10 @@ def __init__(self,
self.use_rts = use_rts
self.top2_2nd_expert_sampling = top2_2nd_expert_sampling

def _set_ep_group(self, ep_group):
assert self.ep_group is None, f'Attempting to override an existing ep_group'
self.ep_group = ep_group

def forward(self,
input: torch.Tensor,
used_token: torch.Tensor = None,
Expand All @@ -428,11 +439,11 @@ def forward(self,
if self.k == 1:
gate_output = top1gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
self.min_capacity, used_token, self.noisy_gate_policy if self.training else None,
self.drop_tokens, self.use_rts, use_tutel)
self.drop_tokens, self.use_rts, self.ep_group, use_tutel)

else:
gate_output = top2gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
self.min_capacity, self.drop_tokens, self.top2_2nd_expert_sampling)
self.min_capacity, self.drop_tokens, self.ep_group, self.top2_2nd_expert_sampling)

if self.wall_clock_breakdown:
self.timers(TOPK_GATE_TIMER).stop()
Expand Down Expand Up @@ -492,6 +503,7 @@ def __init__(self,

def _set_ep_group(self, ep_group):
self.ep_group = ep_group
self.gate._set_ep_group(ep_group)

def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:

Expand Down
2 changes: 1 addition & 1 deletion deepspeed/ops/transformer/inference/moe_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def __init__(self,
self.moe_gate = TopKGate(self.config.hidden_size, self.config.global_experts, self.config.k,
self.config.capacity_factor, self.config.eval_capacity_factor,
self.config.min_capacity, self.config.noisy_gate_policy, self.config.drop_tokens,
self.config.use_rts)
self.config.use_rts, self.ep_group)

self.ep_group = ep_group
self.mp_group = mp_group
Expand Down
3 changes: 2 additions & 1 deletion deepspeed/runtime/activation_checkpointing/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@

from deepspeed.runtime.config import DeepSpeedConfig
from deepspeed.utils import logger
from deepspeed.runtime.utils import copy_to_device, move_to_device, see_memory_usage, bwc_tensor_model_parallel_rank
from deepspeed.runtime.utils import copy_to_device, move_to_device, see_memory_usage
from deepspeed.utils.timer import SynchronizedWallClockTimer as Timers, FORWARD_GLOBAL_TIMER
from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank
from deepspeed.accelerator import get_accelerator

# DeepSpeed Checkpointing Enabled or Disabled
Expand Down
6 changes: 3 additions & 3 deletions deepspeed/runtime/bf16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
from packaging import version as pkg_version
from deepspeed.git_version_info import version
from deepspeed.runtime.utils import (get_global_norm_of_tensors, clip_tensors_by_global_norm, DummyOptim,
align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank,
is_model_parallel_parameter, see_memory_usage, graph_process,
get_norm_with_moe_layers)
align_dense_tensors, all_gather_dp_groups, is_model_parallel_parameter,
see_memory_usage, graph_process, get_norm_with_moe_layers)
from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address, groups
from deepspeed.moe.utils import is_moe_param, is_moe_param_group
from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank
from deepspeed.checkpoint import enable_universal_checkpoint
from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE,
SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS,
Expand Down
7 changes: 5 additions & 2 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3228,9 +3228,12 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa

# Load flow uses below saved file for model parameters, RNG and more
if groups._get_data_parallel_rank() == 0:
# get non-moe parameters
# Get non-moe parameters
# Classes DeepSpeedEngine and PipelineEngine have different behavior for method module_state_dict.
# DeepSpeedEngine returns the state dict, where PipelineEngine saves the state dict and returns None.
# We need to get the state dict, therefore, call to DeepSpeedEngine (base class for PipelineEngine)
model_state_dict = self._get_non_moe_state_dict(
self.module_state_dict(exclude_frozen_parameters=exclude_frozen_parameters))
DeepSpeedEngine.module_state_dict(self, exclude_frozen_parameters=exclude_frozen_parameters))

# TODO: update num experts info,.. in checkpoint
state = {
Expand Down
Loading

0 comments on commit c8065c3

Please sign in to comment.