Skip to content

Commit

Permalink
get per tbe cache_load_factor (pytorch#1461)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1461

We always want to create less TBEs by grouping multiple tables together. The problem here is tables could have different cache_load_factor.

Hence what we do is to calculate a per tbe cache_load_factor, by calculating the weighted average of the CLFs.

Differential Revision: D49070469

fbshipit-source-id: ed448cf0ff6c0d779f806a7985380aebd7416cf3
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed Oct 27, 2023
1 parent f3ba8d9 commit a96157f
Showing 1 changed file with 79 additions and 12 deletions.
91 changes: 79 additions & 12 deletions torchrec/distributed/embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# LICENSE file in the root directory of this source tree.

import abc
import copy
import logging
from dataclasses import dataclass, field
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union
Expand Down Expand Up @@ -123,6 +124,55 @@ def bucketize_kjt_before_all2all(
)


def _get_weighted_avg_cache_load_factor(
embedding_tables: List[ShardedEmbeddingTable],
) -> Optional[float]:
"""
Calculate the weighted average cache load factor of all tables. The cache
load factors are weighted by the hash size of each table.
"""
cache_load_factor_sum: float = 0.0
weight: int = 0

for table in embedding_tables:
if (
table.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING
and table.fused_params
and "cache_load_factor" in table.fused_params
):
cache_load_factor_sum += (
table.fused_params["cache_load_factor"] * table.num_embeddings
)
weight += table.num_embeddings

# if no fused_uvm_caching tables, return default cache load factor
if weight == 0:
return None

return cache_load_factor_sum / weight


def _get_grouping_fused_params(
fused_params: Optional[Dict[str, Any]]
) -> Optional[Dict[str, Any]]:
"""
Only shallow copy the fused params we need for groupping tables into TBEs. In
particular, we do not copy cache_load_factor.
"""
grouping_fused_params: Optional[Dict[str, Any]] = copy.copy(fused_params)

if not grouping_fused_params:
return grouping_fused_params

non_grouping_params: List[str] = ["cache_load_factor"]

for param in non_grouping_params:
if param in grouping_fused_params:
del grouping_fused_params[param]

return grouping_fused_params


# group tables by `DataType`, `PoolingType`, and `EmbeddingComputeKernel`.
def group_tables(
tables_per_rank: List[List[ShardedEmbeddingTable]],
Expand All @@ -143,13 +193,15 @@ def _group_tables_per_rank(
) -> List[GroupedEmbeddingConfig]:
grouped_embedding_configs: List[GroupedEmbeddingConfig] = []

# add fused params:
# populate grouping fused params
fused_params_groups = []
for table in embedding_tables:
if table.fused_params is None:
table.fused_params = {}
if table.fused_params not in fused_params_groups:
fused_params_groups.append(table.fused_params)

grouping_fused_params = _get_grouping_fused_params(table.fused_params)
if grouping_fused_params not in fused_params_groups:
fused_params_groups.append(grouping_fused_params)

compute_kernels = [
EmbeddingComputeKernel.DENSE,
Expand Down Expand Up @@ -198,7 +250,10 @@ def _group_tables_per_rank(
and table.has_feature_processor
== has_feature_processor
and compute_kernel_type == compute_kernel
and table.fused_params == fused_params_group
and fused_params_group
== _get_grouping_fused_params(
table.fused_params
)
and (
emb_dim_bucketer.get_bucket(
table.embedding_dim, table.data_type
Expand All @@ -217,6 +272,21 @@ def _group_tables_per_rank(
fused_params_group = {}

if grouped_tables:
cache_load_factor = (
_get_weighted_avg_cache_load_factor(
grouped_tables
)
)
per_tbe_fused_params = copy.copy(fused_params_group)
if cache_load_factor is not None:
per_tbe_fused_params[
"cache_load_factor"
] = cache_load_factor

# drop '_batch_key' not a native fused param
if "_batch_key" in per_tbe_fused_params:
del per_tbe_fused_params["_batch_key"]

grouped_embedding_configs.append(
GroupedEmbeddingConfig(
data_type=data_type,
Expand All @@ -225,16 +295,13 @@ def _group_tables_per_rank(
has_feature_processor=has_feature_processor,
compute_kernel=compute_kernel,
embedding_tables=grouped_tables,
fused_params={
k: v
for k, v in fused_params_group.items()
if k
not in [
"_batch_key"
] # drop '_batch_key' not a native fused param
},
fused_params=per_tbe_fused_params,
)
)

logging.info(
f"Per table cache_load_factor is {cache_load_factor}"
)
return grouped_embedding_configs

table_weightedness = [
Expand Down

0 comments on commit a96157f

Please sign in to comment.