diff --git a/torchrec/distributed/embedding_sharding.py b/torchrec/distributed/embedding_sharding.py index 3d3c756bd..60c6d51f3 100644 --- a/torchrec/distributed/embedding_sharding.py +++ b/torchrec/distributed/embedding_sharding.py @@ -40,11 +40,7 @@ ShardMetadata, ) from torchrec.fx.utils import assert_fx_safe -from torchrec.modules.embedding_configs import ( - DataType, - EmbeddingTableConfig, - PoolingType, -) +from torchrec.modules.embedding_configs import EmbeddingTableConfig from torchrec.sparse.jagged_tensor import KeyedJaggedTensor from torchrec.streamable import Multistreamable @@ -172,6 +168,26 @@ def _get_grouping_fused_params( return grouping_fused_params +def _get_compute_kernel_type( + compute_kernel: EmbeddingComputeKernel, +) -> EmbeddingComputeKernel: + """ + Return the compute kernel type for the given compute kernel. + """ + compute_kernel_type = compute_kernel + if compute_kernel_type in [ + EmbeddingComputeKernel.FUSED_UVM, + EmbeddingComputeKernel.FUSED_UVM_CACHING, + ]: + compute_kernel_type = EmbeddingComputeKernel.FUSED + elif compute_kernel_type in [ + EmbeddingComputeKernel.QUANT_UVM, + EmbeddingComputeKernel.QUANT_UVM_CACHING, + ]: + compute_kernel_type = EmbeddingComputeKernel.QUANT + return compute_kernel_type + + # group tables by `DataType`, `PoolingType`, and `EmbeddingComputeKernel`. def group_tables( tables_per_rank: List[List[ShardedEmbeddingTable]], @@ -192,22 +208,6 @@ def _group_tables_per_rank( ) -> List[GroupedEmbeddingConfig]: grouped_embedding_configs: List[GroupedEmbeddingConfig] = [] - # populate grouping fused params - fused_params_groups = [] - for table in embedding_tables: - if table.fused_params is None: - table.fused_params = {} - - grouping_fused_params = _get_grouping_fused_params(table.fused_params) - if grouping_fused_params not in fused_params_groups: - fused_params_groups.append(grouping_fused_params) - - compute_kernels = [ - EmbeddingComputeKernel.DENSE, - EmbeddingComputeKernel.FUSED, - EmbeddingComputeKernel.QUANT, - ] - emb_dim_bucketer_policy = ( EmbDimBucketerPolicy.ALL_BUCKETS if should_do_dim_bucketing(embedding_tables) @@ -216,88 +216,83 @@ def _group_tables_per_rank( emb_dim_bucketer = EmbDimBucketer(embedding_tables, emb_dim_bucketer_policy) logging.info(f"bucket count {emb_dim_bucketer.bucket_count()}") - for data_type in DataType: - for pooling in PoolingType: - # remove this when finishing migration - for has_feature_processor in [False, True]: - for fused_params_group in fused_params_groups: - for compute_kernel in compute_kernels: - for dim_bucket in range(emb_dim_bucketer.bucket_count()): - grouped_tables: List[ShardedEmbeddingTable] = [] - is_weighted = False - for table in embedding_tables: - compute_kernel_type = table.compute_kernel - is_weighted = table.is_weighted - if table.compute_kernel in [ - EmbeddingComputeKernel.FUSED_UVM, - EmbeddingComputeKernel.FUSED_UVM_CACHING, - ]: - compute_kernel_type = ( - EmbeddingComputeKernel.FUSED - ) - elif table.compute_kernel in [ - EmbeddingComputeKernel.QUANT_UVM, - EmbeddingComputeKernel.QUANT_UVM_CACHING, - ]: - compute_kernel_type = ( - EmbeddingComputeKernel.QUANT - ) - - if ( - table.data_type == data_type - and table.pooling.value == pooling.value - and table.has_feature_processor - == has_feature_processor - and compute_kernel_type == compute_kernel - and fused_params_group - == _get_grouping_fused_params( - table.fused_params - ) - and ( - emb_dim_bucketer.get_bucket( - table.embedding_dim, table.data_type - ) - == dim_bucket - ) - ): - grouped_tables.append(table) - - if fused_params_group is None: - fused_params_group = {} - - if grouped_tables: - logging.info( - f"{len(grouped_tables)} tables are grouped for bucket: {dim_bucket}." - ) - cache_load_factor = ( - _get_weighted_avg_cache_load_factor( - grouped_tables - ) - ) - per_tbe_fused_params = copy.copy(fused_params_group) - if cache_load_factor is not None: - per_tbe_fused_params[ - CACHE_LOAD_FACTOR_STR - ] = cache_load_factor - - grouped_embedding_configs.append( - GroupedEmbeddingConfig( - data_type=data_type, - pooling=pooling, - is_weighted=is_weighted, - has_feature_processor=has_feature_processor, - compute_kernel=compute_kernel, - embedding_tables=grouped_tables, - fused_params={ - k: v - for k, v in per_tbe_fused_params.items() - if k - not in [ - "_batch_key" - ] # drop '_batch_key' not a native fused param - }, - ) - ) + # populate grouping keys + grouping_keys = [] + for table in embedding_tables: + if table.fused_params is None: + table.fused_params = {} + + grouping_key = ( + table.data_type, + table.pooling, + table.has_feature_processor, + _get_grouping_fused_params(table.fused_params), + _get_compute_kernel_type(table.compute_kernel), + emb_dim_bucketer.get_bucket(table.embedding_dim, table.data_type), + ) + if grouping_key not in grouping_keys: + grouping_keys.append(grouping_key) + + for grouping_key in grouping_keys: + ( + data_type, + pooling, + has_feature_processor, + fused_params_group, + compute_kernel_type, + dim_bucket, + ) = grouping_key + grouped_tables: List[ShardedEmbeddingTable] = [] + is_weighted = False + for table in embedding_tables: + is_weighted = table.is_weighted + if ( + table.data_type == data_type + and table.pooling.value == pooling.value + and table.has_feature_processor == has_feature_processor + and fused_params_group + == _get_grouping_fused_params(table.fused_params) + and compute_kernel_type + == _get_compute_kernel_type(table.compute_kernel) + and ( + emb_dim_bucketer.get_bucket( + table.embedding_dim, table.data_type + ) + == dim_bucket + ) + ): + grouped_tables.append(table) + + if fused_params_group is None: + fused_params_group = {} + + if grouped_tables: + logging.info( + f"{len(grouped_tables)} tables are grouped for bucket: {dim_bucket}." + ) + cache_load_factor = _get_weighted_avg_cache_load_factor(grouped_tables) + per_tbe_fused_params = copy.copy(fused_params_group) + if cache_load_factor is not None: + per_tbe_fused_params[CACHE_LOAD_FACTOR_STR] = cache_load_factor + + grouped_embedding_configs.append( + GroupedEmbeddingConfig( + data_type=data_type, + pooling=pooling, + is_weighted=is_weighted, + has_feature_processor=has_feature_processor, + compute_kernel=compute_kernel_type, + embedding_tables=grouped_tables, + fused_params={ + k: v + for k, v in per_tbe_fused_params.items() + if k + not in [ + "_batch_key" + ] # drop '_batch_key' not a native fused param + }, + ) + ) return grouped_embedding_configs table_weightedness = [ diff --git a/torchrec/distributed/embedding_types.py b/torchrec/distributed/embedding_types.py index aeb726b65..2ea4241a4 100644 --- a/torchrec/distributed/embedding_types.py +++ b/torchrec/distributed/embedding_types.py @@ -199,6 +199,12 @@ def dim_sum(self) -> int: dim_sum += table.num_features() * table.local_cols return dim_sum + def table_names(self) -> List[str]: + table_names = [] + for table in self.embedding_tables: + table_names.append(table.name) + return table_names + def feature_names(self) -> List[str]: feature_names = [] for table in self.embedding_tables: diff --git a/torchrec/distributed/tests/test_embedding_sharding.py b/torchrec/distributed/tests/test_embedding_sharding.py index d047fcc1e..a95995e6d 100644 --- a/torchrec/distributed/tests/test_embedding_sharding.py +++ b/torchrec/distributed/tests/test_embedding_sharding.py @@ -8,7 +8,7 @@ import random import unittest -from typing import List +from typing import Any, Dict, List from unittest.mock import MagicMock import hypothesis.strategies as st @@ -18,10 +18,12 @@ from torchrec.distributed.embedding_lookup import EmbeddingComputeKernel from torchrec.distributed.embedding_sharding import ( + _get_compute_kernel_type, _get_grouping_fused_params, _get_weighted_avg_cache_load_factor, group_tables, ) + from torchrec.distributed.embedding_types import ( GroupedEmbeddingConfig, ShardedEmbeddingTable, @@ -168,3 +170,222 @@ def test_per_tbe_clf_weighted_average( table_group = table_groups[0] self.assertIsNotNone(table_group.fused_params) self.assertEqual(table_group.fused_params.get("cache_load_factor"), 0.35) + + +def _get_table_names_by_groups( + embedding_tables: List[ShardedEmbeddingTable], +) -> List[List[str]]: + # since we don't have access to _group_tables_per_rank + tables_per_rank: List[List[ShardedEmbeddingTable]] = [embedding_tables] + + # taking only the list for the first rank + table_groups: List[GroupedEmbeddingConfig] = group_tables(tables_per_rank)[0] + return [table_group.table_names() for table_group in table_groups] + + +class TestGroupTablesPerRank(unittest.TestCase): + # pyre-ignore[56] + @given( + data_type=st.sampled_from([DataType.FP16, DataType.FP32]), + has_feature_processor=st.sampled_from([False, True]), + fused_params_group=st.sampled_from( + [ + { + "cache_load_factor": 0.5, + "prefetch_pipeline": False, + }, + { + "cache_load_factor": 0.3, + "prefetch_pipeline": True, + }, + ] + ), + embedding_dim=st.sampled_from(list(range(160, 320, 40))), + pooling_type=st.sampled_from(list(PoolingType)), + compute_kernel=st.sampled_from(list(EmbeddingComputeKernel)), + ) + @settings(max_examples=10, deadline=10000) + def test_should_group_together( + self, + data_type: DataType, + has_feature_processor: bool, + fused_params_group: Dict[str, Any], + embedding_dim: int, + pooling_type: PoolingType, + compute_kernel: EmbeddingComputeKernel, + ) -> None: + tables = [ + ShardedEmbeddingTable( + name=f"table_{i}", + data_type=data_type, + pooling=pooling_type, + has_feature_processor=has_feature_processor, + fused_params=fused_params_group, + compute_kernel=compute_kernel, + embedding_dim=embedding_dim, + num_embeddings=10000, + ) + for i in range(2) + ] + + expected_table_names_by_groups = [["table_0", "table_1"]] + self.assertEqual( + _get_table_names_by_groups(tables), + expected_table_names_by_groups, + ) + + # pyre-ignore[56] + @given( + data_type=st.sampled_from([DataType.FP16, DataType.FP32]), + has_feature_processor=st.sampled_from([False, True]), + embedding_dim=st.sampled_from(list(range(160, 320, 40))), + pooling_type=st.sampled_from(list(PoolingType)), + compute_kernel=st.sampled_from(list(EmbeddingComputeKernel)), + ) + @settings(max_examples=10, deadline=10000) + def test_should_group_together_with_prefetch( + self, + data_type: DataType, + has_feature_processor: bool, + embedding_dim: int, + pooling_type: PoolingType, + compute_kernel: EmbeddingComputeKernel, + ) -> None: + fused_params_groups = [ + { + "cache_load_factor": 0.3, + "prefetch_pipeline": True, + }, + { + "cache_load_factor": 0.5, + "prefetch_pipeline": True, + }, + ] + tables = [ + ShardedEmbeddingTable( + name=f"table_{i}", + data_type=data_type, + pooling=pooling_type, + has_feature_processor=has_feature_processor, + fused_params=fused_params_groups[i], + compute_kernel=compute_kernel, + embedding_dim=embedding_dim, + num_embeddings=10000, + ) + for i in range(2) + ] + + expected_table_names_by_groups = [["table_0", "table_1"]] + self.assertEqual( + _get_table_names_by_groups(tables), + expected_table_names_by_groups, + ) + + # pyre-ignore[56] + @given( + data_types=st.lists( + st.sampled_from([DataType.FP16, DataType.FP32]), + min_size=2, + max_size=2, + unique=True, + ), + has_feature_processors=st.lists( + st.sampled_from([False, True]), min_size=2, max_size=2, unique=True + ), + fused_params_group=st.sampled_from( + [ + { + "cache_load_factor": 0.5, + "prefetch_pipeline": True, + }, + { + "cache_load_factor": 0.3, + "prefetch_pipeline": True, + }, + ], + ), + embedding_dims=st.lists( + st.sampled_from(list(range(160, 320, 40))), + min_size=2, + max_size=2, + unique=True, + ), + pooling_types=st.lists( + st.sampled_from(list(PoolingType)), min_size=2, max_size=2, unique=True + ), + compute_kernels=st.lists( + st.sampled_from(list(EmbeddingComputeKernel)), + min_size=2, + max_size=2, + unique=True, + ), + distinct_key=st.sampled_from( + [ + "data_type", + "has_feature_processor", + "embedding_dim", + "pooling_type", + "compute_kernel", + ] + ), + ) + @settings(max_examples=10, deadline=10000) + def test_should_not_group_together( + self, + data_types: List[DataType], + has_feature_processors: List[bool], + fused_params_group: Dict[str, Any], + embedding_dims: List[int], + pooling_types: List[PoolingType], + compute_kernels: List[EmbeddingComputeKernel], + distinct_key: str, + ) -> None: + tables = [ + ShardedEmbeddingTable( + name=f"table_{i}", + data_type=data_types[i] + if distinct_key == "data_type" + else data_types[0], + pooling=pooling_types[i] + if distinct_key == "pooling_type" + else pooling_types[0], + has_feature_processor=has_feature_processors[i] + if distinct_key == "has_feature_processor" + else has_feature_processors[0], + fused_params=fused_params_group, # can't hash dicts + compute_kernel=compute_kernels[i] + if distinct_key == "compute_kernel" + else compute_kernels[0], + embedding_dim=embedding_dims[i] + if distinct_key == "embedding_dim" + else embedding_dims[0], + num_embeddings=10000, + ) + for i in range(2) + ] + + if distinct_key == "compute_kernel" and _get_compute_kernel_type( + compute_kernels[0] + ) == _get_compute_kernel_type(compute_kernels[1]): + self.assertEqual( + _get_table_names_by_groups(tables), + [["table_0", "table_1"]], + ) + return + + # emb dim bucketizier only in use when computer kernel is uvm caching + # and prefetch pipeline is True + if ( + distinct_key == "embedding_dim" + and compute_kernels[0] != EmbeddingComputeKernel.FUSED_UVM_CACHING + ): + self.assertEqual( + _get_table_names_by_groups(tables), + [["table_0", "table_1"]], + ) + return + + self.assertEqual( + sorted(_get_table_names_by_groups(tables)), + [["table_0"], ["table_1"]], + )