Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support MoE for pipeline models #5338

Merged
merged 11 commits into from
Apr 8, 2024
2 changes: 1 addition & 1 deletion deepspeed/moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(self,

experts = Experts(expert, self.num_local_experts, self.expert_group_name)
self.deepspeed_moe = MOELayer(TopKGate(hidden_size, num_experts, k, capacity_factor, eval_capacity_factor,
min_capacity, noisy_gate_policy, drop_tokens, use_rts,
min_capacity, noisy_gate_policy, drop_tokens, use_rts, None,
top2_2nd_expert_sampling),
experts,
self.expert_group_name,
Expand Down
16 changes: 9 additions & 7 deletions deepspeed/moe/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

import torch
import deepspeed
from deepspeed.utils.bwc import (bwc_tensor_model_parallel_world_size, bwc_tensor_model_parallel_rank,
bwc_tensor_model_parallel_group)


def _gather_tokens(input_, dim=0):
Expand All @@ -31,11 +33,11 @@ def _gather_tokens(input_, dim=0):

input_ = input_.contiguous()
# Size and dimension.
rank = mpu.get_tensor_model_parallel_rank()
rank = bwc_tensor_model_parallel_rank(mpu)

tensor_list = [torch.empty_like(input_) for _ in range(mpu.get_tensor_model_parallel_world_size())]
tensor_list = [torch.empty_like(input_) for _ in range(bwc_tensor_model_parallel_world_size(mpu))]
tensor_list[rank] = input_
deepspeed.comm.all_gather(tensor_list, input_, group=mpu.get_tensor_model_parallel_group())
deepspeed.comm.all_gather(tensor_list, input_, group=bwc_tensor_model_parallel_group(mpu))

# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=dim).contiguous()
Expand All @@ -47,8 +49,8 @@ def _drop_tokens(input_, dim=0):
"""Divide a tensor among the tensor parallel ranks"""
mpu = deepspeed.utils.groups.mpu

total_chunks = mpu.get_tensor_model_parallel_world_size()
this_chunk = mpu.get_tensor_model_parallel_rank()
total_chunks = bwc_tensor_model_parallel_world_size(mpu)
this_chunk = bwc_tensor_model_parallel_rank(mpu)
assert input_.shape[
dim] % total_chunks == 0, f"input dimension {dim} ({input_.shape[dim]}) is not divisible by tensor parallel world size ({total_chunks})"
chunk_size = input_.shape[dim] // total_chunks
Expand Down Expand Up @@ -92,15 +94,15 @@ def backward(ctx, input_):

def gather_tokens(input_, dim=0):
mpu = deepspeed.utils.groups.mpu
if mpu is None or mpu.get_tensor_model_parallel_world_size() == 1:
if mpu is None or bwc_tensor_model_parallel_world_size(mpu) == 1:
# no tensor parallelism for non-experts
return input_
return _GatherTokens.apply(input_, dim)


def drop_tokens(input_, dim=0):
mpu = deepspeed.utils.groups.mpu
if mpu is None or mpu.get_tensor_model_parallel_world_size() == 1:
if mpu is None or bwc_tensor_model_parallel_world_size(mpu) == 1:
# no tensor parallelism for non-experts
return input_
return _DropTokens.apply(input_, dim)
30 changes: 21 additions & 9 deletions deepspeed/moe/sharded_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

from deepspeed.utils.timer import SynchronizedWallClockTimer
from deepspeed.utils import logger
from typing import Callable, Dict, TYPE_CHECKING, Any, Optional, Tuple
from deepspeed.utils.bwc import bwc_tensor_model_parallel_world_size
from typing import Callable, Dict, TYPE_CHECKING, Any, Optional, Tuple, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -184,6 +185,7 @@ def top1gating(logits: Tensor,
noisy_gate_policy: Optional[str] = None,
drop_tokens: bool = True,
use_rts: bool = True,
ep_group: Union[torch.distributed.ProcessGroup, None] = None,
use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Implements Top1Gating on logits."""
if noisy_gate_policy == 'RSample':
Expand All @@ -209,12 +211,13 @@ def top1gating(logits: Tensor,
# if we don't want to drop any tokens
if not drop_tokens:
new_capacity = torch.max(exp_counts).to(logits.device)
# Communicate across all processes to pick the maximum capacity.
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group())
# Communicate across expert processes to pick the maximum capacity.
if ep_group is not None:
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=ep_group)
if groups._get_expert_model_parallel_world_size() == 1:
# If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'.
# This is since we are going to activate drop_tokens() to drop duplicate tokens.
tp = 1 if groups.mpu is None else groups.mpu.get_tensor_model_parallel_world_size()
tp = 1 if groups.mpu is None else bwc_tensor_model_parallel_world_size(mpu=groups.mpu)
new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype)
# Make sure the capacity value does not exceed the number of tokens.
capacity = min(new_capacity, torch.tensor(mask1.size(0)))
Expand Down Expand Up @@ -286,6 +289,7 @@ def top2gating(logits: Tensor,
capacity_factor: float,
min_capacity: int,
drop_tokens: bool = True,
ep_group: Union[torch.distributed.ProcessGroup, None] = None,
top2_2nd_expert_sampling: bool = True) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Implements Top2Gating on logits."""
# everything is in fp32 in this function
Expand Down Expand Up @@ -328,11 +332,12 @@ def top2gating(logits: Tensor,
else:
# Do not drop tokens - set capacity according to current expert assignments
new_capacity = torch.max(exp_counts)
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group())
if ep_group is not None:
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=ep_group)
if groups._get_expert_model_parallel_world_size() == 1:
# If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'.
# This is since we are going to activate drop_tokens() to drop duplicate tokens.
tp = 1 if groups.mpu is None else groups.mpu.get_tensor_model_parallel_world_size()
tp = 1 if groups.mpu is None else bwc_tensor_model_parallel_world_size(mpu=groups.mpu)
new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype)
capacity = new_capacity

Expand Down Expand Up @@ -376,7 +381,7 @@ class TopKGate(Module):
Args:
model_dim (int):
size of model embedding dimension
num_experts (ints):
num_experts (int):
number of experts in model
"""

Expand All @@ -392,13 +397,15 @@ def __init__(self,
noisy_gate_policy: Optional[str] = None,
drop_tokens: bool = True,
use_rts: bool = True,
ep_group: Union[torch.distributed.ProcessGroup, None] = None,
top2_2nd_expert_sampling: bool = True) -> None:
super().__init__()

# Only top-1 and top-2 are supported at the moment.
if k != 1 and k != 2:
raise ValueError('Only top-1 and top-2 gatings are supported.')
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False)
self.ep_group = ep_group
self.k = k
self.capacity_factor = capacity_factor
self.eval_capacity_factor = eval_capacity_factor
Expand All @@ -411,6 +418,10 @@ def __init__(self,
self.use_rts = use_rts
self.top2_2nd_expert_sampling = top2_2nd_expert_sampling

def _set_ep_group(self, ep_group):
assert self.ep_group is None, f'Attempting to override an existing ep_group'
self.ep_group = ep_group

def forward(self,
input: torch.Tensor,
used_token: torch.Tensor = None,
Expand All @@ -428,11 +439,11 @@ def forward(self,
if self.k == 1:
gate_output = top1gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
self.min_capacity, used_token, self.noisy_gate_policy if self.training else None,
self.drop_tokens, self.use_rts, use_tutel)
self.drop_tokens, self.use_rts, self.ep_group, use_tutel)

else:
gate_output = top2gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
self.min_capacity, self.drop_tokens, self.top2_2nd_expert_sampling)
self.min_capacity, self.drop_tokens, self.ep_group, self.top2_2nd_expert_sampling)

if self.wall_clock_breakdown:
self.timers(TOPK_GATE_TIMER).stop()
Expand Down Expand Up @@ -492,6 +503,7 @@ def __init__(self,

def _set_ep_group(self, ep_group):
self.ep_group = ep_group
self.gate._set_ep_group(ep_group)

def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:

Expand Down
2 changes: 1 addition & 1 deletion deepspeed/ops/transformer/inference/moe_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def __init__(self,
self.moe_gate = TopKGate(self.config.hidden_size, self.config.global_experts, self.config.k,
self.config.capacity_factor, self.config.eval_capacity_factor,
self.config.min_capacity, self.config.noisy_gate_policy, self.config.drop_tokens,
self.config.use_rts)
self.config.use_rts, self.ep_group)

self.ep_group = ep_group
self.mp_group = mp_group
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@

from deepspeed.runtime.config import DeepSpeedConfig
from deepspeed.utils import logger
from deepspeed.runtime.utils import copy_to_device, move_to_device, see_memory_usage, bwc_tensor_model_parallel_rank
from deepspeed.runtime.utils import copy_to_device, move_to_device, see_memory_usage
from deepspeed.utils.timer import SynchronizedWallClockTimer as Timers, FORWARD_GLOBAL_TIMER
from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank
from deepspeed.accelerator import get_accelerator

# DeepSpeed Checkpointing Enabled or Disabled
Expand Down
6 changes: 3 additions & 3 deletions deepspeed/runtime/bf16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
from packaging import version as pkg_version
from deepspeed.git_version_info import version
from deepspeed.runtime.utils import (get_global_norm_of_tensors, clip_tensors_by_global_norm, DummyOptim,
align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank,
is_model_parallel_parameter, see_memory_usage, graph_process,
get_norm_with_moe_layers)
align_dense_tensors, all_gather_dp_groups, is_model_parallel_parameter,
see_memory_usage, graph_process, get_norm_with_moe_layers)
from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address, groups
from deepspeed.moe.utils import is_moe_param, is_moe_param_group
from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank
from deepspeed.checkpoint import enable_universal_checkpoint
from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE,
SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS,
Expand Down
7 changes: 5 additions & 2 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3228,9 +3228,12 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa

# Load flow uses below saved file for model parameters, RNG and more
if groups._get_data_parallel_rank() == 0:
# get non-moe parameters
# Get non-moe parameters
# Classes DeepSpeedEngine and PipelineEngine have different behavior for method module_state_dict.
# DeepSpeedEngine returns the state dict, where PipelineEngine saves the state dict and returns None.
# We need to get the state dict, therefore, call to DeepSpeedEngine (base class for PipelineEngine)
model_state_dict = self._get_non_moe_state_dict(
self.module_state_dict(exclude_frozen_parameters=exclude_frozen_parameters))
DeepSpeedEngine.module_state_dict(self, exclude_frozen_parameters=exclude_frozen_parameters))

# TODO: update num experts info,.. in checkpoint
state = {
Expand Down
Loading
Loading