Skip to content

Commit

Permalink
get per tbe cache_load_factor (pytorch#1461)
Browse files Browse the repository at this point in the history
Summary:

We always want to create less TBEs by grouping multiple tables together. The problem here is tables could have different cache_load_factor.

Hence what we do is to calculate a per tbe cache_load_factor, by calculating the weighted average of the CLFs.

Differential Revision: D49070469
  • Loading branch information
hlhtsang authored and facebook-github-bot committed Nov 16, 2023
1 parent cf32b32 commit ef3d932
Show file tree
Hide file tree
Showing 2 changed files with 240 additions and 5 deletions.
75 changes: 70 additions & 5 deletions torchrec/distributed/embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# LICENSE file in the root directory of this source tree.

import abc
import copy
import logging
from dataclasses import dataclass, field
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union
Expand Down Expand Up @@ -49,6 +50,8 @@

torch.fx.wrap("len")

CACHE_LOAD_FACTOR_STR: str = "cache_load_factor"


# torch.Tensor.to can not be fx symbolic traced as it does not go through __torch_dispatch__ => fx.wrap it
@torch.fx.wrap
Expand Down Expand Up @@ -123,6 +126,52 @@ def bucketize_kjt_before_all2all(
)


def _get_weighted_avg_cache_load_factor(
embedding_tables: List[ShardedEmbeddingTable],
) -> Optional[float]:
"""
Calculate the weighted average cache load factor of all tables. The cache
load factors are weighted by the hash size of each table.
"""
cache_load_factor_sum: float = 0.0
weight: int = 0

for table in embedding_tables:
if (
table.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING
and table.fused_params
and CACHE_LOAD_FACTOR_STR in table.fused_params
):
cache_load_factor_sum += (
table.fused_params[CACHE_LOAD_FACTOR_STR] * table.num_embeddings
)
weight += table.num_embeddings

# if no fused_uvm_caching tables, return default cache load factor
if weight == 0:
return None

return cache_load_factor_sum / weight


def _get_grouping_fused_params(
fused_params: Optional[Dict[str, Any]]
) -> Optional[Dict[str, Any]]:
"""
Only shallow copy the fused params we need for grouping tables into TBEs. In
particular, we do not copy cache_load_factor.
"""
grouping_fused_params: Optional[Dict[str, Any]] = copy.copy(fused_params)

if not grouping_fused_params:
return grouping_fused_params

if CACHE_LOAD_FACTOR_STR in grouping_fused_params:
del grouping_fused_params[CACHE_LOAD_FACTOR_STR]

return grouping_fused_params


# group tables by `DataType`, `PoolingType`, and `EmbeddingComputeKernel`.
def group_tables(
tables_per_rank: List[List[ShardedEmbeddingTable]],
Expand All @@ -143,13 +192,15 @@ def _group_tables_per_rank(
) -> List[GroupedEmbeddingConfig]:
grouped_embedding_configs: List[GroupedEmbeddingConfig] = []

# add fused params:
# populate grouping fused params
fused_params_groups = []
for table in embedding_tables:
if table.fused_params is None:
table.fused_params = {}
if table.fused_params not in fused_params_groups:
fused_params_groups.append(table.fused_params)

grouping_fused_params = _get_grouping_fused_params(table.fused_params)
if grouping_fused_params not in fused_params_groups:
fused_params_groups.append(grouping_fused_params)

compute_kernels = [
EmbeddingComputeKernel.DENSE,
Expand Down Expand Up @@ -198,7 +249,10 @@ def _group_tables_per_rank(
and table.has_feature_processor
== has_feature_processor
and compute_kernel_type == compute_kernel
and table.fused_params == fused_params_group
and fused_params_group
== _get_grouping_fused_params(
table.fused_params
)
and (
emb_dim_bucketer.get_bucket(
table.embedding_dim, table.data_type
Expand All @@ -215,6 +269,17 @@ def _group_tables_per_rank(
logging.info(
f"{len(grouped_tables)} tables are grouped for bucket: {dim_bucket}."
)
cache_load_factor = (
_get_weighted_avg_cache_load_factor(
grouped_tables
)
)
per_tbe_fused_params = copy.copy(fused_params_group)
if cache_load_factor is not None:
per_tbe_fused_params[
CACHE_LOAD_FACTOR_STR
] = cache_load_factor

grouped_embedding_configs.append(
GroupedEmbeddingConfig(
data_type=data_type,
Expand All @@ -225,7 +290,7 @@ def _group_tables_per_rank(
embedding_tables=grouped_tables,
fused_params={
k: v
for k, v in fused_params_group.items()
for k, v in per_tbe_fused_params.items()
if k
not in [
"_batch_key"
Expand Down
170 changes: 170 additions & 0 deletions torchrec/distributed/tests/test_embedding_sharding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import random
import unittest
from typing import List
from unittest.mock import MagicMock

import hypothesis.strategies as st

from hypothesis import given, settings

from torchrec.distributed.embedding_lookup import EmbeddingComputeKernel

from torchrec.distributed.embedding_sharding import (
_get_grouping_fused_params,
_get_weighted_avg_cache_load_factor,
group_tables,
)
from torchrec.distributed.embedding_types import (
GroupedEmbeddingConfig,
ShardedEmbeddingTable,
)
from torchrec.modules.embedding_configs import DataType, PoolingType


class TestGetWeightedAverageCacheLoadFactor(unittest.TestCase):
def test_get_avg_cache_load_factor_hbm(self) -> None:
cache_load_factors = [random.random() for _ in range(5)]
embedding_tables: List[ShardedEmbeddingTable] = [
ShardedEmbeddingTable(
num_embeddings=1000,
embedding_dim=MagicMock(),
fused_params={"cache_load_factor": cache_load_factor},
)
for cache_load_factor in cache_load_factors
]

weighted_avg_cache_load_factor = _get_weighted_avg_cache_load_factor(
embedding_tables
)
self.assertIsNone(weighted_avg_cache_load_factor)

def test_get_avg_cache_load_factor(self) -> None:
cache_load_factors = [random.random() for _ in range(5)]
embedding_tables: List[ShardedEmbeddingTable] = [
ShardedEmbeddingTable(
num_embeddings=1000,
embedding_dim=MagicMock(),
compute_kernel=EmbeddingComputeKernel.FUSED_UVM_CACHING,
fused_params={"cache_load_factor": cache_load_factor},
)
for cache_load_factor in cache_load_factors
]

weighted_avg_cache_load_factor = _get_weighted_avg_cache_load_factor(
embedding_tables
)
expected_avg = sum(cache_load_factors) / len(cache_load_factors)
self.assertIsNotNone(weighted_avg_cache_load_factor)
self.assertAlmostEqual(weighted_avg_cache_load_factor, expected_avg)

def test_get_weighted_avg_cache_load_factor(self) -> None:
hash_sizes = [random.randint(100, 1000) for _ in range(5)]
cache_load_factors = [random.random() for _ in range(5)]
embedding_tables: List[ShardedEmbeddingTable] = [
ShardedEmbeddingTable(
num_embeddings=hash_size,
embedding_dim=MagicMock(),
compute_kernel=EmbeddingComputeKernel.FUSED_UVM_CACHING,
fused_params={"cache_load_factor": cache_load_factor},
)
for cache_load_factor, hash_size in zip(cache_load_factors, hash_sizes)
]

weighted_avg_cache_load_factor = _get_weighted_avg_cache_load_factor(
embedding_tables
)
expected_weighted_avg = sum(
cache_load_factor * hash_size
for cache_load_factor, hash_size in zip(cache_load_factors, hash_sizes)
) / sum(hash_sizes)

self.assertIsNotNone(weighted_avg_cache_load_factor)
self.assertAlmostEqual(weighted_avg_cache_load_factor, expected_weighted_avg)


class TestGetGroupingFusedParams(unittest.TestCase):
def test_get_grouping_fused_params(self) -> None:
fused_params_groups = [
None,
{},
{"stochastic_rounding": False},
{"stochastic_rounding": False, "cache_load_factor": 0.4},
]
grouping_fused_params_groups = [
_get_grouping_fused_params(fused_params)
for fused_params in fused_params_groups
]
expected_grouping_fused_params_groups = [
None,
{},
{"stochastic_rounding": False},
{"stochastic_rounding": False},
]

self.assertEqual(
grouping_fused_params_groups, expected_grouping_fused_params_groups
)


class TestPerTBECacheLoadFactor(unittest.TestCase):
# pyre-ignore[56]
@given(
data_type=st.sampled_from([DataType.FP16, DataType.FP32]),
has_feature_processor=st.sampled_from([False, True]),
embedding_dim=st.sampled_from(list(range(160, 320, 40))),
pooling_type=st.sampled_from(list(PoolingType)),
)
@settings(max_examples=10, deadline=10000)
def test_per_tbe_clf_weighted_average(
self,
data_type: DataType,
has_feature_processor: bool,
embedding_dim: int,
pooling_type: PoolingType,
) -> None:
compute_kernels = [
EmbeddingComputeKernel.FUSED_UVM_CACHING,
EmbeddingComputeKernel.FUSED_UVM_CACHING,
EmbeddingComputeKernel.FUSED,
EmbeddingComputeKernel.FUSED_UVM,
]
fused_params_groups = [
{"cache_load_factor": 0.5},
{"cache_load_factor": 0.3},
{"cache_load_factor": 0.9}, # hbm table, would have no effect
{"cache_load_factor": 0.4}, # uvm table, would have no effect
]
tables = [
ShardedEmbeddingTable(
name=f"table_{i}",
data_type=data_type,
pooling=pooling_type,
has_feature_processor=has_feature_processor,
fused_params=fused_params_groups[i],
compute_kernel=compute_kernels[i],
embedding_dim=embedding_dim,
num_embeddings=10000 * (2 * i + 1), # 10000 and 30000
)
for i in range(4)
]

# since we don't have access to _group_tables_per_rank
tables_per_rank: List[List[ShardedEmbeddingTable]] = [tables]

# taking only the list for the first rank
table_groups: List[GroupedEmbeddingConfig] = group_tables(tables_per_rank)[0]

# assert that they are grouped together
self.assertEqual(len(table_groups), 1)

table_group = table_groups[0]
self.assertIsNotNone(table_group.fused_params)
self.assertEqual(table_group.fused_params.get("cache_load_factor"), 0.35)

0 comments on commit ef3d932

Please sign in to comment.