From ef3d9329e3bccd40a9e87b52323b7471372dfc40 Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Wed, 15 Nov 2023 20:24:03 -0800 Subject: [PATCH] get per tbe cache_load_factor (#1461) Summary: We always want to create less TBEs by grouping multiple tables together. The problem here is tables could have different cache_load_factor. Hence what we do is to calculate a per tbe cache_load_factor, by calculating the weighted average of the CLFs. Differential Revision: D49070469 --- torchrec/distributed/embedding_sharding.py | 75 +++++++- .../tests/test_embedding_sharding.py | 170 ++++++++++++++++++ 2 files changed, 240 insertions(+), 5 deletions(-) create mode 100644 torchrec/distributed/tests/test_embedding_sharding.py diff --git a/torchrec/distributed/embedding_sharding.py b/torchrec/distributed/embedding_sharding.py index 857e4cb68..3d3c756bd 100644 --- a/torchrec/distributed/embedding_sharding.py +++ b/torchrec/distributed/embedding_sharding.py @@ -6,6 +6,7 @@ # LICENSE file in the root directory of this source tree. import abc +import copy import logging from dataclasses import dataclass, field from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union @@ -49,6 +50,8 @@ torch.fx.wrap("len") +CACHE_LOAD_FACTOR_STR: str = "cache_load_factor" + # torch.Tensor.to can not be fx symbolic traced as it does not go through __torch_dispatch__ => fx.wrap it @torch.fx.wrap @@ -123,6 +126,52 @@ def bucketize_kjt_before_all2all( ) +def _get_weighted_avg_cache_load_factor( + embedding_tables: List[ShardedEmbeddingTable], +) -> Optional[float]: + """ + Calculate the weighted average cache load factor of all tables. The cache + load factors are weighted by the hash size of each table. + """ + cache_load_factor_sum: float = 0.0 + weight: int = 0 + + for table in embedding_tables: + if ( + table.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING + and table.fused_params + and CACHE_LOAD_FACTOR_STR in table.fused_params + ): + cache_load_factor_sum += ( + table.fused_params[CACHE_LOAD_FACTOR_STR] * table.num_embeddings + ) + weight += table.num_embeddings + + # if no fused_uvm_caching tables, return default cache load factor + if weight == 0: + return None + + return cache_load_factor_sum / weight + + +def _get_grouping_fused_params( + fused_params: Optional[Dict[str, Any]] +) -> Optional[Dict[str, Any]]: + """ + Only shallow copy the fused params we need for grouping tables into TBEs. In + particular, we do not copy cache_load_factor. + """ + grouping_fused_params: Optional[Dict[str, Any]] = copy.copy(fused_params) + + if not grouping_fused_params: + return grouping_fused_params + + if CACHE_LOAD_FACTOR_STR in grouping_fused_params: + del grouping_fused_params[CACHE_LOAD_FACTOR_STR] + + return grouping_fused_params + + # group tables by `DataType`, `PoolingType`, and `EmbeddingComputeKernel`. def group_tables( tables_per_rank: List[List[ShardedEmbeddingTable]], @@ -143,13 +192,15 @@ def _group_tables_per_rank( ) -> List[GroupedEmbeddingConfig]: grouped_embedding_configs: List[GroupedEmbeddingConfig] = [] - # add fused params: + # populate grouping fused params fused_params_groups = [] for table in embedding_tables: if table.fused_params is None: table.fused_params = {} - if table.fused_params not in fused_params_groups: - fused_params_groups.append(table.fused_params) + + grouping_fused_params = _get_grouping_fused_params(table.fused_params) + if grouping_fused_params not in fused_params_groups: + fused_params_groups.append(grouping_fused_params) compute_kernels = [ EmbeddingComputeKernel.DENSE, @@ -198,7 +249,10 @@ def _group_tables_per_rank( and table.has_feature_processor == has_feature_processor and compute_kernel_type == compute_kernel - and table.fused_params == fused_params_group + and fused_params_group + == _get_grouping_fused_params( + table.fused_params + ) and ( emb_dim_bucketer.get_bucket( table.embedding_dim, table.data_type @@ -215,6 +269,17 @@ def _group_tables_per_rank( logging.info( f"{len(grouped_tables)} tables are grouped for bucket: {dim_bucket}." ) + cache_load_factor = ( + _get_weighted_avg_cache_load_factor( + grouped_tables + ) + ) + per_tbe_fused_params = copy.copy(fused_params_group) + if cache_load_factor is not None: + per_tbe_fused_params[ + CACHE_LOAD_FACTOR_STR + ] = cache_load_factor + grouped_embedding_configs.append( GroupedEmbeddingConfig( data_type=data_type, @@ -225,7 +290,7 @@ def _group_tables_per_rank( embedding_tables=grouped_tables, fused_params={ k: v - for k, v in fused_params_group.items() + for k, v in per_tbe_fused_params.items() if k not in [ "_batch_key" diff --git a/torchrec/distributed/tests/test_embedding_sharding.py b/torchrec/distributed/tests/test_embedding_sharding.py new file mode 100644 index 000000000..d047fcc1e --- /dev/null +++ b/torchrec/distributed/tests/test_embedding_sharding.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import random +import unittest +from typing import List +from unittest.mock import MagicMock + +import hypothesis.strategies as st + +from hypothesis import given, settings + +from torchrec.distributed.embedding_lookup import EmbeddingComputeKernel + +from torchrec.distributed.embedding_sharding import ( + _get_grouping_fused_params, + _get_weighted_avg_cache_load_factor, + group_tables, +) +from torchrec.distributed.embedding_types import ( + GroupedEmbeddingConfig, + ShardedEmbeddingTable, +) +from torchrec.modules.embedding_configs import DataType, PoolingType + + +class TestGetWeightedAverageCacheLoadFactor(unittest.TestCase): + def test_get_avg_cache_load_factor_hbm(self) -> None: + cache_load_factors = [random.random() for _ in range(5)] + embedding_tables: List[ShardedEmbeddingTable] = [ + ShardedEmbeddingTable( + num_embeddings=1000, + embedding_dim=MagicMock(), + fused_params={"cache_load_factor": cache_load_factor}, + ) + for cache_load_factor in cache_load_factors + ] + + weighted_avg_cache_load_factor = _get_weighted_avg_cache_load_factor( + embedding_tables + ) + self.assertIsNone(weighted_avg_cache_load_factor) + + def test_get_avg_cache_load_factor(self) -> None: + cache_load_factors = [random.random() for _ in range(5)] + embedding_tables: List[ShardedEmbeddingTable] = [ + ShardedEmbeddingTable( + num_embeddings=1000, + embedding_dim=MagicMock(), + compute_kernel=EmbeddingComputeKernel.FUSED_UVM_CACHING, + fused_params={"cache_load_factor": cache_load_factor}, + ) + for cache_load_factor in cache_load_factors + ] + + weighted_avg_cache_load_factor = _get_weighted_avg_cache_load_factor( + embedding_tables + ) + expected_avg = sum(cache_load_factors) / len(cache_load_factors) + self.assertIsNotNone(weighted_avg_cache_load_factor) + self.assertAlmostEqual(weighted_avg_cache_load_factor, expected_avg) + + def test_get_weighted_avg_cache_load_factor(self) -> None: + hash_sizes = [random.randint(100, 1000) for _ in range(5)] + cache_load_factors = [random.random() for _ in range(5)] + embedding_tables: List[ShardedEmbeddingTable] = [ + ShardedEmbeddingTable( + num_embeddings=hash_size, + embedding_dim=MagicMock(), + compute_kernel=EmbeddingComputeKernel.FUSED_UVM_CACHING, + fused_params={"cache_load_factor": cache_load_factor}, + ) + for cache_load_factor, hash_size in zip(cache_load_factors, hash_sizes) + ] + + weighted_avg_cache_load_factor = _get_weighted_avg_cache_load_factor( + embedding_tables + ) + expected_weighted_avg = sum( + cache_load_factor * hash_size + for cache_load_factor, hash_size in zip(cache_load_factors, hash_sizes) + ) / sum(hash_sizes) + + self.assertIsNotNone(weighted_avg_cache_load_factor) + self.assertAlmostEqual(weighted_avg_cache_load_factor, expected_weighted_avg) + + +class TestGetGroupingFusedParams(unittest.TestCase): + def test_get_grouping_fused_params(self) -> None: + fused_params_groups = [ + None, + {}, + {"stochastic_rounding": False}, + {"stochastic_rounding": False, "cache_load_factor": 0.4}, + ] + grouping_fused_params_groups = [ + _get_grouping_fused_params(fused_params) + for fused_params in fused_params_groups + ] + expected_grouping_fused_params_groups = [ + None, + {}, + {"stochastic_rounding": False}, + {"stochastic_rounding": False}, + ] + + self.assertEqual( + grouping_fused_params_groups, expected_grouping_fused_params_groups + ) + + +class TestPerTBECacheLoadFactor(unittest.TestCase): + # pyre-ignore[56] + @given( + data_type=st.sampled_from([DataType.FP16, DataType.FP32]), + has_feature_processor=st.sampled_from([False, True]), + embedding_dim=st.sampled_from(list(range(160, 320, 40))), + pooling_type=st.sampled_from(list(PoolingType)), + ) + @settings(max_examples=10, deadline=10000) + def test_per_tbe_clf_weighted_average( + self, + data_type: DataType, + has_feature_processor: bool, + embedding_dim: int, + pooling_type: PoolingType, + ) -> None: + compute_kernels = [ + EmbeddingComputeKernel.FUSED_UVM_CACHING, + EmbeddingComputeKernel.FUSED_UVM_CACHING, + EmbeddingComputeKernel.FUSED, + EmbeddingComputeKernel.FUSED_UVM, + ] + fused_params_groups = [ + {"cache_load_factor": 0.5}, + {"cache_load_factor": 0.3}, + {"cache_load_factor": 0.9}, # hbm table, would have no effect + {"cache_load_factor": 0.4}, # uvm table, would have no effect + ] + tables = [ + ShardedEmbeddingTable( + name=f"table_{i}", + data_type=data_type, + pooling=pooling_type, + has_feature_processor=has_feature_processor, + fused_params=fused_params_groups[i], + compute_kernel=compute_kernels[i], + embedding_dim=embedding_dim, + num_embeddings=10000 * (2 * i + 1), # 10000 and 30000 + ) + for i in range(4) + ] + + # since we don't have access to _group_tables_per_rank + tables_per_rank: List[List[ShardedEmbeddingTable]] = [tables] + + # taking only the list for the first rank + table_groups: List[GroupedEmbeddingConfig] = group_tables(tables_per_rank)[0] + + # assert that they are grouped together + self.assertEqual(len(table_groups), 1) + + table_group = table_groups[0] + self.assertIsNotNone(table_group.fused_params) + self.assertEqual(table_group.fused_params.get("cache_load_factor"), 0.35)