From a96157fd3bb7e2829733e12f40338313684cff60 Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Fri, 27 Oct 2023 10:44:15 -0700 Subject: [PATCH] get per tbe cache_load_factor (#1461) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/1461 We always want to create less TBEs by grouping multiple tables together. The problem here is tables could have different cache_load_factor. Hence what we do is to calculate a per tbe cache_load_factor, by calculating the weighted average of the CLFs. Differential Revision: D49070469 fbshipit-source-id: ed448cf0ff6c0d779f806a7985380aebd7416cf3 --- torchrec/distributed/embedding_sharding.py | 91 +++++++++++++++++++--- 1 file changed, 79 insertions(+), 12 deletions(-) diff --git a/torchrec/distributed/embedding_sharding.py b/torchrec/distributed/embedding_sharding.py index 9fe2b4bb2..4e374aabb 100644 --- a/torchrec/distributed/embedding_sharding.py +++ b/torchrec/distributed/embedding_sharding.py @@ -6,6 +6,7 @@ # LICENSE file in the root directory of this source tree. import abc +import copy import logging from dataclasses import dataclass, field from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union @@ -123,6 +124,55 @@ def bucketize_kjt_before_all2all( ) +def _get_weighted_avg_cache_load_factor( + embedding_tables: List[ShardedEmbeddingTable], +) -> Optional[float]: + """ + Calculate the weighted average cache load factor of all tables. The cache + load factors are weighted by the hash size of each table. + """ + cache_load_factor_sum: float = 0.0 + weight: int = 0 + + for table in embedding_tables: + if ( + table.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING + and table.fused_params + and "cache_load_factor" in table.fused_params + ): + cache_load_factor_sum += ( + table.fused_params["cache_load_factor"] * table.num_embeddings + ) + weight += table.num_embeddings + + # if no fused_uvm_caching tables, return default cache load factor + if weight == 0: + return None + + return cache_load_factor_sum / weight + + +def _get_grouping_fused_params( + fused_params: Optional[Dict[str, Any]] +) -> Optional[Dict[str, Any]]: + """ + Only shallow copy the fused params we need for groupping tables into TBEs. In + particular, we do not copy cache_load_factor. + """ + grouping_fused_params: Optional[Dict[str, Any]] = copy.copy(fused_params) + + if not grouping_fused_params: + return grouping_fused_params + + non_grouping_params: List[str] = ["cache_load_factor"] + + for param in non_grouping_params: + if param in grouping_fused_params: + del grouping_fused_params[param] + + return grouping_fused_params + + # group tables by `DataType`, `PoolingType`, and `EmbeddingComputeKernel`. def group_tables( tables_per_rank: List[List[ShardedEmbeddingTable]], @@ -143,13 +193,15 @@ def _group_tables_per_rank( ) -> List[GroupedEmbeddingConfig]: grouped_embedding_configs: List[GroupedEmbeddingConfig] = [] - # add fused params: + # populate grouping fused params fused_params_groups = [] for table in embedding_tables: if table.fused_params is None: table.fused_params = {} - if table.fused_params not in fused_params_groups: - fused_params_groups.append(table.fused_params) + + grouping_fused_params = _get_grouping_fused_params(table.fused_params) + if grouping_fused_params not in fused_params_groups: + fused_params_groups.append(grouping_fused_params) compute_kernels = [ EmbeddingComputeKernel.DENSE, @@ -198,7 +250,10 @@ def _group_tables_per_rank( and table.has_feature_processor == has_feature_processor and compute_kernel_type == compute_kernel - and table.fused_params == fused_params_group + and fused_params_group + == _get_grouping_fused_params( + table.fused_params + ) and ( emb_dim_bucketer.get_bucket( table.embedding_dim, table.data_type @@ -217,6 +272,21 @@ def _group_tables_per_rank( fused_params_group = {} if grouped_tables: + cache_load_factor = ( + _get_weighted_avg_cache_load_factor( + grouped_tables + ) + ) + per_tbe_fused_params = copy.copy(fused_params_group) + if cache_load_factor is not None: + per_tbe_fused_params[ + "cache_load_factor" + ] = cache_load_factor + + # drop '_batch_key' not a native fused param + if "_batch_key" in per_tbe_fused_params: + del per_tbe_fused_params["_batch_key"] + grouped_embedding_configs.append( GroupedEmbeddingConfig( data_type=data_type, @@ -225,16 +295,13 @@ def _group_tables_per_rank( has_feature_processor=has_feature_processor, compute_kernel=compute_kernel, embedding_tables=grouped_tables, - fused_params={ - k: v - for k, v in fused_params_group.items() - if k - not in [ - "_batch_key" - ] # drop '_batch_key' not a native fused param - }, + fused_params=per_tbe_fused_params, ) ) + + logging.info( + f"Per table cache_load_factor is {cache_load_factor}" + ) return grouped_embedding_configs table_weightedness = [