Skip to content

Commit

Permalink
MOE: Enhance expert group creation for pipeline
Browse files Browse the repository at this point in the history
This commit enhances expert group creation for both modes:
- DP + PP + EP
- DP + TP + PP + EP

Signed-off-by: Moshe Island <[email protected]>
  • Loading branch information
misland-habana committed Mar 30, 2024
1 parent 83bda4e commit 5831321
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 49 deletions.
108 changes: 60 additions & 48 deletions deepspeed/utils/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from deepspeed import comm as dist
from deepspeed.utils import log_dist
from deepspeed.utils.bwc import bwc_tensor_model_parallel_world_size, bwc_pipeline_parallel_world_size
from deepspeed.utils.exceptions import DeprecatedException
from deepspeed.accelerator import get_accelerator
# Expert parallel group that the current rank belongs to.
Expand Down Expand Up @@ -128,31 +129,32 @@ def _create_expert_and_data_parallel(expert_parallel_size_, use_data_before_expe

log_dist(f'Creating expert and data parallel groups with size {expert_parallel_size_}', ranks=[0])
world_size = dist.get_world_size()
pp_world_size = 1 if mpu is None else bwc_pipeline_parallel_world_size(mpu)
rank = dist.get_rank()

_ensure_divisibility(world_size, expert_parallel_size_)
pp_stride = world_size // pp_world_size
_ensure_divisibility(pp_stride, expert_parallel_size_)

group_name = f"ep_size_{expert_parallel_size_}"

# Build the expert data parallel groups.
global _EXPERT_DATA_PARALLEL_GROUP

ep_stride = world_size // expert_parallel_size_
ep_stride = pp_stride // expert_parallel_size_

# Only create group if it does not already exist
if group_name not in _EXPERT_DATA_PARALLEL_GROUP:
for i in range(expert_parallel_size_):
if use_data_before_expert_parallel_:
ranks = range(i * ep_stride, (i + 1) * ep_stride)
else:
ranks = range(i, world_size, expert_parallel_size_)
group = dist.new_group(ranks)
log_dist(f'Creating expert data parallel process group named {group_name} with ranks: {list(ranks)}', [0])
if use_data_before_expert_parallel_:
if i == (rank // ep_stride):
_EXPERT_DATA_PARALLEL_GROUP[group_name] = group
else:
if i == (rank % expert_parallel_size_):
for pp_stage_start in range(0, world_size, pp_stride):
for i in range(expert_parallel_size_):
if use_data_before_expert_parallel_:
ranks = range(pp_stage_start + i * ep_stride, pp_stage_start + (i + 1) * ep_stride)
else:
ranks = range(pp_stage_start + i, pp_stage_start + pp_stride, expert_parallel_size_)
group = dist.new_group(ranks)
log_dist(
f'Creating expert data parallel process group named {group_name} '
f'with ranks: {list(ranks)}', [0])
if rank in ranks:
_EXPERT_DATA_PARALLEL_GROUP[group_name] = group

# Build the expert parallel groups.
Expand All @@ -161,24 +163,29 @@ def _create_expert_and_data_parallel(expert_parallel_size_, use_data_before_expe
# Only create group if it does not already exist
if group_name not in _EXPERT_PARALLEL_GROUP:
if use_data_before_expert_parallel_:
for i in range(ep_stride):
ranks = range(i, world_size, ep_stride)
group = dist.new_group(ranks)
log_dist(f'creating expert parallel process group named {group_name} with ranks: {list(ranks)}', [0])
if i == (rank % ep_stride):
_EXPERT_PARALLEL_GROUP[group_name] = group
for pp_stage_start in range(0, world_size, pp_stride):
for i in range(ep_stride):
ranks = range(pp_stage_start + i, pp_stage_start + pp_stride, ep_stride)
group = dist.new_group(ranks)
log_dist(
f'creating expert parallel process group named {group_name} '
f'with ranks: {list(ranks)}', [0])
if rank in ranks:
_EXPERT_PARALLEL_GROUP[group_name] = group
else:
for i in range(world_size // expert_parallel_size_):
ranks = range(i * expert_parallel_size_, (i + 1) * expert_parallel_size_)
group = dist.new_group(ranks)
log_dist(f'creating expert parallel process group named {group_name} with ranks: {list(ranks)}', [0])
if i == (rank // expert_parallel_size_):
log_dist(f'creating expert parallel process group named {group_name} '
f'with ranks: {list(ranks)}', [0])
if rank in ranks:
_EXPERT_PARALLEL_GROUP[group_name] = group


def _get_expert_parallel_ranks(world_size,
model_parallel_size_,
tensor_parallel_size_,
expert_parallel_size_,
pipeline_parallel_size_=1,
use_data_before_expert_parallel_=False):
"""Generate expert parallel and expert data parallel group ranks list.
Expand All @@ -193,32 +200,40 @@ def _get_expert_parallel_ranks(world_size,
Args:
world_size (int): Distributed world size.
model_parallel_size_ (int): Model parallel group size.
tensor_parallel_size_ (int): Tensor parallel group size.
expert_parallel_size_ (int): Expert parallel group size.
pipeline_parallel_size_ (int): Pipeline parallel group size
use_data_before_expert_parallel_ (bool): Use the D + E instead of E + D topology
Returns:
Expert parallel group ranks and Expert data parallel group ranks list.
"""
_ensure_divisibility(world_size, model_parallel_size_)
dp_world_size = world_size // model_parallel_size_
_ensure_divisibility(world_size, tensor_parallel_size_ * pipeline_parallel_size_)
dp_world_size = world_size // (tensor_parallel_size_ * pipeline_parallel_size_)
_ensure_divisibility(dp_world_size, expert_parallel_size_)

# Generate data parallel groups
data_parallel_groups = []
dp_group_size = model_parallel_size_
dp_group_size = tensor_parallel_size_
pp_stride = world_size // pipeline_parallel_size_

if use_data_before_expert_parallel_:
dp_stride = world_size // expert_parallel_size_ // model_parallel_size_
for i in range(dp_group_size):
data_parallel_groups.append(list())
for ds in range(dp_stride):
# [0, 4, 8, 12, 16, 20, 24, 28, 2, 6, 10, 14, 18, 22, 26, 30]
# [1, 5, 9, 13, 17, 21, 25, 29, 3, 7, 11, 15, 19, 23, 27, 31]
data_parallel_groups[-1].extend(
list(range(i + ds * model_parallel_size_, world_size, dp_stride * model_parallel_size_)))
dp_stride = world_size // expert_parallel_size_ // tensor_parallel_size_ // pipeline_parallel_size_
for pp_stage_start in range(0, world_size, pp_stride):
pp_stage_next = pp_stage_start + pp_stride
for i in range(dp_group_size):
data_parallel_groups.append(list())
for ds in range(dp_stride):
# [0, 4, 8, 12, 16, 20, 24, 28, 2, 6, 10, 14, 18, 22, 26, 30]
# [1, 5, 9, 13, 17, 21, 25, 29, 3, 7, 11, 15, 19, 23, 27, 31]
data_parallel_groups[-1].extend(
list(
range(pp_stage_start + i + ds * tensor_parallel_size_, pp_stage_next,
dp_stride * tensor_parallel_size_)))
else:
for i in range(dp_group_size):
data_parallel_groups.append(list(range(i, world_size, dp_group_size)))
for pp_stage_start in range(0, world_size, pp_stride):
pp_stage_next = pp_stage_start + pp_stride
for i in range(dp_group_size):
data_parallel_groups.append(list(range(pp_stage_start + i, pp_stage_next, dp_group_size)))

expert_parallel_groups = []
expert_data_parallel_groups = []
Expand Down Expand Up @@ -252,36 +267,33 @@ def _create_expert_data_and_model_parallel(expert_parallel_size_, mpu, use_data_
expert_data_parallel_group = [0,8],[2,10],[4,12],[6,14], [1,9],[3,11],[5,13],[7,15]
"""
assert dist.is_initialized(), "dist is not initialized"
model_parallel_size_ = mpu.get_model_parallel_world_size()
tensor_parallel_size_ = bwc_tensor_model_parallel_world_size(mpu)

global expert_tensor_parallel_world_size
expert_tensor_parallel_world_size = model_parallel_size_
expert_tensor_parallel_world_size = tensor_parallel_size_

world_size = dist.get_world_size()
rank = dist.get_rank()
dp_world_size = mpu.get_data_parallel_world_size()
dp_rank = mpu.get_data_parallel_rank()
pp_world_size = 1 if mpu is None else bwc_pipeline_parallel_world_size(mpu)

_ensure_divisibility(world_size, model_parallel_size_)
_ensure_divisibility(world_size, tensor_parallel_size_)
_ensure_divisibility(dp_world_size, expert_parallel_size_)

log_dist(
f"Creating deepspeed groups with model parallel size {model_parallel_size_}, expert parallel size {expert_parallel_size_}, world size {world_size}, dp world size {dp_world_size}",
[0])
f"Creating deepspeed groups with model parallel size {tensor_parallel_size_}, "
f"pipeline parallel size {pp_world_size}, expert parallel size {expert_parallel_size_}, "
f"world size {world_size}, dp world size {dp_world_size}", [0])

global _EXPERT_PARALLEL_GROUP, _EXPERT_DATA_PARALLEL_GROUP

# Get world size and rank. Ensure some consistencies.
_DATA_PARALLEL_GROUP = mpu.get_data_parallel_group()
_MODEL_PARALLEL_GROUP = mpu.get_model_parallel_group()

group_name = f"ep_size_{expert_parallel_size_}"

# Only create groups if they don't already exist
# Need to check conditions outside the group creation loop because of the way torch.dist group creation works
if group_name not in _EXPERT_DATA_PARALLEL_GROUP and group_name not in _EXPERT_PARALLEL_GROUP:
expert_parallel_groups, expert_data_parallel_groups = _get_expert_parallel_ranks(
world_size, model_parallel_size_, expert_parallel_size_, use_data_before_expert_parallel_)
world_size, tensor_parallel_size_, expert_parallel_size_, pp_world_size, use_data_before_expert_parallel_)
for ranks in expert_parallel_groups:
group = dist.new_group(ranks)
if rank in list(ranks):
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/utils/test_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_get_expert_parallel_ranks():
expert_data_parallel_group = [0,8],[2,10],[4,12],[6,14], [1,9],[3,11],[5,13],[7,15]
"""
expert_parallel_groups, expert_data_parallel_groups = _get_expert_parallel_ranks(world_size=16,
model_parallel_size_=2,
tensor_parallel_size_=2,
expert_parallel_size_=4)
assert expert_parallel_groups == [
[0, 2, 4, 6],
Expand Down

0 comments on commit 5831321

Please sign in to comment.