From e1830b4ef76f8be22ebcecc60bfde42614be2376 Mon Sep 17 00:00:00 2001 From: Moshe Island Date: Mon, 8 Apr 2024 18:35:53 +0300 Subject: [PATCH] Support MoE for pipeline models (#5338) This PR enhances DeepSpeed to support MoE for pipeline models (e.g. GPTModelPipe from Megatron-DeepSpeed). Main changes: - Enhance expert groups creation for pipeline (enhance both flavors: DP/PP/EP and DP/TP/PP/EP) - Fix MoE save/load checkpoint for PipelineModule based models. - Display MoE loss for PipelineModule based models. - Support gradients reduce for BF16_Optimizer for PipelineModule.
Note that same commit also fixes gradients reduction error when using Megatron-DeepSpeed GPTModelPipe with BF16_Optimizer also for a dense (no MOE) model. - When using no-drop tokens, all-reduce the capacity (op=max) using expert parallel group instead of world group --------- Signed-off-by: Moshe Island Co-authored-by: Moshe Island --- deepspeed/moe/layer.py | 2 +- deepspeed/moe/mappings.py | 16 +-- deepspeed/moe/sharded_moe.py | 30 +++-- .../transformer/inference/moe_inference.py | 2 +- .../activation_checkpointing/checkpointing.py | 3 +- deepspeed/runtime/bf16_optimizer.py | 6 +- deepspeed/runtime/engine.py | 7 +- deepspeed/runtime/pipe/engine.py | 114 ++++++++++++++---- deepspeed/runtime/pipe/module.py | 7 ++ deepspeed/runtime/utils.py | 60 +++------ deepspeed/runtime/zero/stage_1_and_2.py | 6 +- deepspeed/utils/bwc.py | 104 ++++++++++++++++ deepspeed/utils/groups.py | 108 +++++++++-------- tests/unit/utils/test_groups.py | 2 +- 14 files changed, 326 insertions(+), 141 deletions(-) create mode 100644 deepspeed/utils/bwc.py diff --git a/deepspeed/moe/layer.py b/deepspeed/moe/layer.py index dfa9fcf4f464..6777788ab885 100644 --- a/deepspeed/moe/layer.py +++ b/deepspeed/moe/layer.py @@ -71,7 +71,7 @@ def __init__(self, experts = Experts(expert, self.num_local_experts, self.expert_group_name) self.deepspeed_moe = MOELayer(TopKGate(hidden_size, num_experts, k, capacity_factor, eval_capacity_factor, - min_capacity, noisy_gate_policy, drop_tokens, use_rts, + min_capacity, noisy_gate_policy, drop_tokens, use_rts, None, top2_2nd_expert_sampling), experts, self.expert_group_name, diff --git a/deepspeed/moe/mappings.py b/deepspeed/moe/mappings.py index 6c501ea6503a..b8a06274343a 100644 --- a/deepspeed/moe/mappings.py +++ b/deepspeed/moe/mappings.py @@ -23,6 +23,8 @@ import torch import deepspeed +from deepspeed.utils.bwc import (bwc_tensor_model_parallel_world_size, bwc_tensor_model_parallel_rank, + bwc_tensor_model_parallel_group) def _gather_tokens(input_, dim=0): @@ -31,11 +33,11 @@ def _gather_tokens(input_, dim=0): input_ = input_.contiguous() # Size and dimension. - rank = mpu.get_tensor_model_parallel_rank() + rank = bwc_tensor_model_parallel_rank(mpu) - tensor_list = [torch.empty_like(input_) for _ in range(mpu.get_tensor_model_parallel_world_size())] + tensor_list = [torch.empty_like(input_) for _ in range(bwc_tensor_model_parallel_world_size(mpu))] tensor_list[rank] = input_ - deepspeed.comm.all_gather(tensor_list, input_, group=mpu.get_tensor_model_parallel_group()) + deepspeed.comm.all_gather(tensor_list, input_, group=bwc_tensor_model_parallel_group(mpu)) # Note: torch.cat already creates a contiguous tensor. output = torch.cat(tensor_list, dim=dim).contiguous() @@ -47,8 +49,8 @@ def _drop_tokens(input_, dim=0): """Divide a tensor among the tensor parallel ranks""" mpu = deepspeed.utils.groups.mpu - total_chunks = mpu.get_tensor_model_parallel_world_size() - this_chunk = mpu.get_tensor_model_parallel_rank() + total_chunks = bwc_tensor_model_parallel_world_size(mpu) + this_chunk = bwc_tensor_model_parallel_rank(mpu) assert input_.shape[ dim] % total_chunks == 0, f"input dimension {dim} ({input_.shape[dim]}) is not divisible by tensor parallel world size ({total_chunks})" chunk_size = input_.shape[dim] // total_chunks @@ -92,7 +94,7 @@ def backward(ctx, input_): def gather_tokens(input_, dim=0): mpu = deepspeed.utils.groups.mpu - if mpu is None or mpu.get_tensor_model_parallel_world_size() == 1: + if mpu is None or bwc_tensor_model_parallel_world_size(mpu) == 1: # no tensor parallelism for non-experts return input_ return _GatherTokens.apply(input_, dim) @@ -100,7 +102,7 @@ def gather_tokens(input_, dim=0): def drop_tokens(input_, dim=0): mpu = deepspeed.utils.groups.mpu - if mpu is None or mpu.get_tensor_model_parallel_world_size() == 1: + if mpu is None or bwc_tensor_model_parallel_world_size(mpu) == 1: # no tensor parallelism for non-experts return input_ return _DropTokens.apply(input_, dim) diff --git a/deepspeed/moe/sharded_moe.py b/deepspeed/moe/sharded_moe.py index e685a0f574f3..bd2782279c01 100644 --- a/deepspeed/moe/sharded_moe.py +++ b/deepspeed/moe/sharded_moe.py @@ -17,7 +17,8 @@ from deepspeed.utils.timer import SynchronizedWallClockTimer from deepspeed.utils import logger -from typing import Callable, Dict, TYPE_CHECKING, Any, Optional, Tuple +from deepspeed.utils.bwc import bwc_tensor_model_parallel_world_size +from typing import Callable, Dict, TYPE_CHECKING, Any, Optional, Tuple, Union import torch from torch import Tensor @@ -184,6 +185,7 @@ def top1gating(logits: Tensor, noisy_gate_policy: Optional[str] = None, drop_tokens: bool = True, use_rts: bool = True, + ep_group: Union[torch.distributed.ProcessGroup, None] = None, use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Implements Top1Gating on logits.""" if noisy_gate_policy == 'RSample': @@ -209,12 +211,13 @@ def top1gating(logits: Tensor, # if we don't want to drop any tokens if not drop_tokens: new_capacity = torch.max(exp_counts).to(logits.device) - # Communicate across all processes to pick the maximum capacity. - dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group()) + # Communicate across expert processes to pick the maximum capacity. + if ep_group is not None: + dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=ep_group) if groups._get_expert_model_parallel_world_size() == 1: # If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'. # This is since we are going to activate drop_tokens() to drop duplicate tokens. - tp = 1 if groups.mpu is None else groups.mpu.get_tensor_model_parallel_world_size() + tp = 1 if groups.mpu is None else bwc_tensor_model_parallel_world_size(mpu=groups.mpu) new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype) # Make sure the capacity value does not exceed the number of tokens. capacity = min(new_capacity, torch.tensor(mask1.size(0))) @@ -286,6 +289,7 @@ def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int, drop_tokens: bool = True, + ep_group: Union[torch.distributed.ProcessGroup, None] = None, top2_2nd_expert_sampling: bool = True) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Implements Top2Gating on logits.""" # everything is in fp32 in this function @@ -328,11 +332,12 @@ def top2gating(logits: Tensor, else: # Do not drop tokens - set capacity according to current expert assignments new_capacity = torch.max(exp_counts) - dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group()) + if ep_group is not None: + dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=ep_group) if groups._get_expert_model_parallel_world_size() == 1: # If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'. # This is since we are going to activate drop_tokens() to drop duplicate tokens. - tp = 1 if groups.mpu is None else groups.mpu.get_tensor_model_parallel_world_size() + tp = 1 if groups.mpu is None else bwc_tensor_model_parallel_world_size(mpu=groups.mpu) new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype) capacity = new_capacity @@ -376,7 +381,7 @@ class TopKGate(Module): Args: model_dim (int): size of model embedding dimension - num_experts (ints): + num_experts (int): number of experts in model """ @@ -392,6 +397,7 @@ def __init__(self, noisy_gate_policy: Optional[str] = None, drop_tokens: bool = True, use_rts: bool = True, + ep_group: Union[torch.distributed.ProcessGroup, None] = None, top2_2nd_expert_sampling: bool = True) -> None: super().__init__() @@ -399,6 +405,7 @@ def __init__(self, if k != 1 and k != 2: raise ValueError('Only top-1 and top-2 gatings are supported.') self.wg = torch.nn.Linear(model_dim, num_experts, bias=False) + self.ep_group = ep_group self.k = k self.capacity_factor = capacity_factor self.eval_capacity_factor = eval_capacity_factor @@ -411,6 +418,10 @@ def __init__(self, self.use_rts = use_rts self.top2_2nd_expert_sampling = top2_2nd_expert_sampling + def _set_ep_group(self, ep_group): + assert self.ep_group is None, f'Attempting to override an existing ep_group' + self.ep_group = ep_group + def forward(self, input: torch.Tensor, used_token: torch.Tensor = None, @@ -428,11 +439,11 @@ def forward(self, if self.k == 1: gate_output = top1gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor, self.min_capacity, used_token, self.noisy_gate_policy if self.training else None, - self.drop_tokens, self.use_rts, use_tutel) + self.drop_tokens, self.use_rts, self.ep_group, use_tutel) else: gate_output = top2gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor, - self.min_capacity, self.drop_tokens, self.top2_2nd_expert_sampling) + self.min_capacity, self.drop_tokens, self.ep_group, self.top2_2nd_expert_sampling) if self.wall_clock_breakdown: self.timers(TOPK_GATE_TIMER).stop() @@ -492,6 +503,7 @@ def __init__(self, def _set_ep_group(self, ep_group): self.ep_group = ep_group + self.gate._set_ep_group(ep_group) def forward(self, *input: Tensor, **kwargs: Any) -> Tensor: diff --git a/deepspeed/ops/transformer/inference/moe_inference.py b/deepspeed/ops/transformer/inference/moe_inference.py index 90bfcae81bf2..8766b65e866d 100644 --- a/deepspeed/ops/transformer/inference/moe_inference.py +++ b/deepspeed/ops/transformer/inference/moe_inference.py @@ -226,7 +226,7 @@ def __init__(self, self.moe_gate = TopKGate(self.config.hidden_size, self.config.global_experts, self.config.k, self.config.capacity_factor, self.config.eval_capacity_factor, self.config.min_capacity, self.config.noisy_gate_policy, self.config.drop_tokens, - self.config.use_rts) + self.config.use_rts, self.ep_group) self.ep_group = ep_group self.mp_group = mp_group diff --git a/deepspeed/runtime/activation_checkpointing/checkpointing.py b/deepspeed/runtime/activation_checkpointing/checkpointing.py index 02e0b197e927..2a21cf7ca17a 100644 --- a/deepspeed/runtime/activation_checkpointing/checkpointing.py +++ b/deepspeed/runtime/activation_checkpointing/checkpointing.py @@ -25,8 +25,9 @@ from deepspeed.runtime.config import DeepSpeedConfig from deepspeed.utils import logger -from deepspeed.runtime.utils import copy_to_device, move_to_device, see_memory_usage, bwc_tensor_model_parallel_rank +from deepspeed.runtime.utils import copy_to_device, move_to_device, see_memory_usage from deepspeed.utils.timer import SynchronizedWallClockTimer as Timers, FORWARD_GLOBAL_TIMER +from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank from deepspeed.accelerator import get_accelerator # DeepSpeed Checkpointing Enabled or Disabled diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index d076035604e3..f970e582b354 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -13,11 +13,11 @@ from packaging import version as pkg_version from deepspeed.git_version_info import version from deepspeed.runtime.utils import (get_global_norm_of_tensors, clip_tensors_by_global_norm, DummyOptim, - align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank, - is_model_parallel_parameter, see_memory_usage, graph_process, - get_norm_with_moe_layers) + align_dense_tensors, all_gather_dp_groups, is_model_parallel_parameter, + see_memory_usage, graph_process, get_norm_with_moe_layers) from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address, groups from deepspeed.moe.utils import is_moe_param, is_moe_param_group +from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank from deepspeed.checkpoint import enable_universal_checkpoint from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE, SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS, diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 31bb58b64e04..992d7877c179 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -3228,9 +3228,12 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa # Load flow uses below saved file for model parameters, RNG and more if groups._get_data_parallel_rank() == 0: - # get non-moe parameters + # Get non-moe parameters + # Classes DeepSpeedEngine and PipelineEngine have different behavior for method module_state_dict. + # DeepSpeedEngine returns the state dict, where PipelineEngine saves the state dict and returns None. + # We need to get the state dict, therefore, call to DeepSpeedEngine (base class for PipelineEngine) model_state_dict = self._get_non_moe_state_dict( - self.module_state_dict(exclude_frozen_parameters=exclude_frozen_parameters)) + DeepSpeedEngine.module_state_dict(self, exclude_frozen_parameters=exclude_frozen_parameters)) # TODO: update num experts info,.. in checkpoint state = { diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index ef1c98a95c7b..1dda7f1aad32 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -4,6 +4,7 @@ # DeepSpeed Team from types import MethodType +from collections import OrderedDict import torch from deepspeed import comm as dist @@ -194,9 +195,15 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs): #stores the loss for the entire batch self.total_loss = None + self.total_additional_losses = None self.agg_loss = torch.tensor(0.0, requires_grad=False).to(self.device) self.dp_group_loss = torch.tensor(0.0, requires_grad=False).to(self.device) + # stores aggregated-DP train final loss and aggregated-DP additional losses, if any + # additional losses are stored as dict: {loss-name: agg-loss} + self.agg_train_loss = None + self.agg_additional_losses = None + if self._config.pipeline['activation_checkpoint_interval'] > 0: self.module.activation_checkpoint_interval = self._config.pipeline['activation_checkpoint_interval'] # set use_reentrant default to True. @@ -284,10 +291,7 @@ def _exec_reduce_grads(self): self._force_grad_boundary = False def _bf16_reduce_grads(self): - # Make our own list of gradients from the optimizer's FP32 grads - grads = [] - self.buffered_allreduce_fallback(grads=self.optimizer.get_grads_for_reduction(), - elements_per_buffer=MEMORY_OPT_ALLREDUCE_SIZE) + self.buffered_allreduce_fallback(grads=None, elements_per_buffer=MEMORY_OPT_ALLREDUCE_SIZE) def _reserve_pipe_buffers(self, num_buffers): """Ensure that each pipeline buffer has at least ``num_buffers`` slots. @@ -363,6 +367,7 @@ def train_batch(self, data_iter=None): self.module.train() self.total_loss = None + self.total_additional_losses = None self._compute_loss = True # Do the work @@ -371,7 +376,9 @@ def train_batch(self, data_iter=None): stages=self.num_stages, stage_id=self.stage_id) self._exec_schedule(sched) - self.agg_train_loss = self._aggregate_total_loss() + + with torch.no_grad(): + self.agg_train_loss = self._aggregate_total_loss() self.timers(TRAIN_BATCH_TIMER).stop() @@ -380,10 +387,12 @@ def train_batch(self, data_iter=None): elapsed = self.timers(TRAIN_BATCH_TIMER).elapsed(reset=True) / 1000.0 iter_time = elapsed / self.steps_per_print() tput = self.train_batch_size() / iter_time - print(f'steps: {self.global_steps} ' - f'loss: {self.agg_train_loss:0.4f} ' - f'iter time (s): {iter_time:0.3f} ' - f'samples/sec: {tput:0.3f}') + log_str = f'steps: {self.global_steps} loss: {self.agg_train_loss:0.4f} ' + if self.agg_additional_losses is not None: + for loss_name, loss_value in self.agg_additional_losses.items(): + log_str += f'{loss_name}: {loss_value.item():0.4f} ' + log_str += f'iter time (s): {iter_time:0.3f} samples/sec: {tput:0.3f}' + print(log_str) else: self.timers(TRAIN_BATCH_TIMER).elapsed(reset=True) @@ -565,29 +574,66 @@ def _bcast_pipe_scalar(self, data, src_rank=None, dtype=torch.float32): def _aggregate_total_loss(self): # Scale loss, average among DP ranks, and bcast loss to the rest of my DP group if self.is_last_stage(): + # Scale loss and additional losses, if any loss = self._scale_loss_by_gas(self.total_loss) - self.dp_group_loss = loss.clone().detach() + self.agg_additional_losses = self.total_additional_losses + if self.agg_additional_losses is not None: + self.agg_additional_losses = OrderedDict({ + loss_name: self._scale_loss_by_gas(_loss.clone().detach()) + for loss_name, _loss in self.agg_additional_losses.items() + }) - ## Average loss across all data-parallel groups + self.dp_group_loss = loss.clone().detach() agg_loss = self.dp_group_loss.clone().detach() #print(f'RANK={self.global_rank} bcast SENDER src={self.global_rank} group={self.grid.pp_group}', flush=True) + + # Average loss across all data-parallel groups if self.is_data_parallel: - dist.all_reduce(agg_loss, group=self.mpu.get_data_parallel_group()) - agg_loss /= self.dp_world_size + if self.agg_additional_losses is None: + dist.all_reduce(agg_loss, group=self.mpu.get_data_parallel_group()) + agg_loss /= self.dp_world_size + else: + # use a single reduce op for agg_loss and additional losses, if any + assert '__train_loss__' not in self.agg_additional_losses.keys() + tensors = OrderedDict({'__train_loss__': agg_loss}) + tensors.update(self.agg_additional_losses.items()) + flat_tensor = torch.cat([t.clone().reshape(-1).detach() for t in tensors.values()]) + dist.all_reduce(flat_tensor, group=self.mpu.get_data_parallel_group()) + flat_tensor /= self.dp_world_size + offset = 0 + reduced_tensor = {} + for name, t in tensors.items(): + n_elem = t.numel() + reduced_tensor[name] = flat_tensor[offset:offset + n_elem].clone().detach().reshape(t.shape) + offset += n_elem + agg_loss = reduced_tensor['__train_loss__'] + self.agg_additional_losses = OrderedDict( + {name: reduced_tensor[name] + for name in self.agg_additional_losses.keys()}) assert self.global_rank in self.grid.pp_group - losses = torch.stack([self.dp_group_loss, agg_loss]).float() + losses = [self.dp_group_loss, agg_loss] + if self.agg_additional_losses is not None: + losses += list(self.agg_additional_losses.values()) + losses = torch.stack(losses).float() if self.is_pipe_parallel: dist.broadcast(tensor=losses, src=self.global_rank, group=self.mpu.get_pipe_parallel_group()) else: # Get loss from last stage src_rank = self.grid.stage_to_global(self.num_stages - 1) assert src_rank in self.grid.pp_group - losses = torch.Tensor([0., 0.]).to(self.device) + # losses to reduce are: dp_group_loss, agg_loss, model additional losses + # therefore: 2 + n_additional_losses + additional_losses = self.module.get_additional_losses() + n_additional_losses = 0 if additional_losses is None else len(additional_losses) + losses = torch.Tensor([0.] * (2 + n_additional_losses)).to(self.device) dist.broadcast(tensor=losses, src=src_rank, group=self.grid.get_pipe_parallel_group()) self.dp_group_loss = losses[0].clone().detach() agg_loss = losses[1].clone().detach() - + if additional_losses is not None: + self.agg_additional_losses = OrderedDict( + {name: losses[2 + i].clone().detach() + for i, name in enumerate(additional_losses.keys())}) return agg_loss def set_dataloader(self, loader): @@ -715,19 +761,34 @@ def _exec_forward_pass(self, buffer_id): self.loss = outputs if self.eval_return_logits: self.outputs = outputs + if isinstance(self.loss, torch.Tensor): self.fwd_outputs.append(self.loss.detach()) - - if self.total_loss is None: - self.total_loss = torch.zeros_like(self.loss) - self.total_loss += self.loss.detach() else: self.fwd_outputs.append([l.detach() for l in self.loss]) - if self.total_loss is None: - self.total_loss = [torch.zeros_like(l) for l in self.loss] - for idx, l in enumerate(self.loss): - self.total_loss[idx] += l.detach() + def add_to_total_loss(_total_loss, _loss): + if isinstance(_loss, torch.Tensor): + if _total_loss is None: + _total_loss = torch.zeros_like(_loss) + _total_loss += _loss.detach() + else: + if _total_loss is None: + _total_loss = [torch.zeros_like(_l) for _l in _loss] + for _idx, _l in enumerate(_loss): + _total_loss[_idx] += _l.detach() + return _total_loss + + self.total_loss = add_to_total_loss(self.total_loss, self.loss) + + # aggregate additional losses across gradient accumulation steps + additional_losses = self.module.get_additional_losses() + if additional_losses is not None: + if self.total_additional_losses is None: + self.total_additional_losses = OrderedDict() + for name, loss in additional_losses.items(): + total = self.total_additional_losses[name] if name in self.total_additional_losses else None + self.total_additional_losses[name] = add_to_total_loss(total, loss) def _exec_backward_pass(self, buffer_id): assert self.optimizer is not None, "must provide optimizer during " \ @@ -1332,7 +1393,7 @@ def load_module_state_dict(self, checkpoint, strict=True, custom_load_fn=None, f strict (bool, optional): Strict state loading. Defaults to True. """ assert custom_load_fn is None, "custom_load_fn not supported w. pipeline parallelism" - state_dict = checkpoint['module'] + state_dict = checkpoint if self.has_moe_layers else checkpoint['module'] if (state_dict is not None) and (not isinstance(state_dict, str)): super().load_module_state_dict(state_dict, strict) return @@ -1371,3 +1432,6 @@ def _exec_schedule(self, pipe_schedule): # Equivalent to: self._exec_forward_pass(buffer_id=0) self._exec_instr = MethodType(self._INSTRUCTION_MAP[type(cmd)], self) self._exec_instr(**cmd.kwargs) + + def get_additional_losses(self): + return self.agg_additional_losses diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py index c11379b0a0d7..8036faef72ee 100644 --- a/deepspeed/runtime/pipe/module.py +++ b/deepspeed/runtime/pipe/module.py @@ -634,3 +634,10 @@ def _is_checkpointable(self, funcs): return all(f.__class__.__name__ in self.checkpointable_layers for f in funcs) params = [f.parameters() for f in funcs if isinstance(f, torch.nn.Module)] return any(len(list(p)) > 0 for p in params) + + def get_additional_losses(self): + """ Returns model specific additional losses for reporting + + Return a dictionary of {"loss name": loss_value} or None if no additional losses. + """ + return None diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index c1c2b6c61cfd..9d561f7271eb 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -25,6 +25,8 @@ from torch import inf from deepspeed.utils import groups, logger +from deepspeed.utils.bwc import (bwc_tensor_model_parallel_rank, bwc_pipeline_parallel_world_size, + bwc_pipeline_parallel_group) from deepspeed.runtime.constants import PIPE_REPLICATED from numpy import prod from deepspeed.accelerator import get_accelerator @@ -117,44 +119,6 @@ def is_model_parallel_parameter(p) -> bool: return False -def bwc_tensor_model_parallel_rank(mpu=None): - """Backwards-compatible way of querying the tensor model parallel rank from - an ``mpu`` object. - - *Tensor* model parallelism means that tensors are physically split across - processes. This contrasts with *pipeline* model parallelism, in which the - layers are partitioned but tensors left intact. - - The API for tensor model parallelism has changed across versions and this - helper provides a best-effort implementation across versions of ``mpu`` - objects. The preferred mechanism is - ``mpu.get_tensor_model_parallel_rank()``. - - This should "just work" with both Megatron-LM and DeepSpeed's pipeline - parallelism. - - Args: - mpu (model parallel unit, optional): The tensor model parallel rank. - If ``mpu=None``, returns 0. Defaults to ``None``. - - Returns: - int: the rank - """ - if mpu is None: - # No model parallelism in easy :) - return 0 - - if hasattr(mpu, 'get_tensor_model_parallel_rank'): - # New Megatron and DeepSpeed convention (post pipeline-parallelism release) - return mpu.get_tensor_model_parallel_rank() - elif hasattr(mpu, 'get_slice_parallel_rank'): - # Some DeepSpeed + pipeline parallelism versions - return mpu.get_slice_parallel_rank() - else: - # Deprecated Megatron and DeepSpeed convention - return mpu.get_model_parallel_rank() - - def copy_to_device(item, device, criterion_func): """ Return a copy of tensor on specified device. @@ -894,8 +858,16 @@ def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=F all_norms.append(t.data.abs().max().float()) total_norm = torch.stack(all_norms).max() device_total_norm = total_norm.to(get_accelerator().current_device_name()) + # Max across model parallel if mpu is not None: - dist.all_reduce(device_total_norm, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group()) + # For MoE grads, max over model parallel only if MoE-TP is enabled + if moe_ep_group is None or groups._get_expert_model_parallel_world_size() > 1: + dist.all_reduce(device_total_norm, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group()) + # If MoE grads and MoE-TP disabled, max over pipeline parallel + elif bwc_pipeline_parallel_world_size(mpu) > 1: + dist.all_reduce(device_total_norm, op=dist.ReduceOp.MAX, group=bwc_pipeline_parallel_group(mpu)) + + # MoE grads: max across expert parallel group if moe_ep_group is not None: dist.all_reduce(device_total_norm, op=dist.ReduceOp.MAX, group=moe_ep_group) total_norm = device_total_norm.to(input_tensors[0].device) @@ -922,8 +894,16 @@ def _norm_tensors(tensor_list, _compute_buffer, _norm_type): device_total_norm = compute_buffer[0].float().detach() + # Sum across model parallel if mpu is not None: - dist.all_reduce(device_total_norm, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group()) + # For MoE grads, sum over model parallel only if MoE-TP is enabled + if moe_ep_group is None or groups._get_expert_model_parallel_world_size() > 1: + dist.all_reduce(device_total_norm, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group()) + # If MoE grads and MoE-TP disabled, sum over pipeline parallel + elif bwc_pipeline_parallel_world_size(mpu) > 1: + dist.all_reduce(device_total_norm, op=dist.ReduceOp.SUM, group=bwc_pipeline_parallel_group(mpu)) + + # MoE grads: sum across expert parallel group if moe_ep_group is not None: dist.all_reduce(device_total_norm, op=dist.ReduceOp.SUM, group=moe_ep_group) total_norm = device_total_norm.to(input_tensors[0].device).pow(1. / norm_type) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 2f98379aa14d..16b9c3c18919 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -11,13 +11,13 @@ from deepspeed.runtime.base_optimizer import ZeROOptimizer from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler -from deepspeed.runtime.utils import (bwc_tensor_model_parallel_rank, empty_cache, see_memory_usage, inf, - is_model_parallel_parameter, align_dense_tensors, all_gather_dp_groups) - +from deepspeed.runtime.utils import (empty_cache, see_memory_usage, inf, is_model_parallel_parameter, + align_dense_tensors, all_gather_dp_groups) from deepspeed.runtime.zero.config import ZeroStageEnum from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum from deepspeed.ops.adam import DeepSpeedCPUAdam from deepspeed.utils import logger +from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank from deepspeed.moe.utils import is_moe_param from deepspeed.git_version_info import version diff --git a/deepspeed/utils/bwc.py b/deepspeed/utils/bwc.py new file mode 100644 index 000000000000..69fcc251a684 --- /dev/null +++ b/deepspeed/utils/bwc.py @@ -0,0 +1,104 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + + +def bwc_tensor_model_parallel_rank(mpu=None): + """Backwards-compatible way of querying the tensor model parallel rank from + an ``mpu`` object. + + *Tensor* model parallelism means that tensors are physically split across + processes. This contrasts with *pipeline* model parallelism, in which the + layers are partitioned but tensors left intact. + + The API for tensor model parallelism has changed across versions and this + helper provides a best-effort implementation across versions of ``mpu`` + objects. The preferred mechanism is + ``mpu.get_tensor_model_parallel_rank()``. + + This should "just work" with both Megatron-LM and DeepSpeed's pipeline + parallelism. + + Args: + mpu (model parallel unit, optional): The tensor model parallel rank. + If ``mpu=None``, returns 0. Defaults to ``None``. + + Returns: + int: the rank + """ + if mpu is None: + # No model parallelism in easy :) + return 0 + + if hasattr(mpu, 'get_tensor_model_parallel_rank'): + # New Megatron and DeepSpeed convention (post pipeline-parallelism release) + return mpu.get_tensor_model_parallel_rank() + elif hasattr(mpu, 'get_slice_parallel_rank'): + # Some DeepSpeed + pipeline parallelism versions + return mpu.get_slice_parallel_rank() + else: + # Deprecated Megatron and DeepSpeed convention + return mpu.get_model_parallel_rank() + + +def bwc_tensor_model_parallel_world_size(mpu=None): + """Backwards-compatible way of querying the tensor model parallel world size. + Similar to bwc_tensor_model_parallel_rank. + """ + if mpu is None: + return 1 + + if hasattr(mpu, 'get_tensor_model_parallel_world_size'): + # New Megatron and DeepSpeed convention (post pipeline-parallelism release) + return mpu.get_tensor_model_parallel_world_size() + elif hasattr(mpu, 'get_slice_parallel_world_size'): + # Some DeepSpeed + pipeline parallelism versions + return mpu.get_slice_parallel_world_size() + else: + # Deprecated Megatron and DeepSpeed convention + return mpu.get_model_parallel_world_size() + + +def bwc_tensor_model_parallel_group(mpu=None): + """Backwards-compatible way of querying the tensor model parallel group. + Similar to bwc_tensor_model_parallel_rank. + """ + if mpu is None: + return None + + if hasattr(mpu, 'get_tensor_model_parallel_group'): + # New Megatron and DeepSpeed convention (post pipeline-parallelism release) + return mpu.get_tensor_model_parallel_group() + elif hasattr(mpu, 'get_slice_parallel_group'): + # Some DeepSpeed + pipeline parallelism versions + return mpu.get_slice_parallel_group() + else: + # Deprecated Megatron and DeepSpeed convention + return mpu.get_model_parallel_group() + + +def bwc_pipeline_parallel_world_size(mpu=None): + """Backwards-compatible way of querying the pipeline parallel world size.""" + world_size = 1 + if mpu is not None: + if hasattr(mpu, 'get_pipeline_model_parallel_world_size'): + # New Megatron and DeepSpeed convention (post pipeline-parallelism release) + world_size = mpu.get_pipeline_model_parallel_world_size() + elif hasattr(mpu, 'get_pipe_parallel_world_size'): + # DeepSpeed Topology + world_size = mpu.get_pipe_parallel_world_size() + return world_size + + +def bwc_pipeline_parallel_group(mpu=None): + """Backwards-compatible way of querying the pipeline parallel group.""" + if mpu is None: + return None + if hasattr(mpu, 'get_pipeline_model_parallel_group'): + # Megatron + return mpu.get_pipeline_model_parallel_group() + elif hasattr(mpu, 'get_pipe_parallel_group'): + # DeepSpeed Topology + return mpu.get_pipe_parallel_group() + assert False, 'mpu does not support pipeline parallel group' diff --git a/deepspeed/utils/groups.py b/deepspeed/utils/groups.py index 63dda7f5aaae..c49f4520e16e 100644 --- a/deepspeed/utils/groups.py +++ b/deepspeed/utils/groups.py @@ -27,6 +27,7 @@ from deepspeed import comm as dist from deepspeed.utils import log_dist +from deepspeed.utils.bwc import bwc_tensor_model_parallel_world_size, bwc_pipeline_parallel_world_size from deepspeed.utils.exceptions import DeprecatedException from deepspeed.accelerator import get_accelerator # Expert parallel group that the current rank belongs to. @@ -128,31 +129,32 @@ def _create_expert_and_data_parallel(expert_parallel_size_, use_data_before_expe log_dist(f'Creating expert and data parallel groups with size {expert_parallel_size_}', ranks=[0]) world_size = dist.get_world_size() + pp_world_size = 1 if mpu is None else bwc_pipeline_parallel_world_size(mpu) rank = dist.get_rank() - _ensure_divisibility(world_size, expert_parallel_size_) + pp_stride = world_size // pp_world_size + _ensure_divisibility(pp_stride, expert_parallel_size_) group_name = f"ep_size_{expert_parallel_size_}" # Build the expert data parallel groups. global _EXPERT_DATA_PARALLEL_GROUP - ep_stride = world_size // expert_parallel_size_ + ep_stride = pp_stride // expert_parallel_size_ # Only create group if it does not already exist if group_name not in _EXPERT_DATA_PARALLEL_GROUP: - for i in range(expert_parallel_size_): - if use_data_before_expert_parallel_: - ranks = range(i * ep_stride, (i + 1) * ep_stride) - else: - ranks = range(i, world_size, expert_parallel_size_) - group = dist.new_group(ranks) - log_dist(f'Creating expert data parallel process group named {group_name} with ranks: {list(ranks)}', [0]) - if use_data_before_expert_parallel_: - if i == (rank // ep_stride): - _EXPERT_DATA_PARALLEL_GROUP[group_name] = group - else: - if i == (rank % expert_parallel_size_): + for pp_stage_start in range(0, world_size, pp_stride): + for i in range(expert_parallel_size_): + if use_data_before_expert_parallel_: + ranks = range(pp_stage_start + i * ep_stride, pp_stage_start + (i + 1) * ep_stride) + else: + ranks = range(pp_stage_start + i, pp_stage_start + pp_stride, expert_parallel_size_) + group = dist.new_group(ranks) + log_dist( + f'Creating expert data parallel process group named {group_name} ' + f'with ranks: {list(ranks)}', [0]) + if rank in ranks: _EXPERT_DATA_PARALLEL_GROUP[group_name] = group # Build the expert parallel groups. @@ -161,24 +163,29 @@ def _create_expert_and_data_parallel(expert_parallel_size_, use_data_before_expe # Only create group if it does not already exist if group_name not in _EXPERT_PARALLEL_GROUP: if use_data_before_expert_parallel_: - for i in range(ep_stride): - ranks = range(i, world_size, ep_stride) - group = dist.new_group(ranks) - log_dist(f'creating expert parallel process group named {group_name} with ranks: {list(ranks)}', [0]) - if i == (rank % ep_stride): - _EXPERT_PARALLEL_GROUP[group_name] = group + for pp_stage_start in range(0, world_size, pp_stride): + for i in range(ep_stride): + ranks = range(pp_stage_start + i, pp_stage_start + pp_stride, ep_stride) + group = dist.new_group(ranks) + log_dist( + f'creating expert parallel process group named {group_name} ' + f'with ranks: {list(ranks)}', [0]) + if rank in ranks: + _EXPERT_PARALLEL_GROUP[group_name] = group else: for i in range(world_size // expert_parallel_size_): ranks = range(i * expert_parallel_size_, (i + 1) * expert_parallel_size_) group = dist.new_group(ranks) - log_dist(f'creating expert parallel process group named {group_name} with ranks: {list(ranks)}', [0]) - if i == (rank // expert_parallel_size_): + log_dist(f'creating expert parallel process group named {group_name} ' + f'with ranks: {list(ranks)}', [0]) + if rank in ranks: _EXPERT_PARALLEL_GROUP[group_name] = group def _get_expert_parallel_ranks(world_size, - model_parallel_size_, + tensor_parallel_size_, expert_parallel_size_, + pipeline_parallel_size_=1, use_data_before_expert_parallel_=False): """Generate expert parallel and expert data parallel group ranks list. @@ -193,32 +200,40 @@ def _get_expert_parallel_ranks(world_size, Args: world_size (int): Distributed world size. - model_parallel_size_ (int): Model parallel group size. + tensor_parallel_size_ (int): Tensor parallel group size. expert_parallel_size_ (int): Expert parallel group size. + pipeline_parallel_size_ (int): Pipeline parallel group size use_data_before_expert_parallel_ (bool): Use the D + E instead of E + D topology Returns: Expert parallel group ranks and Expert data parallel group ranks list. """ - _ensure_divisibility(world_size, model_parallel_size_) - dp_world_size = world_size // model_parallel_size_ + _ensure_divisibility(world_size, tensor_parallel_size_ * pipeline_parallel_size_) + dp_world_size = world_size // (tensor_parallel_size_ * pipeline_parallel_size_) _ensure_divisibility(dp_world_size, expert_parallel_size_) # Generate data parallel groups data_parallel_groups = [] - dp_group_size = model_parallel_size_ + dp_group_size = tensor_parallel_size_ + pp_stride = world_size // pipeline_parallel_size_ if use_data_before_expert_parallel_: - dp_stride = world_size // expert_parallel_size_ // model_parallel_size_ - for i in range(dp_group_size): - data_parallel_groups.append(list()) - for ds in range(dp_stride): - # [0, 4, 8, 12, 16, 20, 24, 28, 2, 6, 10, 14, 18, 22, 26, 30] - # [1, 5, 9, 13, 17, 21, 25, 29, 3, 7, 11, 15, 19, 23, 27, 31] - data_parallel_groups[-1].extend( - list(range(i + ds * model_parallel_size_, world_size, dp_stride * model_parallel_size_))) + dp_stride = world_size // expert_parallel_size_ // tensor_parallel_size_ // pipeline_parallel_size_ + for pp_stage_start in range(0, world_size, pp_stride): + pp_stage_next = pp_stage_start + pp_stride + for i in range(dp_group_size): + data_parallel_groups.append(list()) + for ds in range(dp_stride): + # [0, 4, 8, 12, 16, 20, 24, 28, 2, 6, 10, 14, 18, 22, 26, 30] + # [1, 5, 9, 13, 17, 21, 25, 29, 3, 7, 11, 15, 19, 23, 27, 31] + data_parallel_groups[-1].extend( + list( + range(pp_stage_start + i + ds * tensor_parallel_size_, pp_stage_next, + dp_stride * tensor_parallel_size_))) else: - for i in range(dp_group_size): - data_parallel_groups.append(list(range(i, world_size, dp_group_size))) + for pp_stage_start in range(0, world_size, pp_stride): + pp_stage_next = pp_stage_start + pp_stride + for i in range(dp_group_size): + data_parallel_groups.append(list(range(pp_stage_start + i, pp_stage_next, dp_group_size))) expert_parallel_groups = [] expert_data_parallel_groups = [] @@ -252,36 +267,33 @@ def _create_expert_data_and_model_parallel(expert_parallel_size_, mpu, use_data_ expert_data_parallel_group = [0,8],[2,10],[4,12],[6,14], [1,9],[3,11],[5,13],[7,15] """ assert dist.is_initialized(), "dist is not initialized" - model_parallel_size_ = mpu.get_model_parallel_world_size() + tensor_parallel_size_ = bwc_tensor_model_parallel_world_size(mpu) global expert_tensor_parallel_world_size - expert_tensor_parallel_world_size = model_parallel_size_ + expert_tensor_parallel_world_size = tensor_parallel_size_ world_size = dist.get_world_size() rank = dist.get_rank() dp_world_size = mpu.get_data_parallel_world_size() - dp_rank = mpu.get_data_parallel_rank() + pp_world_size = 1 if mpu is None else bwc_pipeline_parallel_world_size(mpu) - _ensure_divisibility(world_size, model_parallel_size_) + _ensure_divisibility(world_size, tensor_parallel_size_) _ensure_divisibility(dp_world_size, expert_parallel_size_) log_dist( - f"Creating deepspeed groups with model parallel size {model_parallel_size_}, expert parallel size {expert_parallel_size_}, world size {world_size}, dp world size {dp_world_size}", - [0]) + f"Creating deepspeed groups with model parallel size {tensor_parallel_size_}, " + f"pipeline parallel size {pp_world_size}, expert parallel size {expert_parallel_size_}, " + f"world size {world_size}, dp world size {dp_world_size}", [0]) global _EXPERT_PARALLEL_GROUP, _EXPERT_DATA_PARALLEL_GROUP - # Get world size and rank. Ensure some consistencies. - _DATA_PARALLEL_GROUP = mpu.get_data_parallel_group() - _MODEL_PARALLEL_GROUP = mpu.get_model_parallel_group() - group_name = f"ep_size_{expert_parallel_size_}" # Only create groups if they don't already exist # Need to check conditions outside the group creation loop because of the way torch.dist group creation works if group_name not in _EXPERT_DATA_PARALLEL_GROUP and group_name not in _EXPERT_PARALLEL_GROUP: expert_parallel_groups, expert_data_parallel_groups = _get_expert_parallel_ranks( - world_size, model_parallel_size_, expert_parallel_size_, use_data_before_expert_parallel_) + world_size, tensor_parallel_size_, expert_parallel_size_, pp_world_size, use_data_before_expert_parallel_) for ranks in expert_parallel_groups: group = dist.new_group(ranks) if rank in list(ranks): diff --git a/tests/unit/utils/test_groups.py b/tests/unit/utils/test_groups.py index d8f12be4f3c6..5cd35baf3510 100644 --- a/tests/unit/utils/test_groups.py +++ b/tests/unit/utils/test_groups.py @@ -18,7 +18,7 @@ def test_get_expert_parallel_ranks(): expert_data_parallel_group = [0,8],[2,10],[4,12],[6,14], [1,9],[3,11],[5,13],[7,15] """ expert_parallel_groups, expert_data_parallel_groups = _get_expert_parallel_ranks(world_size=16, - model_parallel_size_=2, + tensor_parallel_size_=2, expert_parallel_size_=4) assert expert_parallel_groups == [ [0, 2, 4, 6],