From 873faf34c4eb3371c0d6848a449fedebb25d7022 Mon Sep 17 00:00:00 2001 From: Qiang Zhang Date: Sat, 18 Nov 2023 04:59:48 -0800 Subject: [PATCH] RW Dist change to support uneven sharding (#1525) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/1525 X-link: https://github.com/pytorch/FBGEMM/pull/2142 As titled Reviewed By: dstaay-fb Differential Revision: D51053207 fbshipit-source-id: 8aee7d967ceec9ea0739f5d2c56d6e541a0d3648 --- torchrec/distributed/embedding_sharding.py | 3 + torchrec/distributed/sharding/rw_sharding.py | 82 ++++++++- .../distributed/tests/test_infer_shardings.py | 156 +++++++++++++++++- 3 files changed, 226 insertions(+), 15 deletions(-) diff --git a/torchrec/distributed/embedding_sharding.py b/torchrec/distributed/embedding_sharding.py index 60c6d51f3..ce418f9c5 100644 --- a/torchrec/distributed/embedding_sharding.py +++ b/torchrec/distributed/embedding_sharding.py @@ -63,6 +63,7 @@ def bucketize_kjt_before_all2all( block_sizes: torch.Tensor, output_permute: bool = False, bucketize_pos: bool = False, + block_bucketize_row_pos: Optional[List[torch.Tensor]] = None, ) -> Tuple[KeyedJaggedTensor, Optional[torch.Tensor]]: """ Bucketizes the `values` in KeyedJaggedTensor into `num_buckets` buckets, @@ -78,6 +79,7 @@ def bucketize_kjt_before_all2all( values to bucketized values or not. bucketize_pos (bool): output the changed position of the bucketized values or not. + block_bucketize_row_pos (Optional[List[torch.Tensor]]): The offsets of shard size for each feature. Returns: Tuple[KeyedJaggedTensor, Optional[torch.Tensor]]: the bucketized `KeyedJaggedTensor` and the optional permute mapping from the unbucketized values to bucketized value. @@ -103,6 +105,7 @@ def bucketize_kjt_before_all2all( block_sizes=block_sizes_new_type, my_size=num_buckets, weights=kjt.weights_or_none(), + block_bucketize_pos=block_bucketize_row_pos, ) return ( diff --git a/torchrec/distributed/sharding/rw_sharding.py b/torchrec/distributed/sharding/rw_sharding.py index 17d64bfc9..4f1910e98 100644 --- a/torchrec/distributed/sharding/rw_sharding.py +++ b/torchrec/distributed/sharding/rw_sharding.py @@ -5,7 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Dict, List, Optional, TypeVar +import math +from typing import Any, Dict, List, Optional, Tuple, TypeVar import torch import torch.distributed as dist @@ -55,6 +56,40 @@ W = TypeVar("W") +def get_embedding_shard_metadata( + grouped_embedding_configs_per_rank: List[List[GroupedEmbeddingConfig]], +) -> Tuple[List[List[int]], bool]: + is_even_sharding: bool = True + world_size = len(grouped_embedding_configs_per_rank) + + def get_even_shard_sizes(hash_size: int, world_size: int) -> List[int]: + block_size: int = math.ceil(hash_size / world_size) + last_rank: int = hash_size // block_size + + expected_even_shard_sizes = [block_size] * last_rank + if hash_size % world_size != 0: + expected_even_shard_sizes.append(hash_size - sum(expected_even_shard_sizes)) + return expected_even_shard_sizes + + embed_sharding = [] + for table in grouped_embedding_configs_per_rank[0][0].embedding_tables: + embed_sharding_per_feature = [] + total_rows = 0 + sizes = [] + # pyre-ignore [16]: `Optional` has no attribute `shards_metadata` + for metadata in table.global_metadata.shards_metadata: + embed_sharding_per_feature.append(metadata.shard_offsets[0]) + total_rows += metadata.shard_sizes[0] + sizes.append(metadata.shard_sizes[0]) + embed_sharding_per_feature.append(total_rows) + embed_sharding.extend([embed_sharding_per_feature] * len(table.embedding_names)) + expected_even_sizes = get_even_shard_sizes(total_rows, world_size) + if sizes != expected_even_sizes: + is_even_sharding = False + + return (embed_sharding, is_even_sharding) + + class BaseRwEmbeddingSharding(EmbeddingSharding[C, F, T, W]): """ Base class for row-wise sharding. @@ -428,14 +463,27 @@ def create_output_dist( def get_block_sizes_runtime_device( block_sizes: List[int], runtime_device: torch.device, - tensor_cache: Dict[str, torch.Tensor], -) -> torch.Tensor: + tensor_cache: Dict[str, Tuple[torch.Tensor, List[torch.Tensor]]], + embedding_shard_metadata: Optional[List[List[int]]], +) -> Tuple[torch.Tensor, List[torch.Tensor]]: cache_key: str = "__block_sizes" if cache_key not in tensor_cache: - tensor_cache[cache_key] = torch.tensor( - block_sizes, - device=runtime_device, - dtype=torch.int32, + tensor_cache[cache_key] = ( + torch.tensor( + block_sizes, + device=runtime_device, + dtype=torch.int32, + ), + [] + if embedding_shard_metadata is None + else [ + torch.tensor( + row_pos, + device=runtime_device, + dtype=torch.int32, + ) + for row_pos in embedding_shard_metadata + ], ) return tensor_cache[cache_key] @@ -451,6 +499,7 @@ def __init__( is_sequence: bool = False, has_feature_processor: bool = False, need_pos: bool = False, + embedding_shard_metadata: Optional[List[List[int]]] = None, ) -> None: super().__init__() self._world_size: int = world_size @@ -459,7 +508,10 @@ def __init__( (hash_size + self._world_size - 1) // self._world_size for hash_size in feature_hash_sizes ] - self.tensor_cache: Dict[str, torch.Tensor] = {} + self.tensor_cache: Dict[ + str, Tuple[torch.Tensor, Optional[List[torch.Tensor]]] + ] = {} + self._dist = KJTOneToAll( splits=self._world_size * [self._num_features], world_size=world_size, @@ -470,14 +522,19 @@ def __init__( self._need_pos = need_pos self.unbucketize_permute_tensor: Optional[torch.Tensor] = None + self._embedding_shard_metadata: Optional[ + List[List[int]] + ] = embedding_shard_metadata + def forward( self, sparse_features: KeyedJaggedTensor, ) -> KJTList: - block_sizes = get_block_sizes_runtime_device( + block_sizes, block_bucketize_row_pos = get_block_sizes_runtime_device( self.feature_block_sizes, sparse_features.device(), self.tensor_cache, + self._embedding_shard_metadata, ) ( bucketized_features, @@ -490,6 +547,7 @@ def forward( bucketize_pos=self._has_feature_processor if sparse_features.weights_or_none() is None else self._need_pos, + block_bucketize_row_pos=block_bucketize_row_pos, ) return self._dist.forward(bucketized_features) @@ -505,11 +563,17 @@ def create_input_dist( ) -> BaseSparseFeaturesDist[KJTList]: num_features = self._get_num_features() feature_hash_sizes = self._get_feature_hash_sizes() + + (embed_sharding, is_even_sharding) = get_embedding_shard_metadata( + self._grouped_embedding_configs_per_rank + ) + return InferRwSparseFeaturesDist( world_size=self._world_size, num_features=num_features, feature_hash_sizes=feature_hash_sizes, device=device if device is not None else self._device, + embedding_shard_metadata=embed_sharding if not is_even_sharding else None, ) def create_lookup( diff --git a/torchrec/distributed/tests/test_infer_shardings.py b/torchrec/distributed/tests/test_infer_shardings.py index d24371580..ca17f42cb 100755 --- a/torchrec/distributed/tests/test_infer_shardings.py +++ b/torchrec/distributed/tests/test_infer_shardings.py @@ -49,6 +49,13 @@ from torchrec.fx import symbolic_trace +def placement_helper(device_type: str, index: int = 0) -> str: + if device_type == "cpu": + return f"rank:0/{device_type}" # cpu only use rank 0 + + return f"rank:{index}/{device_type}:{index}" + + class InferShardingsTest(unittest.TestCase): @unittest.skipIf( torch.cuda.device_count() <= 1, @@ -702,6 +709,7 @@ def test_rw_sequence(self, weight_dtype: torch.dtype) -> None: ShardingType.ROW_WISE.value, ) + @unittest.skip("Need D51309697 before turn on") @unittest.skipIf( torch.cuda.device_count() <= 2, "Not enough GPUs available", @@ -728,7 +736,7 @@ def test_rw_uneven_sharding( local_size = 3 world_size = 3 batch_size = 4 - local_device = torch.device("cuda:0") + local_device = torch.device("cpu") mi = create_test_model( num_embeddings, emb_dim, @@ -792,11 +800,147 @@ def test_rw_uneven_sharding( sharded_model.load_state_dict(non_sharded_model.state_dict()) # We need this first inference to make all lazy init in forward - _ = sharded_model(*inputs[0]) - _ = non_sharded_model(*inputs[0]) - # TODO (drqiangzhang): Add comparison between sharded and unsharded model outputs + sharded_output = sharded_model(*inputs[0]) + non_sharded_output = non_sharded_model(*inputs[0]) + assert_close(non_sharded_output, sharded_output) + + gm: torch.fx.GraphModule = symbolic_trace(sharded_model) + gm_script = torch.jit.script(gm) + gm_script_output = gm_script(*inputs[0]) + assert_close(sharded_output, gm_script_output) + + @unittest.skip("Need D51309697 before turn on") + @unittest.skipIf( + torch.cuda.device_count() <= 3, + "Not enough GPUs available", + ) + # pyre-fixme[56]Pyre was not able to infer the type of argument `hypothesis.strategies.booleans()` to decorator factory `hypothesis.given`. + @given( + weight_dtype=st.sampled_from([torch.qint8, torch.quint4x2]), + ) + @settings(max_examples=4, deadline=None) + def test_rw_uneven_sharding_mutiple_table( + self, + weight_dtype: torch.dtype, + ) -> None: + num_embeddings = 512 + emb_dim = 64 + local_size = 4 + world_size = 4 + batch_size = 1 + local_device = torch.device("cpu") + mi = create_test_model( + num_embeddings, + emb_dim, + world_size, + batch_size, + dense_device=local_device, + sparse_device=local_device, + quant_state_dict_split_scale_bias=True, + weight_dtype=weight_dtype, + num_features=3, + ) + + non_sharded_model = mi.quant_model + expected_shards = [ + [ + ( + (0, 0, 256, 64), + placement_helper("cpu", 0), + ), + ( + (256, 0, 128, 64), + placement_helper("cpu", 0), + ), + ( + (384, 0, 64, 64), + placement_helper("cpu", 0), + ), + ( + (448, 0, 64, 64), + placement_helper("cpu", 0), + ), + ], + [ + ( + (0, 0, 128, 64), + placement_helper("cpu", 0), + ), + ( + (128, 0, 128, 64), + placement_helper("cpu", 0), + ), + ( + (256, 0, 128, 64), + placement_helper("cpu", 0), + ), + ( + (384, 0, 128, 64), + placement_helper("cpu", 0), + ), + ], + [ + ( + (0, 0, 256, 64), + placement_helper("cpu", 0), + ), + ( + (256, 0, 128, 64), + placement_helper("cpu", 0), + ), + ( + (384, 0, 128, 64), + placement_helper("cpu", 0), + ), + ( + (512, 0, 0, 64), + placement_helper("cpu", 0), + ), + ], + ] + + sharder = TestQuantEBCSharder( + sharding_type=ShardingType.ROW_WISE.value, + kernel_type=EmbeddingComputeKernel.QUANT.value, + shardable_params=[table.name for table in mi.tables], + ) + + module_plan = construct_module_sharding_plan( + non_sharded_model._module.sparse.ebc, + per_param_sharding={ + "table_0": row_wise( + ([256, 128, 64, 64], "cpu"), + ), + "table_1": row_wise(([128, 128, 128, 128], "cpu")), + "table_2": row_wise(([256, 128, 128, 0], "cpu")), + }, + # pyre-ignore + sharder=sharder, + local_size=local_size, + world_size=world_size, + ) + + plan = ShardingPlan(plan={"_module.sparse.ebc": module_plan}) + + sharded_model = shard_qebc( + mi=mi, + sharding_type=ShardingType.ROW_WISE, + device=local_device, + expected_shards=expected_shards, + plan=plan, + ) + inputs = [ + model_input_to_forward_args(inp.to(local_device)) + for inp in prep_inputs(mi, world_size, batch_size, long_indices=False) + ] + sharded_model.load_state_dict(non_sharded_model.state_dict()) + + # We need this first inference to make all lazy init in forward + sharded_output = sharded_model(*inputs[0]) + non_sharded_output = non_sharded_model(*inputs[0]) + assert_close(non_sharded_output, sharded_output) gm: torch.fx.GraphModule = symbolic_trace(sharded_model) gm_script = torch.jit.script(gm) - _ = gm_script(*inputs[0]) - # TODO (drqiangzhang): Add comparison between scripted and nonscripted model outputs + gm_script_output = gm_script(*inputs[0]) + assert_close(sharded_output, gm_script_output)