Skip to content

Commit

Permalink
fix comms dtype (#5297)
Browse files Browse the repository at this point in the history
is the comms dtype name a bug?
can we fix it?

Co-authored-by: Mayank Mishra <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
3 people authored Mar 27, 2024
1 parent 4520edd commit 5d29ad7
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,10 @@ def get_communication_data_type(param_dict,
return torch.float32
elif val == "fp16":
return torch.float16
elif val == "bfp16":
elif val == "bf16":
return torch.bfloat16

raise ValueError(f"Invalid communication_data_type. Supported data types: ['fp16', 'bfp16', 'fp32']. Got: {val}")
raise ValueError(f"Invalid communication_data_type. Supported data types: ['fp16', 'bf16', 'fp32']. Got: {val}")


def get_prescale_gradients(param_dict):
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/runtime/half_precision/test_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,8 @@ def test(self, stage=2):
model.step()


@pytest.mark.parametrize("comp_type", [torch.float16, torch.bfloat16, torch.float], ids=["fp16", "bfp16", "fp32"])
@pytest.mark.parametrize("comm_type", [torch.float16, torch.bfloat16, None], ids=["fp16", "bfp16", "default"])
@pytest.mark.parametrize("comp_type", [torch.float16, torch.bfloat16, torch.float], ids=["fp16", "bf16", "fp32"])
@pytest.mark.parametrize("comm_type", [torch.float16, torch.bfloat16, None], ids=["fp16", "bf16", "default"])
class TestZeroDtypeCocktail(DistributedTest):
world_size = 2

Expand All @@ -304,7 +304,7 @@ def test(self, comp_type, comm_type):
if not get_accelerator().is_fp16_supported():
pytest.skip("fp16 is not supported")

type_str = {torch.float16: "fp16", torch.bfloat16: "bfp16"}
type_str = {torch.float16: "fp16", torch.bfloat16: "bf16"}

config_dict = {
"train_micro_batch_size_per_gpu": 2,
Expand Down

0 comments on commit 5d29ad7

Please sign in to comment.