Skip to content

Commit

Permalink
Embedding grouper distinguish between prefetched vs non-prefetched ta…
Browse files Browse the repository at this point in the history
…ble (pytorch#1859)

Summary:
Pull Request resolved: pytorch#1859

Fuse cached vs non-cached table together will generally have two problem for memory efficiency:
1.  TBE cache will use

If `prefetch_pipeline` is as fused parameter, training pipeline will try to call `prefetch()` in a separate stream one batch ahead of time. This process, is unfortunately consuming lots of extra memory. Practically it consumes 8~9x the size of input tensor at peak.

Therefore, we wish to minimize the input tensor size to `prefetch()` call as much as possible. To achieve that, we don't want to mix tables that require prefetch and doesn't to be grouped to the same TBE.

This diff will not change behavior for any jobs without cached embedding offloading.

For embedding-offloaded jobs, this diff will slightly decrease the performance of TBE lookup as it result in more TBEs (and subsequently more kernels in forward and backward). but greatly increase the memory efficiency:
1. If prefetch is off, all table (regardless of cache status or dimension) will be grouped together
2. If prefetch is on,
    1. Cached vs noncached tables will be separated, even if they have the same dimension
    2. For two cached tables, if they have different dimension they shall be separated, otherwise they'll be grouped
    3. For two noncached tables, they'll be grouped regardless of dimension

Reviewed By: henrylhtsang

Differential Revision: D55901328

fbshipit-source-id: 3ad3c26721a5a208177e060c2777087ca1273bb7
  • Loading branch information
levythu authored and facebook-github-bot committed Apr 23, 2024
1 parent 3fbd547 commit 819ecf3
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 55 deletions.
33 changes: 0 additions & 33 deletions torchrec/distributed/embedding_dim_bucketer.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,36 +155,3 @@ def bucket(self, dim: int, dtype: DataType) -> int:

def dim_in_bytes(self, dim: int, dtype: DataType) -> int:
return dim * DATA_TYPE_NUM_BITS[dtype] // 8


def should_do_dim_bucketing(
embedding_tables: List[ShardedEmbeddingTable],
) -> bool:
"""
When embedding memory offloading with caching is enabled, we prefer to
do dim bucketing for better utilization of cache space. Only applied to
"prefetch-sparse-dist" training pipeline.
Currently using the compute kernel to deduct caching is enabled.
"""
table_pipeline_count = 0
for table in embedding_tables:
if (
table.fused_params is not None
and "prefetch_pipeline" in table.fused_params
and table.fused_params["prefetch_pipeline"]
):
table_pipeline_count += 1

if table_pipeline_count > 0 and table_pipeline_count != len(embedding_tables):
AssertionError(
f"Only {table_pipeline_count} tables have prefetch-sparse-dist pipeline. It should be all {len(embedding_tables)} tables."
)

for table in embedding_tables:
if (
table.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING
and table_pipeline_count
):
return True
return False
49 changes: 42 additions & 7 deletions torchrec/distributed/embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import uuid
from collections import defaultdict
from dataclasses import dataclass
from itertools import filterfalse
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union

import torch
Expand All @@ -23,7 +24,6 @@
from torchrec.distributed.embedding_dim_bucketer import (
EmbDimBucketer,
EmbDimBucketerPolicy,
should_do_dim_bucketing,
)
from torchrec.distributed.embedding_types import (
BaseEmbeddingLookup,
Expand Down Expand Up @@ -242,6 +242,26 @@ def _get_compute_kernel_type(
return compute_kernel_type


def _prefetch_and_cached(
table: ShardedEmbeddingTable,
) -> bool:
"""
Return if this embedding use hbm as cache. In this case we might want to use
bucketizer to group by dimension for memory efficiency.
"""

return (
table.compute_kernel
in [
EmbeddingComputeKernel.FUSED_UVM_CACHING,
EmbeddingComputeKernel.QUANT_UVM_CACHING,
]
and table.fused_params is not None
and "prefetch_pipeline" in table.fused_params
and table.fused_params["prefetch_pipeline"]
)


# group tables by `DataType`, `PoolingType`, and `EmbeddingComputeKernel`.
def group_tables(
tables_per_rank: List[List[ShardedEmbeddingTable]],
Expand All @@ -262,12 +282,20 @@ def _group_tables_per_rank(
) -> List[GroupedEmbeddingConfig]:
grouped_embedding_configs: List[GroupedEmbeddingConfig] = []

emb_dim_bucketer_policy = (
EmbDimBucketerPolicy.ALL_BUCKETS
if should_do_dim_bucketing(embedding_tables)
else EmbDimBucketerPolicy.SINGLE_BUCKET
# We use different dim-bucketing policy for different cases.
# If prefetch is off, all table (regardless of cache status or dimension) will be grouped together (SINGLE_BUCKET)
# If prefetch is on,
# Cached vs noncached tables will be separated, even if they have the same dimension
# For two cached tables, if they have different dimension they shall be separated, otherwise they'll be grouped (ALL_BUCKETS)
# For two noncached tables, they'll be grouped regardless of dimension (SINGLE_BUCKET)
prefetch_cached_dim_bucketer = EmbDimBucketer(
list(filter(_prefetch_and_cached, embedding_tables)),
EmbDimBucketerPolicy.ALL_BUCKETS,
)
non_prefetch_cached_dim_bucketer = EmbDimBucketer(
list(filterfalse(_prefetch_and_cached, embedding_tables)),
EmbDimBucketerPolicy.SINGLE_BUCKET,
)
emb_dim_bucketer = EmbDimBucketer(embedding_tables, emb_dim_bucketer_policy)

# all embedding tables have the same weight status
is_weighted = (
Expand All @@ -279,13 +307,19 @@ def _group_tables_per_rank(
grouping_keys = []
for table in embedding_tables:
group_fused_params = _get_grouping_fused_params(table.fused_params) or {}
bucketer = (
prefetch_cached_dim_bucketer
if _prefetch_and_cached(table)
else non_prefetch_cached_dim_bucketer
)
grouping_key = (
table.data_type,
table.pooling,
table.has_feature_processor,
tuple(sorted(group_fused_params.items())),
_get_compute_kernel_type(table.compute_kernel),
emb_dim_bucketer.get_bucket(table.local_cols, table.data_type),
bucketer.get_bucket(table.local_cols, table.data_type),
_prefetch_and_cached(table),
)
# micromanage the order of we traverse the groups to ensure backwards compatibility
if grouping_key not in groups:
Expand All @@ -300,6 +334,7 @@ def _group_tables_per_rank(
fused_params_tuple,
compute_kernel_type,
_,
_,
) = grouping_key
grouped_tables = groups[grouping_key]
# remove non-native fused params
Expand Down
5 changes: 0 additions & 5 deletions torchrec/distributed/tests/test_emb_dim_bucketer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from torchrec.distributed.embedding_dim_bucketer import (
EmbDimBucketer,
EmbDimBucketerPolicy,
should_do_dim_bucketing,
)

from torchrec.distributed.embedding_types import (
Expand Down Expand Up @@ -101,7 +100,3 @@ def test_all_bucket_policy(self) -> None:

for i in range(emb_dim_bucketer.bucket_count()):
self.assertTrue(i in emb_dim_bucketer.emb_dim_buckets.values())

def test_should_do_dim_bucketing(self) -> None:
embedding_tables, _ = self.gen_tables()
self.assertFalse(should_do_dim_bucketing(embedding_tables))
27 changes: 17 additions & 10 deletions torchrec/distributed/tests/test_embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_get_compute_kernel_type,
_get_grouping_fused_params,
_get_weighted_avg_cache_load_factor,
_prefetch_and_cached,
group_tables,
)

Expand Down Expand Up @@ -394,18 +395,24 @@ def test_should_not_group_together(
if distinct_key == "compute_kernel" and _get_compute_kernel_type(
compute_kernels[0]
) == _get_compute_kernel_type(compute_kernels[1]):
self.assertEqual(
_get_table_names_by_groups(tables),
[["table_0", "table_1"]],
)
# Typically, a table with same group of kernel (e.g. FUSED vs FUSED_UVM)
# would be grouped together. But if one of them are related to CACHE,
# we'll group them separately because we don't want to add the burden of
# prefetch()
if _prefetch_and_cached(tables[0]) != _prefetch_and_cached(tables[1]):
self.assertEqual(
sorted(_get_table_names_by_groups(tables)),
[["table_0"], ["table_1"]],
)
else:
self.assertEqual(
_get_table_names_by_groups(tables),
[["table_0", "table_1"]],
)
return

# emb dim bucketizier only in use when computer kernel is uvm caching
# and prefetch pipeline is True
if (
distinct_key == "local_dim"
and compute_kernels[0] != EmbeddingComputeKernel.FUSED_UVM_CACHING
):
# emb dim bucketizier only in use when computer kernel is caching
if distinct_key == "local_dim" and _prefetch_and_cached(tables[0]):
self.assertEqual(
_get_table_names_by_groups(tables),
[["table_0", "table_1"]],
Expand Down

0 comments on commit 819ecf3

Please sign in to comment.