Skip to content

Commit

Permalink
Add KeyValueParams
Browse files Browse the repository at this point in the history
Summary:
Add KeyValueParams class, that are for params to go to SSD TBE.

Expectation:
* pass to SSD TBE only when using EmbeddingComputeKernel.KEY_VALUE. This is important to make sure we can use a mixed of FUSED and KEY_VALUE tables.
* need to be hashable

Differential Revision: D58892592
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed Jun 25, 2024
1 parent 704afbe commit b3a4a68
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 0 deletions.
6 changes: 6 additions & 0 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
)
from fbgemm_gpu.tbe.ssd import ASSOC, SSDTableBatchedEmbeddingBags
from torch import nn
from torchrec.distributed.comm import get_local_rank
from torchrec.distributed.composable.table_batched_embedding_slice import (
TableBatchedEmbeddingSlice,
)
Expand Down Expand Up @@ -133,6 +134,11 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]:

if "ssd_storage_directory" not in ssd_tbe_params:
ssd_tbe_params["ssd_storage_directory"] = tempfile.mkdtemp()
else:
directory = ssd_tbe_params["ssd_storage_directory"]
if "@local_rank" in directory:
# assume we have initialized a process group already
directory = directory.replace("@local_rank", str(get_local_rank()))

if "weights_precision" not in ssd_tbe_params:
weights_precision = data_type_to_sparse_type(config.data_type)
Expand Down
7 changes: 7 additions & 0 deletions torchrec/distributed/planner/enumerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from torchrec.distributed.types import (
BoundsCheckMode,
CacheParams,
KeyValueParams,
ModuleSharder,
ShardingType,
)
Expand Down Expand Up @@ -154,6 +155,7 @@ def enumerate(
feature_names,
output_dtype,
device_group,
key_value_params,
) = _extract_constraints_for_param(self._constraints, name)

# skip for other device groups
Expand Down Expand Up @@ -209,6 +211,7 @@ def enumerate(
is_pooled=is_pooled,
feature_names=feature_names,
output_dtype=output_dtype,
key_value_params=key_value_params,
)
)
if not sharding_options_per_table:
Expand Down Expand Up @@ -315,6 +318,7 @@ def _extract_constraints_for_param(
Optional[List[str]],
Optional[DataType],
Optional[str],
Optional[KeyValueParams],
]:
input_lengths = [POOLING_FACTOR]
col_wise_shard_dim = None
Expand All @@ -325,6 +329,7 @@ def _extract_constraints_for_param(
feature_names = None
output_dtype = None
device_group = None
key_value_params = None

if constraints and constraints.get(name):
input_lengths = constraints[name].pooling_factors
Expand All @@ -336,6 +341,7 @@ def _extract_constraints_for_param(
feature_names = constraints[name].feature_names
output_dtype = constraints[name].output_dtype
device_group = constraints[name].device_group
key_value_params = constraints[name].key_value_params

return (
input_lengths,
Expand All @@ -347,6 +353,7 @@ def _extract_constraints_for_param(
feature_names,
output_dtype,
device_group,
key_value_params,
)


Expand Down
1 change: 1 addition & 0 deletions torchrec/distributed/planner/planners.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def _to_sharding_plan(
stochastic_rounding=sharding_option.stochastic_rounding,
bounds_check_mode=sharding_option.bounds_check_mode,
output_dtype=sharding_option.output_dtype,
key_value_params=sharding_option.key_value_params,
)
plan[sharding_option.path] = module_plan
return ShardingPlan(plan)
Expand Down
8 changes: 8 additions & 0 deletions torchrec/distributed/planner/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from torchrec.distributed.types import (
BoundsCheckMode,
CacheParams,
KeyValueParams,
ModuleSharder,
ShardingPlan,
)
Expand Down Expand Up @@ -368,6 +369,8 @@ class ShardingOption:
output_dtype (Optional[DataType]): output dtype to be used by this table.
The default is FP32. If not None, the output dtype will also be used
by the planner to produce a more balanced plan.
key_value_params (Optional[KeyValueParams]): Params for SSD TBE, either
for SSD or PS.
"""

def __init__(
Expand All @@ -389,6 +392,7 @@ def __init__(
is_pooled: Optional[bool] = None,
feature_names: Optional[List[str]] = None,
output_dtype: Optional[DataType] = None,
key_value_params: Optional[KeyValueParams] = None,
) -> None:
self.name = name
self._tensor = tensor
Expand All @@ -410,6 +414,7 @@ def __init__(
self.is_weighted: Optional[bool] = None
self.feature_names: Optional[List[str]] = feature_names
self.output_dtype: Optional[DataType] = output_dtype
self.key_value_params: Optional[KeyValueParams] = key_value_params

@property
def tensor(self) -> torch.Tensor:
Expand Down Expand Up @@ -574,6 +579,8 @@ class ParameterConstraints:
device_group (Optional[str]): device group to be used by this table. It can be cpu
or cuda. This specifies if the table should be placed on a cpu device
or a gpu device.
key_value_params (Optional[KeyValueParams]): key value params for SSD TBE, either for
SSD or PS.
"""

sharding_types: Optional[List[str]] = None
Expand All @@ -592,6 +599,7 @@ class ParameterConstraints:
feature_names: Optional[List[str]] = None
output_dtype: Optional[DataType] = None
device_group: Optional[str] = None
key_value_params: Optional[KeyValueParams] = None


class PlannerErrorType(Enum):
Expand Down
25 changes: 25 additions & 0 deletions torchrec/distributed/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Iterator,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
Expand Down Expand Up @@ -576,6 +577,28 @@ def __hash__(self) -> int:
)


@dataclass
class KeyValueParams:
"""
Params for SSD TBE aka SSDTableBatchedEmbeddingBags.
Attributes:
ssd_storage_directory (Optional[str]): Directory for SSD. If we want directory
to be f"data00_nvidia{local_rank}", pass in "data00_nvidia@local_rank".
"""

ssd_storage_directory: Optional[str] = None
ps_hosts: Optional[Tuple[Tuple[str, int]]] = None

def __hash__(self) -> int:
return hash(
(
self.ssd_storage_directory,
self.ps_hosts,
)
)


@dataclass
class ParameterSharding:
"""
Expand All @@ -591,6 +614,7 @@ class ParameterSharding:
stochastic_rounding (Optional[bool]): whether to use stochastic rounding.
bounds_check_mode (Optional[BoundsCheckMode]): bounds check mode.
output_dtype (Optional[DataType]): output dtype.
key_value_params (Optional[KeyValueParams]): key value params for SSD TBE or PS.
NOTE:
ShardingType.TABLE_WISE - rank where this embedding is placed
Expand All @@ -610,6 +634,7 @@ class ParameterSharding:
stochastic_rounding: Optional[bool] = None
bounds_check_mode: Optional[BoundsCheckMode] = None
output_dtype: Optional[DataType] = None
key_value_params: Optional[KeyValueParams] = None


class EmbeddingModuleShardingPlan(ModuleShardingPlan, Dict[str, ParameterSharding]):
Expand Down
11 changes: 11 additions & 0 deletions torchrec/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import sys

from collections import OrderedDict
from dataclasses import asdict
from typing import Any, Dict, List, Optional, Set, Type, TypeVar, Union

import torch
Expand Down Expand Up @@ -405,6 +406,16 @@ def add_params_from_parameter_sharding(
if parameter_sharding.output_dtype is not None:
fused_params["output_dtype"] = parameter_sharding.output_dtype

if (
parameter_sharding.compute_kernel in {EmbeddingComputeKernel.KEY_VALUE.value}
and parameter_sharding.key_value_params is not None
):
key_value_params_dict = asdict(parameter_sharding.key_value_params)
key_value_params_dict = {
k: v for k, v in key_value_params_dict.items() if v is not None
}
fused_params.update(key_value_params_dict)

# print warning if sharding_type is data_parallel or kernel is dense
if parameter_sharding.sharding_type == ShardingType.DATA_PARALLEL.value:
logger.warning(
Expand Down

0 comments on commit b3a4a68

Please sign in to comment.