From 7a140e84e1cd346ecb40654aa9834e301a3d5779 Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Wed, 7 Feb 2024 12:47:37 -0800 Subject: [PATCH] Add tests for different sharding types (#1673) Summary: Expand testing for UVM and different sharding types. Reviewed By: joshuadeng Differential Revision: D53207247 --- .../test_utils/test_model_parallel.py | 202 +++++++++++++++++- .../test_utils/test_model_parallel_base.py | 69 +++++- 2 files changed, 257 insertions(+), 14 deletions(-) diff --git a/torchrec/distributed/test_utils/test_model_parallel.py b/torchrec/distributed/test_utils/test_model_parallel.py index bc5c41ba7..bef0ff9af 100644 --- a/torchrec/distributed/test_utils/test_model_parallel.py +++ b/torchrec/distributed/test_utils/test_model_parallel.py @@ -139,6 +139,13 @@ def setUp(self, backend: str = "nccl") -> None: SharderType.EMBEDDING_BAG_COLLECTION.value, ] ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ), qcomms_config=st.sampled_from( [ None, @@ -162,6 +169,7 @@ def setUp(self, backend: str = "nccl") -> None: def test_sharding_rw( self, sharder_type: str, + kernel_type: str, qcomms_config: Optional[QCommsConfig], apply_optimizer_in_backward_config: Optional[ Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] @@ -174,7 +182,6 @@ def test_sharding_rw( ) sharding_type = ShardingType.ROW_WISE.value - kernel_type = EmbeddingComputeKernel.FUSED.value assume( sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value or not variable_batch_size @@ -206,6 +213,11 @@ def test_sharding_rw( SharderType.EMBEDDING_BAG_COLLECTION.value, ] ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.DENSE.value, + ], + ), apply_optimizer_in_backward_config=st.sampled_from([None]), # TODO - need to enable optimizer overlapped behavior for data_parallel tables ) @@ -213,12 +225,12 @@ def test_sharding_rw( def test_sharding_dp( self, sharder_type: str, + kernel_type: str, apply_optimizer_in_backward_config: Optional[ Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] ], ) -> None: sharding_type = ShardingType.DATA_PARALLEL.value - kernel_type = EmbeddingComputeKernel.DENSE.value self._test_sharding( # pyre-ignore[6] sharders=[ @@ -236,6 +248,13 @@ def test_sharding_dp( SharderType.EMBEDDING_BAG_COLLECTION.value, ] ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ), qcomms_config=st.sampled_from( [ None, @@ -259,14 +278,20 @@ def test_sharding_dp( def test_sharding_cw( self, sharder_type: str, + kernel_type: str, qcomms_config: Optional[QCommsConfig], apply_optimizer_in_backward_config: Optional[ Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] ], variable_batch_size: bool, ) -> None: + if ( + self.device == torch.device("cpu") + and kernel_type != EmbeddingComputeKernel.FUSED.value + ): + self.skipTest("CPU does not support uvm.") + sharding_type = ShardingType.COLUMN_WISE.value - kernel_type = EmbeddingComputeKernel.FUSED.value assume( sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value or not variable_batch_size @@ -300,6 +325,90 @@ def test_sharding_cw( SharderType.EMBEDDING_BAG_COLLECTION.value, ] ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ), + qcomms_config=st.sampled_from( + [ + None, + QCommsConfig( + forward_precision=CommType.FP16, backward_precision=CommType.BF16 + ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + None, + { + "embeddingbags": (torch.optim.SGD, {"lr": 0.01}), + "embeddings": (torch.optim.SGD, {"lr": 0.2}), + }, + ] + ), + variable_batch_size=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None) + def test_sharding_twcw( + self, + sharder_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + variable_batch_size: bool, + ) -> None: + if ( + self.device == torch.device("cpu") + and kernel_type != EmbeddingComputeKernel.FUSED.value + ): + self.skipTest("CPU does not support uvm.") + + sharding_type = ShardingType.TABLE_COLUMN_WISE.value + assume( + sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value + or not variable_batch_size + ) + self._test_sharding( + # pyre-ignore[6] + sharders=[ + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + qcomms_config=qcomms_config, + device=self.device, + ), + ], + backend=self.backend, + qcomms_config=qcomms_config, + constraints={ + table.name: ParameterConstraints(min_partition=4) + for table in self.tables + }, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + variable_batch_size=variable_batch_size, + ) + + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + # SharderType.EMBEDDING_BAG.value, + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ), qcomms_config=st.sampled_from( [ # None, @@ -324,14 +433,97 @@ def test_sharding_cw( def test_sharding_tw( self, sharder_type: str, + kernel_type: str, qcomms_config: Optional[QCommsConfig], apply_optimizer_in_backward_config: Optional[ Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] ], variable_batch_size: bool, ) -> None: + if ( + self.device == torch.device("cpu") + and kernel_type != EmbeddingComputeKernel.FUSED.value + ): + self.skipTest("CPU does not support uvm.") + sharding_type = ShardingType.TABLE_WISE.value - kernel_type = EmbeddingComputeKernel.FUSED.value + assume( + sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value + or not variable_batch_size + ) + self._test_sharding( + # pyre-ignore[6] + sharders=[ + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + qcomms_config=qcomms_config, + device=self.device, + ), + ], + backend=self.backend, + qcomms_config=qcomms_config, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + variable_batch_size=variable_batch_size, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + # SharderType.EMBEDDING_BAG.value, + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ), + qcomms_config=st.sampled_from( + [ + # None, + QCommsConfig( + forward_precision=CommType.FP16, + backward_precision=CommType.BF16, + ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + None, + { + "embeddingbags": (torch.optim.SGD, {"lr": 0.01}), + "embeddings": (torch.optim.SGD, {"lr": 0.2}), + }, + ] + ), + variable_batch_size=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None) + def test_sharding_twrw( + self, + sharder_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + variable_batch_size: bool, + ) -> None: + if self.backend == "gloo": + self.skipTest( + "Gloo reduce_scatter_base fallback not supported with async_op=True" + ) + + sharding_type = ShardingType.TABLE_ROW_WISE.value assume( sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value or not variable_batch_size @@ -364,6 +556,8 @@ def test_sharding_tw( ShardingType.TABLE_WISE.value, ShardingType.COLUMN_WISE.value, ShardingType.ROW_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + ShardingType.TABLE_COLUMN_WISE.value, ] ), global_constant_batch=st.booleans(), diff --git a/torchrec/distributed/test_utils/test_model_parallel_base.py b/torchrec/distributed/test_utils/test_model_parallel_base.py index a27035e26..4063def4f 100644 --- a/torchrec/distributed/test_utils/test_model_parallel_base.py +++ b/torchrec/distributed/test_utils/test_model_parallel_base.py @@ -23,10 +23,7 @@ EmbeddingComputeKernel, EmbeddingTableConfig, ) -from torchrec.distributed.embeddingbag import ( - EmbeddingBagCollectionSharder, - ShardedEmbeddingBagCollection, -) +from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection from torchrec.distributed.fused_embeddingbag import ShardedFusedEmbeddingBagCollection from torchrec.distributed.model_parallel import DistributedModelParallel from torchrec.distributed.planner import ( @@ -469,13 +466,18 @@ def test_meta_device_dmp_state_dict(self) -> None: ), sharding_type=st.sampled_from( [ + ShardingType.TABLE_WISE.value, ShardingType.COLUMN_WISE.value, + ShardingType.ROW_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + ShardingType.TABLE_COLUMN_WISE.value, ] ), kernel_type=st.sampled_from( [ EmbeddingComputeKernel.FUSED.value, EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, ] ), ) @@ -587,17 +589,44 @@ def test_load_state_dict_dp( # pyre-ignore[56] @given( - sharders=st.sampled_from( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + sharding_type=st.sampled_from( [ - [EmbeddingBagCollectionSharder()], - # [EmbeddingBagSharder()], + ShardingType.TABLE_WISE.value, + ShardingType.COLUMN_WISE.value, + ShardingType.ROW_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + ShardingType.TABLE_COLUMN_WISE.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, ] ), ) @settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None) def test_load_state_dict_prefix( - self, sharders: List[ModuleSharder[nn.Module]] + self, sharder_type: str, sharding_type: str, kernel_type: str ) -> None: + if ( + self.device == torch.device("cpu") + and kernel_type != EmbeddingComputeKernel.FUSED.value + ): + self.skipTest("CPU does not support uvm.") + + sharders = [ + cast( + ModuleSharder[nn.Module], + create_test_sharder(sharder_type, sharding_type, kernel_type), + ), + ] (m1, m2), batch = self._generate_dmps_and_batch(sharders) # load the second's (m2's) with the first (m1's) state_dict @@ -632,19 +661,31 @@ def test_load_state_dict_prefix( sharding_type=st.sampled_from( [ ShardingType.TABLE_WISE.value, + ShardingType.COLUMN_WISE.value, + ShardingType.ROW_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + ShardingType.TABLE_COLUMN_WISE.value, ] ), kernel_type=st.sampled_from( [ # EmbeddingComputeKernel.DENSE.value, EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, ] ), ) - @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) def test_params_and_buffers( self, sharder_type: str, sharding_type: str, kernel_type: str ) -> None: + if ( + self.device == torch.device("cpu") + and kernel_type != EmbeddingComputeKernel.FUSED.value + ): + self.skipTest("CPU does not support uvm.") + sharders = [ create_test_sharder(sharder_type, sharding_type, kernel_type), ] @@ -671,13 +712,21 @@ def test_params_and_buffers( kernel_type=st.sampled_from( [ EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, ] ), ) - @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) def test_load_state_dict_cw_multiple_shards( self, sharder_type: str, sharding_type: str, kernel_type: str ) -> None: + if ( + self.device == torch.device("cpu") + and kernel_type != EmbeddingComputeKernel.FUSED.value + ): + self.skipTest("CPU does not support uvm.") + sharders = [ cast( ModuleSharder[nn.Module],