Skip to content

Commit

Permalink
Merge branch 'master' into moe/pipe
Browse files Browse the repository at this point in the history
  • Loading branch information
mosheisland authored Apr 7, 2024
2 parents 526ce7f + 42a8eaa commit 4d8bf27
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 1 deletion.
30 changes: 30 additions & 0 deletions deepspeed/moe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,33 @@ def split_params_into_different_moe_groups_for_optimizer(

def is_moe_param_group(param_group):
return param_group.get('moe', False)


def configure_moe_param_groups(model_parameters: List):
assert isinstance(model_parameters, list), "model_parameters must be a list"

for p in model_parameters:
# match torch.optim.Optimizer expectations,
# see: https://github.com/pytorch/pytorch/blob/2ffab6e663b9c6951048b8c8ba82d2cc5ca5c2fc/torch/optim/optimizer.py#L270-L272
if not isinstance(p, (torch.Tensor, dict)):
raise TypeError("param argument that would be given to the optimizer should be "
f"an iterable of Tensors or dicts, but got {type(p)}")

# peak at the first element to determine how to proceed
first = model_parameters[0]

# Case 1: model_parameters is a list of torch.nn.Parameter
# -> need to create moe compatible param groups
if isinstance(first, torch.nn.Parameter):
param_group = {'params': model_parameters, 'name': 'dense-params'}
return split_params_into_different_moe_groups_for_optimizer(param_group)

# Case 2: model_parameters is a list of param groups List[dict]
# -> moe compatible param groups might already exist, if not create them
elif isinstance(first, dict):
#there are no moe groups created
if not any(['moe' in param_group for param_group in model_parameters]):
return split_params_into_different_moe_groups_for_optimizer(model_parameters)
else:
# moe groups exist, nothing to do
return model_parameters
4 changes: 3 additions & 1 deletion deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
from ..ops.adam import FusedAdam
from ..moe.sharded_moe import TopKGate, MOELayer
from ..moe.layer import MoE
from ..moe.utils import is_moe_param
from ..moe.utils import is_moe_param, configure_moe_param_groups
from ..git_version_info import version

from deepspeed.profiling.flops_profiler.profiler import FlopsProfiler
Expand Down Expand Up @@ -1227,6 +1227,8 @@ def _do_optimizer_sanity_check(self, basic_optimizer):
# Configure optimizer
def _configure_optimizer(self, client_optimizer, model_parameters):
if client_optimizer is None:
if self.has_moe_layers:
model_parameters = configure_moe_param_groups(model_parameters)
basic_optimizer = self._configure_basic_optimizer(model_parameters)
log_dist(f"Using DeepSpeed Optimizer param name {self.optimizer_name()} as basic optimizer", ranks=[0])
else:
Expand Down
36 changes: 36 additions & 0 deletions tests/unit/moe/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,42 @@
from deepspeed.runtime.utils import required_torch_version


@pytest.mark.parametrize("zero_stage", [0, 1, 2])
class TestSimpleMoE(DistributedTest):
world_size = 2

def test(self, zero_stage):
if not required_torch_version(min_version=1.8):
pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly")

config_dict = {
"train_micro_batch_size_per_gpu": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"fp16": {
"enabled": True
},
"zero_optimization": {
"stage": zero_stage
}
}
# should automatically create moe param groups in deepspeed backend
hidden_dim = 16
model = SimpleMoEModel(hidden_dim=hidden_dim, ep_size=1)
model, optimizer, _, _ = deepspeed.initialize(config=config_dict, model=model)
data_loader = sequence_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device)

for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()


@pytest.mark.parametrize("ep_size", [2, 4])
@pytest.mark.parametrize("zero_stage", [0, 1, 2])
@pytest.mark.parametrize("use_residual", [True, False])
Expand Down

0 comments on commit 4d8bf27

Please sign in to comment.