Skip to content

Commit

Permalink
Add tests for different sharding types (pytorch#1673)
Browse files Browse the repository at this point in the history
Summary:

Expand testing for UVM and different sharding types.

Reviewed By: joshuadeng

Differential Revision: D53207247
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed Feb 7, 2024
1 parent 3daadd9 commit 7a140e8
Show file tree
Hide file tree
Showing 2 changed files with 257 additions and 14 deletions.
202 changes: 198 additions & 4 deletions torchrec/distributed/test_utils/test_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,13 @@ def setUp(self, backend: str = "nccl") -> None:
SharderType.EMBEDDING_BAG_COLLECTION.value,
]
),
kernel_type=st.sampled_from(
[
EmbeddingComputeKernel.FUSED.value,
EmbeddingComputeKernel.FUSED_UVM_CACHING.value,
EmbeddingComputeKernel.FUSED_UVM.value,
],
),
qcomms_config=st.sampled_from(
[
None,
Expand All @@ -162,6 +169,7 @@ def setUp(self, backend: str = "nccl") -> None:
def test_sharding_rw(
self,
sharder_type: str,
kernel_type: str,
qcomms_config: Optional[QCommsConfig],
apply_optimizer_in_backward_config: Optional[
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
Expand All @@ -174,7 +182,6 @@ def test_sharding_rw(
)

sharding_type = ShardingType.ROW_WISE.value
kernel_type = EmbeddingComputeKernel.FUSED.value
assume(
sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value
or not variable_batch_size
Expand Down Expand Up @@ -206,19 +213,24 @@ def test_sharding_rw(
SharderType.EMBEDDING_BAG_COLLECTION.value,
]
),
kernel_type=st.sampled_from(
[
EmbeddingComputeKernel.DENSE.value,
],
),
apply_optimizer_in_backward_config=st.sampled_from([None]),
# TODO - need to enable optimizer overlapped behavior for data_parallel tables
)
@settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None)
def test_sharding_dp(
self,
sharder_type: str,
kernel_type: str,
apply_optimizer_in_backward_config: Optional[
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
],
) -> None:
sharding_type = ShardingType.DATA_PARALLEL.value
kernel_type = EmbeddingComputeKernel.DENSE.value
self._test_sharding(
# pyre-ignore[6]
sharders=[
Expand All @@ -236,6 +248,13 @@ def test_sharding_dp(
SharderType.EMBEDDING_BAG_COLLECTION.value,
]
),
kernel_type=st.sampled_from(
[
EmbeddingComputeKernel.FUSED.value,
EmbeddingComputeKernel.FUSED_UVM_CACHING.value,
EmbeddingComputeKernel.FUSED_UVM.value,
],
),
qcomms_config=st.sampled_from(
[
None,
Expand All @@ -259,14 +278,20 @@ def test_sharding_dp(
def test_sharding_cw(
self,
sharder_type: str,
kernel_type: str,
qcomms_config: Optional[QCommsConfig],
apply_optimizer_in_backward_config: Optional[
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
],
variable_batch_size: bool,
) -> None:
if (
self.device == torch.device("cpu")
and kernel_type != EmbeddingComputeKernel.FUSED.value
):
self.skipTest("CPU does not support uvm.")

sharding_type = ShardingType.COLUMN_WISE.value
kernel_type = EmbeddingComputeKernel.FUSED.value
assume(
sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value
or not variable_batch_size
Expand Down Expand Up @@ -300,6 +325,90 @@ def test_sharding_cw(
SharderType.EMBEDDING_BAG_COLLECTION.value,
]
),
kernel_type=st.sampled_from(
[
EmbeddingComputeKernel.FUSED.value,
EmbeddingComputeKernel.FUSED_UVM_CACHING.value,
EmbeddingComputeKernel.FUSED_UVM.value,
],
),
qcomms_config=st.sampled_from(
[
None,
QCommsConfig(
forward_precision=CommType.FP16, backward_precision=CommType.BF16
),
]
),
apply_optimizer_in_backward_config=st.sampled_from(
[
None,
{
"embeddingbags": (torch.optim.SGD, {"lr": 0.01}),
"embeddings": (torch.optim.SGD, {"lr": 0.2}),
},
]
),
variable_batch_size=st.booleans(),
)
@settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None)
def test_sharding_twcw(
self,
sharder_type: str,
kernel_type: str,
qcomms_config: Optional[QCommsConfig],
apply_optimizer_in_backward_config: Optional[
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
],
variable_batch_size: bool,
) -> None:
if (
self.device == torch.device("cpu")
and kernel_type != EmbeddingComputeKernel.FUSED.value
):
self.skipTest("CPU does not support uvm.")

sharding_type = ShardingType.TABLE_COLUMN_WISE.value
assume(
sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value
or not variable_batch_size
)
self._test_sharding(
# pyre-ignore[6]
sharders=[
create_test_sharder(
sharder_type,
sharding_type,
kernel_type,
qcomms_config=qcomms_config,
device=self.device,
),
],
backend=self.backend,
qcomms_config=qcomms_config,
constraints={
table.name: ParameterConstraints(min_partition=4)
for table in self.tables
},
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
variable_batch_size=variable_batch_size,
)

# pyre-fixme[56]
@given(
sharder_type=st.sampled_from(
[
# SharderType.EMBEDDING_BAG.value,
SharderType.EMBEDDING_BAG_COLLECTION.value,
]
),
kernel_type=st.sampled_from(
[
EmbeddingComputeKernel.FUSED.value,
EmbeddingComputeKernel.FUSED_UVM_CACHING.value,
EmbeddingComputeKernel.FUSED_UVM.value,
],
),
qcomms_config=st.sampled_from(
[
# None,
Expand All @@ -324,14 +433,97 @@ def test_sharding_cw(
def test_sharding_tw(
self,
sharder_type: str,
kernel_type: str,
qcomms_config: Optional[QCommsConfig],
apply_optimizer_in_backward_config: Optional[
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
],
variable_batch_size: bool,
) -> None:
if (
self.device == torch.device("cpu")
and kernel_type != EmbeddingComputeKernel.FUSED.value
):
self.skipTest("CPU does not support uvm.")

sharding_type = ShardingType.TABLE_WISE.value
kernel_type = EmbeddingComputeKernel.FUSED.value
assume(
sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value
or not variable_batch_size
)
self._test_sharding(
# pyre-ignore[6]
sharders=[
create_test_sharder(
sharder_type,
sharding_type,
kernel_type,
qcomms_config=qcomms_config,
device=self.device,
),
],
backend=self.backend,
qcomms_config=qcomms_config,
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
variable_batch_size=variable_batch_size,
)

@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs, this test requires at least two GPUs",
)
# pyre-fixme[56]
@given(
sharder_type=st.sampled_from(
[
# SharderType.EMBEDDING_BAG.value,
SharderType.EMBEDDING_BAG_COLLECTION.value,
]
),
kernel_type=st.sampled_from(
[
EmbeddingComputeKernel.FUSED.value,
EmbeddingComputeKernel.FUSED_UVM_CACHING.value,
EmbeddingComputeKernel.FUSED_UVM.value,
],
),
qcomms_config=st.sampled_from(
[
# None,
QCommsConfig(
forward_precision=CommType.FP16,
backward_precision=CommType.BF16,
),
]
),
apply_optimizer_in_backward_config=st.sampled_from(
[
None,
{
"embeddingbags": (torch.optim.SGD, {"lr": 0.01}),
"embeddings": (torch.optim.SGD, {"lr": 0.2}),
},
]
),
variable_batch_size=st.booleans(),
)
@settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None)
def test_sharding_twrw(
self,
sharder_type: str,
kernel_type: str,
qcomms_config: Optional[QCommsConfig],
apply_optimizer_in_backward_config: Optional[
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
],
variable_batch_size: bool,
) -> None:
if self.backend == "gloo":
self.skipTest(
"Gloo reduce_scatter_base fallback not supported with async_op=True"
)

sharding_type = ShardingType.TABLE_ROW_WISE.value
assume(
sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value
or not variable_batch_size
Expand Down Expand Up @@ -364,6 +556,8 @@ def test_sharding_tw(
ShardingType.TABLE_WISE.value,
ShardingType.COLUMN_WISE.value,
ShardingType.ROW_WISE.value,
ShardingType.TABLE_ROW_WISE.value,
ShardingType.TABLE_COLUMN_WISE.value,
]
),
global_constant_batch=st.booleans(),
Expand Down
Loading

0 comments on commit 7a140e8

Please sign in to comment.