Skip to content

Commit

Permalink
Refactor group_tables (pytorch#1516)
Browse files Browse the repository at this point in the history
Summary:

Refactor group_tables.

The unit tests were created using code from before this change. See D50483066 and D50482054. I have to make some changes to the prefetch test, since at the end we switch to all_buckets policy from cacheline policy.

Differential Revision: D51220870
  • Loading branch information
hlhtsang authored and facebook-github-bot committed Nov 16, 2023
1 parent 157aa15 commit 9610809
Show file tree
Hide file tree
Showing 3 changed files with 326 additions and 104 deletions.
201 changes: 98 additions & 103 deletions torchrec/distributed/embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,7 @@
ShardMetadata,
)
from torchrec.fx.utils import assert_fx_safe
from torchrec.modules.embedding_configs import (
DataType,
EmbeddingTableConfig,
PoolingType,
)
from torchrec.modules.embedding_configs import EmbeddingTableConfig
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from torchrec.streamable import Multistreamable

Expand Down Expand Up @@ -172,6 +168,26 @@ def _get_grouping_fused_params(
return grouping_fused_params


def _get_compute_kernel_type(
compute_kernel: EmbeddingComputeKernel,
) -> EmbeddingComputeKernel:
"""
Return the compute kernel type for the given compute kernel.
"""
compute_kernel_type = compute_kernel
if compute_kernel_type in [
EmbeddingComputeKernel.FUSED_UVM,
EmbeddingComputeKernel.FUSED_UVM_CACHING,
]:
compute_kernel_type = EmbeddingComputeKernel.FUSED
elif compute_kernel_type in [
EmbeddingComputeKernel.QUANT_UVM,
EmbeddingComputeKernel.QUANT_UVM_CACHING,
]:
compute_kernel_type = EmbeddingComputeKernel.QUANT
return compute_kernel_type


# group tables by `DataType`, `PoolingType`, and `EmbeddingComputeKernel`.
def group_tables(
tables_per_rank: List[List[ShardedEmbeddingTable]],
Expand All @@ -192,22 +208,6 @@ def _group_tables_per_rank(
) -> List[GroupedEmbeddingConfig]:
grouped_embedding_configs: List[GroupedEmbeddingConfig] = []

# populate grouping fused params
fused_params_groups = []
for table in embedding_tables:
if table.fused_params is None:
table.fused_params = {}

grouping_fused_params = _get_grouping_fused_params(table.fused_params)
if grouping_fused_params not in fused_params_groups:
fused_params_groups.append(grouping_fused_params)

compute_kernels = [
EmbeddingComputeKernel.DENSE,
EmbeddingComputeKernel.FUSED,
EmbeddingComputeKernel.QUANT,
]

emb_dim_bucketer_policy = (
EmbDimBucketerPolicy.ALL_BUCKETS
if should_do_dim_bucketing(embedding_tables)
Expand All @@ -216,88 +216,83 @@ def _group_tables_per_rank(
emb_dim_bucketer = EmbDimBucketer(embedding_tables, emb_dim_bucketer_policy)
logging.info(f"bucket count {emb_dim_bucketer.bucket_count()}")

for data_type in DataType:
for pooling in PoolingType:
# remove this when finishing migration
for has_feature_processor in [False, True]:
for fused_params_group in fused_params_groups:
for compute_kernel in compute_kernels:
for dim_bucket in range(emb_dim_bucketer.bucket_count()):
grouped_tables: List[ShardedEmbeddingTable] = []
is_weighted = False
for table in embedding_tables:
compute_kernel_type = table.compute_kernel
is_weighted = table.is_weighted
if table.compute_kernel in [
EmbeddingComputeKernel.FUSED_UVM,
EmbeddingComputeKernel.FUSED_UVM_CACHING,
]:
compute_kernel_type = (
EmbeddingComputeKernel.FUSED
)
elif table.compute_kernel in [
EmbeddingComputeKernel.QUANT_UVM,
EmbeddingComputeKernel.QUANT_UVM_CACHING,
]:
compute_kernel_type = (
EmbeddingComputeKernel.QUANT
)

if (
table.data_type == data_type
and table.pooling.value == pooling.value
and table.has_feature_processor
== has_feature_processor
and compute_kernel_type == compute_kernel
and fused_params_group
== _get_grouping_fused_params(
table.fused_params
)
and (
emb_dim_bucketer.get_bucket(
table.embedding_dim, table.data_type
)
== dim_bucket
)
):
grouped_tables.append(table)

if fused_params_group is None:
fused_params_group = {}

if grouped_tables:
logging.info(
f"{len(grouped_tables)} tables are grouped for bucket: {dim_bucket}."
)
cache_load_factor = (
_get_weighted_avg_cache_load_factor(
grouped_tables
)
)
per_tbe_fused_params = copy.copy(fused_params_group)
if cache_load_factor is not None:
per_tbe_fused_params[
CACHE_LOAD_FACTOR_STR
] = cache_load_factor

grouped_embedding_configs.append(
GroupedEmbeddingConfig(
data_type=data_type,
pooling=pooling,
is_weighted=is_weighted,
has_feature_processor=has_feature_processor,
compute_kernel=compute_kernel,
embedding_tables=grouped_tables,
fused_params={
k: v
for k, v in per_tbe_fused_params.items()
if k
not in [
"_batch_key"
] # drop '_batch_key' not a native fused param
},
)
)
# populate grouping keys
grouping_keys = []
for table in embedding_tables:
if table.fused_params is None:
table.fused_params = {}

grouping_key = (
table.data_type,
table.pooling,
table.has_feature_processor,
_get_grouping_fused_params(table.fused_params),
_get_compute_kernel_type(table.compute_kernel),
emb_dim_bucketer.get_bucket(table.embedding_dim, table.data_type),
)
if grouping_key not in grouping_keys:
grouping_keys.append(grouping_key)

for grouping_key in grouping_keys:
(
data_type,
pooling,
has_feature_processor,
fused_params_group,
compute_kernel_type,
dim_bucket,
) = grouping_key
grouped_tables: List[ShardedEmbeddingTable] = []
is_weighted = False
for table in embedding_tables:
is_weighted = table.is_weighted
if (
table.data_type == data_type
and table.pooling.value == pooling.value
and table.has_feature_processor == has_feature_processor
and fused_params_group
== _get_grouping_fused_params(table.fused_params)
and compute_kernel_type
== _get_compute_kernel_type(table.compute_kernel)
and (
emb_dim_bucketer.get_bucket(
table.embedding_dim, table.data_type
)
== dim_bucket
)
):
grouped_tables.append(table)

if fused_params_group is None:
fused_params_group = {}

if grouped_tables:
logging.info(
f"{len(grouped_tables)} tables are grouped for bucket: {dim_bucket}."
)
cache_load_factor = _get_weighted_avg_cache_load_factor(grouped_tables)
per_tbe_fused_params = copy.copy(fused_params_group)
if cache_load_factor is not None:
per_tbe_fused_params[CACHE_LOAD_FACTOR_STR] = cache_load_factor

grouped_embedding_configs.append(
GroupedEmbeddingConfig(
data_type=data_type,
pooling=pooling,
is_weighted=is_weighted,
has_feature_processor=has_feature_processor,
compute_kernel=compute_kernel_type,
embedding_tables=grouped_tables,
fused_params={
k: v
for k, v in per_tbe_fused_params.items()
if k
not in [
"_batch_key"
] # drop '_batch_key' not a native fused param
},
)
)
return grouped_embedding_configs

table_weightedness = [
Expand Down
6 changes: 6 additions & 0 deletions torchrec/distributed/embedding_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,12 @@ def dim_sum(self) -> int:
dim_sum += table.num_features() * table.local_cols
return dim_sum

def table_names(self) -> List[str]:
table_names = []
for table in self.embedding_tables:
table_names.append(table.name)
return table_names

def feature_names(self) -> List[str]:
feature_names = []
for table in self.embedding_tables:
Expand Down
Loading

0 comments on commit 9610809

Please sign in to comment.