Skip to content

Commit

Permalink
Enable Embedding Dim Bucketer (pytorch#1443)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1443

Enabling Embedding Dim Bucketer when UVM CACHING is used with "sparse-dist-prefetch" pipeline.

Default policy in such case is CACHLINE_BUCKETS, otherwise, SINGLE_BUCKET policy is used which is equivalent to having no buckets.

Differential Revision: D50132647

fbshipit-source-id: 1ab90939480dd803843cdcaf4660363a03974559
  • Loading branch information
ehsanardestani authored and facebook-github-bot committed Oct 27, 2023
1 parent e4153d6 commit f3ba8d9
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 48 deletions.
38 changes: 37 additions & 1 deletion torchrec/distributed/embedding_dim_bucketer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
from enum import Enum, unique
from typing import Dict, List

from torchrec.distributed.embedding_types import ShardedEmbeddingTable
from torchrec.distributed.embedding_types import (
EmbeddingComputeKernel,
ShardedEmbeddingTable,
)
from torchrec.modules.embedding_configs import DATA_TYPE_NUM_BITS, DataType


Expand Down Expand Up @@ -153,3 +156,36 @@ def bucket(self, dim: int, dtype: DataType) -> int:

def dim_in_bytes(self, dim: int, dtype: DataType) -> int:
return dim * DATA_TYPE_NUM_BITS[dtype] // 8


def should_do_dim_bucketing(
embedding_tables: List[ShardedEmbeddingTable],
) -> bool:
"""
When embedding memory offloading with caching is enabled, we prefer to
do dim bucketing for better utilization of cache space. Only applied to
"prefetch-sparse-dist" training pipeline.
Currently using the compute kernel to deduct caching is enabled.
"""
table_pipeline_count = 0
for table in embedding_tables:
if (
table.fused_params is not None
and "prefetch_pipeline" in table.fused_params
and table.fused_params["prefetch_pipeline"]
):
table_pipeline_count += 1

if table_pipeline_count > 0 and table_pipeline_count != len(embedding_tables):
AssertionError(
f"Only {table_pipeline_count} tables have prefetch-sparse-dist pipeline. It should be all {len(embedding_tables)} tables."
)

for table in embedding_tables:
if (
table.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING
and table_pipeline_count
):
return True
return False
123 changes: 77 additions & 46 deletions torchrec/distributed/embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# LICENSE file in the root directory of this source tree.

import abc
import logging
from dataclasses import dataclass, field
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union

Expand All @@ -15,6 +16,11 @@
KJTAllToAllTensorsAwaitable,
SplitsAllToAllAwaitable,
)
from torchrec.distributed.embedding_dim_bucketer import (
EmbDimBucketer,
EmbDimBucketerPolicy,
should_do_dim_bucketing,
)
from torchrec.distributed.embedding_types import (
BaseEmbeddingLookup,
BaseGroupedFeatureProcessor,
Expand Down Expand Up @@ -151,59 +157,84 @@ def _group_tables_per_rank(
EmbeddingComputeKernel.QUANT,
]

emb_dim_bucketer_cfg = (
EmbDimBucketerPolicy.CACHELINE_BUCKETS
if should_do_dim_bucketing(embedding_tables)
else EmbDimBucketerPolicy.SINGLE_BUCKET
)
emb_dim_bucketer = EmbDimBucketer(embedding_tables, emb_dim_bucketer_cfg)
logging.info(f"bucket count {emb_dim_bucketer.bucket_count()}")

for data_type in DataType:
for pooling in PoolingType:
# remove this when finishing migration
for has_feature_processor in [False, True]:
for fused_params_group in fused_params_groups:
for compute_kernel in compute_kernels:
grouped_tables: List[ShardedEmbeddingTable] = []
is_weighted = False
for table in embedding_tables:
compute_kernel_type = table.compute_kernel
is_weighted = table.is_weighted
if table.compute_kernel in [
EmbeddingComputeKernel.FUSED_UVM,
EmbeddingComputeKernel.FUSED_UVM_CACHING,
]:
compute_kernel_type = EmbeddingComputeKernel.FUSED
elif table.compute_kernel in [
EmbeddingComputeKernel.QUANT_UVM,
EmbeddingComputeKernel.QUANT_UVM_CACHING,
]:
compute_kernel_type = EmbeddingComputeKernel.QUANT
if (
table.data_type == data_type
and table.pooling.value == pooling.value
and table.has_feature_processor
== has_feature_processor
and compute_kernel_type == compute_kernel
and table.fused_params == fused_params_group
):
grouped_tables.append(table)

if fused_params_group is None:
fused_params_group = {}

if grouped_tables:
grouped_embedding_configs.append(
GroupedEmbeddingConfig(
data_type=data_type,
pooling=pooling,
is_weighted=is_weighted,
has_feature_processor=has_feature_processor,
compute_kernel=compute_kernel,
embedding_tables=grouped_tables,
fused_params={
k: v
for k, v in fused_params_group.items()
if k
not in [
"_batch_key"
] # drop '_batch_key' not a native fused param
},
for dim_bucket in range(emb_dim_bucketer.bucket_count()):
grouped_tables: List[ShardedEmbeddingTable] = []
is_weighted = False
for table in embedding_tables:
compute_kernel_type = table.compute_kernel
is_weighted = table.is_weighted
if table.compute_kernel in [
EmbeddingComputeKernel.FUSED_UVM,
EmbeddingComputeKernel.FUSED_UVM_CACHING,
]:
compute_kernel_type = (
EmbeddingComputeKernel.FUSED
)
elif table.compute_kernel in [
EmbeddingComputeKernel.QUANT_UVM,
EmbeddingComputeKernel.QUANT_UVM_CACHING,
]:
compute_kernel_type = (
EmbeddingComputeKernel.QUANT
)

if (
table.data_type == data_type
and table.pooling.value == pooling.value
and table.has_feature_processor
== has_feature_processor
and compute_kernel_type == compute_kernel
and table.fused_params == fused_params_group
and (
emb_dim_bucketer.get_bucket(
table.embedding_dim, table.data_type
)
== dim_bucket
)
):
grouped_tables.append(table)

if len(grouped_tables) > 0:
logging.info(
f"{len(grouped_tables)} tables are grouped for bucket: {dim_bucket}."
)

if fused_params_group is None:
fused_params_group = {}

if grouped_tables:
grouped_embedding_configs.append(
GroupedEmbeddingConfig(
data_type=data_type,
pooling=pooling,
is_weighted=is_weighted,
has_feature_processor=has_feature_processor,
compute_kernel=compute_kernel,
embedding_tables=grouped_tables,
fused_params={
k: v
for k, v in fused_params_group.items()
if k
not in [
"_batch_key"
] # drop '_batch_key' not a native fused param
},
)
)
)
return grouped_embedding_configs

table_weightedness = [
Expand Down
11 changes: 10 additions & 1 deletion torchrec/distributed/tests/test_emb_dim_bucketer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@
from torchrec.distributed.embedding_dim_bucketer import (
EmbDimBucketer,
EmbDimBucketerPolicy,
should_do_dim_bucketing,
)

from torchrec.distributed.embedding_types import ShardedEmbeddingTable
from torchrec.distributed.embedding_types import (
EmbeddingComputeKernel,
ShardedEmbeddingTable,
)
from torchrec.modules.embedding_configs import DataType


Expand All @@ -36,6 +40,7 @@ def gen_tables(self) -> Tuple[List[ShardedEmbeddingTable], int]:
embedding_dim=buckets[i % num_buckets],
num_embeddings=random.randint(100, 500000),
data_type=DataType.FP16,
compute_kernel=EmbeddingComputeKernel.FUSED_UVM_CACHING,
)
)
return embeddings, len(buckets)
Expand Down Expand Up @@ -86,3 +91,7 @@ def test_all_bucket_policy(self) -> None:

for i in range(emb_dim_bucketer.bucket_count()):
self.assertTrue(i in emb_dim_bucketer.emb_dim_buckets.values())

def test_should_do_dim_bucketing(self) -> None:
embedding_tables, _ = self.gen_tables()
self.assertFalse(should_do_dim_bucketing(embedding_tables))

0 comments on commit f3ba8d9

Please sign in to comment.