From b3a4a68f179334eb41dacbbad93d88ecfee20ee9 Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Mon, 24 Jun 2024 18:40:04 -0700 Subject: [PATCH] Add KeyValueParams Summary: Add KeyValueParams class, that are for params to go to SSD TBE. Expectation: * pass to SSD TBE only when using EmbeddingComputeKernel.KEY_VALUE. This is important to make sure we can use a mixed of FUSED and KEY_VALUE tables. * need to be hashable Differential Revision: D58892592 --- .../distributed/batched_embedding_kernel.py | 6 +++++ torchrec/distributed/planner/enumerators.py | 7 ++++++ torchrec/distributed/planner/planners.py | 1 + torchrec/distributed/planner/types.py | 8 ++++++ torchrec/distributed/types.py | 25 +++++++++++++++++++ torchrec/distributed/utils.py | 11 ++++++++ 6 files changed, 58 insertions(+) diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index 02a9b98a2..5e4836a9f 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -43,6 +43,7 @@ ) from fbgemm_gpu.tbe.ssd import ASSOC, SSDTableBatchedEmbeddingBags from torch import nn +from torchrec.distributed.comm import get_local_rank from torchrec.distributed.composable.table_batched_embedding_slice import ( TableBatchedEmbeddingSlice, ) @@ -133,6 +134,11 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]: if "ssd_storage_directory" not in ssd_tbe_params: ssd_tbe_params["ssd_storage_directory"] = tempfile.mkdtemp() + else: + directory = ssd_tbe_params["ssd_storage_directory"] + if "@local_rank" in directory: + # assume we have initialized a process group already + directory = directory.replace("@local_rank", str(get_local_rank())) if "weights_precision" not in ssd_tbe_params: weights_precision = data_type_to_sparse_type(config.data_type) diff --git a/torchrec/distributed/planner/enumerators.py b/torchrec/distributed/planner/enumerators.py index 2d3e66eef..7554b74be 100644 --- a/torchrec/distributed/planner/enumerators.py +++ b/torchrec/distributed/planner/enumerators.py @@ -31,6 +31,7 @@ from torchrec.distributed.types import ( BoundsCheckMode, CacheParams, + KeyValueParams, ModuleSharder, ShardingType, ) @@ -154,6 +155,7 @@ def enumerate( feature_names, output_dtype, device_group, + key_value_params, ) = _extract_constraints_for_param(self._constraints, name) # skip for other device groups @@ -209,6 +211,7 @@ def enumerate( is_pooled=is_pooled, feature_names=feature_names, output_dtype=output_dtype, + key_value_params=key_value_params, ) ) if not sharding_options_per_table: @@ -315,6 +318,7 @@ def _extract_constraints_for_param( Optional[List[str]], Optional[DataType], Optional[str], + Optional[KeyValueParams], ]: input_lengths = [POOLING_FACTOR] col_wise_shard_dim = None @@ -325,6 +329,7 @@ def _extract_constraints_for_param( feature_names = None output_dtype = None device_group = None + key_value_params = None if constraints and constraints.get(name): input_lengths = constraints[name].pooling_factors @@ -336,6 +341,7 @@ def _extract_constraints_for_param( feature_names = constraints[name].feature_names output_dtype = constraints[name].output_dtype device_group = constraints[name].device_group + key_value_params = constraints[name].key_value_params return ( input_lengths, @@ -347,6 +353,7 @@ def _extract_constraints_for_param( feature_names, output_dtype, device_group, + key_value_params, ) diff --git a/torchrec/distributed/planner/planners.py b/torchrec/distributed/planner/planners.py index 7e3630a71..8ab0ea7b1 100644 --- a/torchrec/distributed/planner/planners.py +++ b/torchrec/distributed/planner/planners.py @@ -106,6 +106,7 @@ def _to_sharding_plan( stochastic_rounding=sharding_option.stochastic_rounding, bounds_check_mode=sharding_option.bounds_check_mode, output_dtype=sharding_option.output_dtype, + key_value_params=sharding_option.key_value_params, ) plan[sharding_option.path] = module_plan return ShardingPlan(plan) diff --git a/torchrec/distributed/planner/types.py b/torchrec/distributed/planner/types.py index b20d96673..21e205d15 100644 --- a/torchrec/distributed/planner/types.py +++ b/torchrec/distributed/planner/types.py @@ -30,6 +30,7 @@ from torchrec.distributed.types import ( BoundsCheckMode, CacheParams, + KeyValueParams, ModuleSharder, ShardingPlan, ) @@ -368,6 +369,8 @@ class ShardingOption: output_dtype (Optional[DataType]): output dtype to be used by this table. The default is FP32. If not None, the output dtype will also be used by the planner to produce a more balanced plan. + key_value_params (Optional[KeyValueParams]): Params for SSD TBE, either + for SSD or PS. """ def __init__( @@ -389,6 +392,7 @@ def __init__( is_pooled: Optional[bool] = None, feature_names: Optional[List[str]] = None, output_dtype: Optional[DataType] = None, + key_value_params: Optional[KeyValueParams] = None, ) -> None: self.name = name self._tensor = tensor @@ -410,6 +414,7 @@ def __init__( self.is_weighted: Optional[bool] = None self.feature_names: Optional[List[str]] = feature_names self.output_dtype: Optional[DataType] = output_dtype + self.key_value_params: Optional[KeyValueParams] = key_value_params @property def tensor(self) -> torch.Tensor: @@ -574,6 +579,8 @@ class ParameterConstraints: device_group (Optional[str]): device group to be used by this table. It can be cpu or cuda. This specifies if the table should be placed on a cpu device or a gpu device. + key_value_params (Optional[KeyValueParams]): key value params for SSD TBE, either for + SSD or PS. """ sharding_types: Optional[List[str]] = None @@ -592,6 +599,7 @@ class ParameterConstraints: feature_names: Optional[List[str]] = None output_dtype: Optional[DataType] = None device_group: Optional[str] = None + key_value_params: Optional[KeyValueParams] = None class PlannerErrorType(Enum): diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index 290bdffbf..8505d936b 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -20,6 +20,7 @@ Iterator, List, Optional, + Tuple, Type, TypeVar, Union, @@ -576,6 +577,28 @@ def __hash__(self) -> int: ) +@dataclass +class KeyValueParams: + """ + Params for SSD TBE aka SSDTableBatchedEmbeddingBags. + + Attributes: + ssd_storage_directory (Optional[str]): Directory for SSD. If we want directory + to be f"data00_nvidia{local_rank}", pass in "data00_nvidia@local_rank". + """ + + ssd_storage_directory: Optional[str] = None + ps_hosts: Optional[Tuple[Tuple[str, int]]] = None + + def __hash__(self) -> int: + return hash( + ( + self.ssd_storage_directory, + self.ps_hosts, + ) + ) + + @dataclass class ParameterSharding: """ @@ -591,6 +614,7 @@ class ParameterSharding: stochastic_rounding (Optional[bool]): whether to use stochastic rounding. bounds_check_mode (Optional[BoundsCheckMode]): bounds check mode. output_dtype (Optional[DataType]): output dtype. + key_value_params (Optional[KeyValueParams]): key value params for SSD TBE or PS. NOTE: ShardingType.TABLE_WISE - rank where this embedding is placed @@ -610,6 +634,7 @@ class ParameterSharding: stochastic_rounding: Optional[bool] = None bounds_check_mode: Optional[BoundsCheckMode] = None output_dtype: Optional[DataType] = None + key_value_params: Optional[KeyValueParams] = None class EmbeddingModuleShardingPlan(ModuleShardingPlan, Dict[str, ParameterSharding]): diff --git a/torchrec/distributed/utils.py b/torchrec/distributed/utils.py index 26efb1263..541dc2ff5 100644 --- a/torchrec/distributed/utils.py +++ b/torchrec/distributed/utils.py @@ -13,6 +13,7 @@ import sys from collections import OrderedDict +from dataclasses import asdict from typing import Any, Dict, List, Optional, Set, Type, TypeVar, Union import torch @@ -405,6 +406,16 @@ def add_params_from_parameter_sharding( if parameter_sharding.output_dtype is not None: fused_params["output_dtype"] = parameter_sharding.output_dtype + if ( + parameter_sharding.compute_kernel in {EmbeddingComputeKernel.KEY_VALUE.value} + and parameter_sharding.key_value_params is not None + ): + key_value_params_dict = asdict(parameter_sharding.key_value_params) + key_value_params_dict = { + k: v for k, v in key_value_params_dict.items() if v is not None + } + fused_params.update(key_value_params_dict) + # print warning if sharding_type is data_parallel or kernel is dense if parameter_sharding.sharding_type == ShardingType.DATA_PARALLEL.value: logger.warning(