diff --git a/torchrec/distributed/embedding_sharding.py b/torchrec/distributed/embedding_sharding.py index 04afb8fd9..dc05d6027 100644 --- a/torchrec/distributed/embedding_sharding.py +++ b/torchrec/distributed/embedding_sharding.py @@ -47,7 +47,6 @@ from torchrec.sparse.jagged_tensor import KeyedJaggedTensor from torchrec.streamable import Multistreamable - torch.fx.wrap("len") CACHE_LOAD_FACTOR_STR: str = "cache_load_factor" @@ -62,15 +61,6 @@ def _fx_wrap_tensor_to_device_dtype( return t.to(device=tensor_device_dtype.device, dtype=tensor_device_dtype.dtype) -@torch.fx.wrap -def _fx_wrap_optional_tensor_to_device_dtype( - t: Optional[torch.Tensor], tensor_device_dtype: torch.Tensor -) -> Optional[torch.Tensor]: - if t is None: - return None - return t.to(device=tensor_device_dtype.device, dtype=tensor_device_dtype.dtype) - - @torch.fx.wrap def _fx_wrap_batch_size_per_feature(kjt: KeyedJaggedTensor) -> Optional[torch.Tensor]: return ( @@ -131,7 +121,6 @@ def _fx_wrap_seq_block_bucketize_sparse_features_inference( block_sizes: torch.Tensor, bucketize_pos: bool = False, block_bucketize_pos: Optional[List[torch.Tensor]] = None, - total_num_blocks: Optional[torch.Tensor] = None, ) -> Tuple[ torch.Tensor, torch.Tensor, @@ -153,7 +142,6 @@ def _fx_wrap_seq_block_bucketize_sparse_features_inference( bucketize_pos=bucketize_pos, sequence=True, block_sizes=block_sizes, - total_num_blocks=total_num_blocks, my_size=num_buckets, weights=kjt.weights_or_none(), max_B=_fx_wrap_max_B(kjt), @@ -301,7 +289,6 @@ def bucketize_kjt_inference( kjt: KeyedJaggedTensor, num_buckets: int, block_sizes: torch.Tensor, - total_num_buckets: Optional[torch.Tensor] = None, bucketize_pos: bool = False, block_bucketize_row_pos: Optional[List[torch.Tensor]] = None, is_sequence: bool = False, @@ -316,7 +303,6 @@ def bucketize_kjt_inference( Args: num_buckets (int): number of buckets to bucketize the values into. block_sizes: (torch.Tensor): bucket sizes for the keyed dimension. - total_num_blocks: (Optional[torch.Tensor]): number of blocks per feature, useful for two-level bucketization bucketize_pos (bool): output the changed position of the bucketized values or not. block_bucketize_row_pos (Optional[List[torch.Tensor]]): The offsets of shard size for each feature. @@ -332,9 +318,6 @@ def bucketize_kjt_inference( f"Expecting block sizes for {num_features} features, but {block_sizes.numel()} received.", ) block_sizes_new_type = _fx_wrap_tensor_to_device_dtype(block_sizes, kjt.values()) - total_num_buckets_new_type = _fx_wrap_optional_tensor_to_device_dtype( - total_num_buckets, kjt.values() - ) unbucketize_permute = None bucket_mapping = None if is_sequence: @@ -349,7 +332,6 @@ def bucketize_kjt_inference( kjt, num_buckets=num_buckets, block_sizes=block_sizes_new_type, - total_num_blocks=total_num_buckets_new_type, bucketize_pos=bucketize_pos, block_bucketize_pos=block_bucketize_row_pos, ) diff --git a/torchrec/distributed/mc_modules.py b/torchrec/distributed/mc_modules.py index 96bd40152..a59d7bde2 100644 --- a/torchrec/distributed/mc_modules.py +++ b/torchrec/distributed/mc_modules.py @@ -12,14 +12,14 @@ import logging import math from collections import defaultdict, OrderedDict -from dataclasses import dataclass -from typing import Any, DefaultDict, Dict, Iterator, List, Optional, Type, Union +from typing import Any, DefaultDict, Dict, Iterator, List, Optional, Type import torch import torch.distributed as dist + from torch import nn -from torch.distributed._shard.sharded_tensor import Shard, ShardMetadata -from torch.fx._symbolic_trace import is_fx_tracing +from torch.distributed._shard.sharded_tensor import Shard +from torchrec.distributed.embedding import EmbeddingCollectionContext from torchrec.distributed.embedding_sharding import ( EmbeddingSharding, EmbeddingShardingContext, @@ -30,22 +30,16 @@ BaseEmbeddingSharder, GroupedEmbeddingConfig, KJTList, - ListOfKJTList, ) - from torchrec.distributed.sharding.rw_sequence_sharding import ( RwSequenceEmbeddingDist, RwSequenceEmbeddingSharding, ) from torchrec.distributed.sharding.rw_sharding import ( BaseRwEmbeddingSharding, - InferRwSparseFeaturesDist, RwSparseFeaturesDist, ) -from torchrec.distributed.sharding.sequence_sharding import ( - InferSequenceShardingContext, - SequenceShardingContext, -) +from torchrec.distributed.sharding.sequence_sharding import SequenceShardingContext from torchrec.distributed.types import ( Awaitable, LazyAwaitable, @@ -55,49 +49,12 @@ ShardedTensor, ShardingEnv, ShardingType, + ShardMetadata, ) from torchrec.distributed.utils import append_prefix from torchrec.modules.mc_modules import ManagedCollisionCollection from torchrec.modules.utils import construct_jagged_tensors from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor -from torchrec.streamable import Multistreamable - - -@dataclass -class EmbeddingCollectionContext(Multistreamable): - sharding_contexts: List[InferSequenceShardingContext | SequenceShardingContext] - - def record_stream(self, stream: torch.Stream) -> None: - for ctx in self.sharding_contexts: - ctx.record_stream(stream) - - -class ManagedCollisionCollectionContext(EmbeddingCollectionContext): - pass - - -@torch.fx.wrap -def _fx_global_to_local_index( - feature_dict: Dict[str, JaggedTensor], feature_to_offset: Dict[str, int] -) -> Dict[str, JaggedTensor]: - for feature, jt in feature_dict.items(): - jt._values = jt.values() - feature_to_offset[feature] - return feature_dict - - -@torch.fx.wrap -def _fx_jt_dict_add_offset( - feature_dict: Dict[str, JaggedTensor], feature_to_offset: Dict[str, int] -) -> Dict[str, JaggedTensor]: - for feature, jt in feature_dict.items(): - jt._values = jt.values() + feature_to_offset[feature] - return feature_dict - - -@torch.fx.wrap -def _get_length_per_key(kjt: KeyedJaggedTensor) -> torch.Tensor: - return torch.tensor(kjt.length_per_key()) - logger: logging.Logger = logging.getLogger(__name__) @@ -149,6 +106,10 @@ def _wait_impl(self) -> KeyedJaggedTensor: return KeyedJaggedTensor.from_jt_dict(jt_dict) +class ManagedCollisionCollectionContext(EmbeddingCollectionContext): + pass + + def create_mc_sharding( sharding_type: str, sharding_infos: List[EmbeddingShardingInfo], @@ -366,7 +327,7 @@ def _create_managed_collision_modules( torch.zeros(1, dtype=torch.int64, device=self._device) for _ in range(self._env.world_size) ] - if self.training and self._env.world_size > 1: + if self._env.world_size > 1: dist.all_gather( zch_size_by_rank, torch.tensor( @@ -573,8 +534,8 @@ def _dedup_indices( values=unique_indices, ) - ctx.input_features.append(kjt) # pyre-ignore - ctx.reverse_indices.append(reverse_indices) # pyre-ignore + ctx.input_features.append(kjt) + ctx.reverse_indices.append(reverse_indices) features_by_sharding.append(dedup_features) return features_by_sharding @@ -694,7 +655,6 @@ def compute( self._sharding_per_table_feature_splits, self._sharding_features, ): - assert isinstance(sharding_ctx, SequenceShardingContext) sharding_ctx.lengths_after_input_dist = features.lengths().view( -1, features.stride() ) @@ -797,6 +757,7 @@ def output_dist( embedding_names_per_sharding=self._embedding_names_per_sharding, need_indices=False, features_to_permute_indices=None, + reverse_indices=ctx.reverse_indices if self._use_index_dedup else None, ) def create_context(self) -> ManagedCollisionCollectionContext: @@ -872,462 +833,3 @@ def sharding_types(self, compute_device_type: str) -> List[str]: ShardingType.ROW_WISE.value, ] return types - - -@torch.fx.wrap -def _cat_jagged_values(jd: Dict[str, JaggedTensor]) -> torch.Tensor: - return torch.cat([jt.values() for jt in jd.values()]) - - -@torch.fx.wrap -def update_jagged_tensor_dict( - output: Dict[str, JaggedTensor], new_dict: Dict[str, JaggedTensor] -) -> Dict[str, JaggedTensor]: - output.update(new_dict) - return output - - -class ShardedMCCRemapper(nn.Module): - def __init__( - self, - table_feature_splits: List[int], - fns: List[str], - managed_collision_modules: nn.ModuleDict, - shard_metadata: Dict[str, List[int]], - ) -> None: - super().__init__() - self._table_feature_splits: List[int] = table_feature_splits - self._fns: List[str] = fns - self.zchs = managed_collision_modules - logger.info(f"registered zchs: {self.zchs=}") - - # shard_size, shard_offset - self._shard_metadata: Dict[str, List[int]] = shard_metadata - self._table_to_offset: Dict[str, int] = { - table: offset[0] for table, offset in shard_metadata.items() - } - - def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor: - # features per shard split by tables - feature_splits = features.split(self._table_feature_splits) - output: Dict[str, JaggedTensor] = {} - for i, (table, mc_module) in enumerate(self.zchs.items()): - kjt: KeyedJaggedTensor = feature_splits[i] - mc_input: Dict[str, JaggedTensor] = { - table: JaggedTensor( - values=kjt.values(), - lengths=kjt.lengths(), - weights=_get_length_per_key(kjt), - ) - } - remapped_input = mc_module(mc_input) - mc_input = self.global_to_local_index(remapped_input) - output[table] = remapped_input[table] - - values: torch.Tensor = _cat_jagged_values(output) - return KeyedJaggedTensor( - keys=self._fns, - values=values, - lengths=features.lengths(), - # original weights instead of features splits - weights=features.weights_or_none(), - ) - - def global_to_local_index( - self, - jt_dict: Dict[str, JaggedTensor], - ) -> Dict[str, JaggedTensor]: - return _fx_global_to_local_index(jt_dict, self._table_to_offset) - - -class ShardedQuantManagedCollisionCollection( - ShardedModule[ - KJTList, - KJTList, - KeyedJaggedTensor, - ManagedCollisionCollectionContext, - ] -): - def __init__( - self, - module: ManagedCollisionCollection, - table_name_to_parameter_sharding: Dict[str, ParameterSharding], - env: Union[ShardingEnv, Dict[str, ShardingEnv]], - device: torch.device, - embedding_shardings: List[ - EmbeddingSharding[ - EmbeddingShardingContext, - KeyedJaggedTensor, - torch.Tensor, - torch.Tensor, - ] - ], - qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, - ) -> None: - super().__init__() - self._env: ShardingEnv = ( - env - if not isinstance(env, Dict) - else embedding_shardings[0]._env # pyre-ignore[16] - ) - self._device = device - self.need_preprocess: bool = module.need_preprocess - self._table_name_to_parameter_sharding: Dict[str, ParameterSharding] = ( - copy.deepcopy(table_name_to_parameter_sharding) - ) - # TODO: create a MCSharding type instead of leveraging EmbeddingSharding - self._embedding_shardings = embedding_shardings - - self._embedding_names_per_sharding: List[List[str]] = [] - for sharding in self._embedding_shardings: - # TODO: support TWRW sharding - assert isinstance( - sharding, BaseRwEmbeddingSharding - ), "Only ROW_WISE sharding is supported." - self._embedding_names_per_sharding.append(sharding.embedding_names()) - - self._feature_to_table: Dict[str, str] = module._feature_to_table - self._table_to_features: Dict[str, List[str]] = module._table_to_features - self._has_uninitialized_input_dists: bool = True - self._input_dists: torch.nn.ModuleList = torch.nn.ModuleList([]) - self._managed_collision_modules: nn.ModuleDict = nn.ModuleDict() - self._create_managed_collision_modules(module) - self._features_order: List[int] = [] - - def _create_managed_collision_modules( - self, module: ManagedCollisionCollection - ) -> None: - - self._managed_collision_modules_per_rank: List[torch.nn.ModuleDict] = [ - torch.nn.ModuleDict() for _ in range(self._env.world_size) - ] - self._shard_metadata_per_rank: List[Dict[str, List[int]]] = [ - defaultdict() for _ in range(self._env.world_size) - ] - self._mc_module_name_shard_metadata: DefaultDict[str, List[int]] = defaultdict() - # To map mch output indices from local to global. key: table_name - self._table_to_offset: Dict[str, int] = {} - - # the split sizes of tables belonging to each sharding. outer len is # shardings - self._sharding_per_table_feature_splits: List[List[int]] = [] - self._input_size_per_table_feature_splits: List[List[int]] = [] - # the split sizes of features per sharding. len is # shardings - self._sharding_feature_splits: List[int] = [] - # the split sizes of features per table. len is # tables sum over all shardings - self._table_feature_splits: List[int] = [] - self._feature_names: List[str] = [] - - # table names of each sharding - self._sharding_tables: List[List[str]] = [] - self._sharding_features: List[List[str]] = [] - - logger.info(f"_create_managed_collision_modules {self._embedding_shardings=}") - - for sharding in self._embedding_shardings: - assert isinstance(sharding, BaseRwEmbeddingSharding) - self._sharding_tables.append([]) - self._sharding_features.append([]) - self._sharding_per_table_feature_splits.append([]) - self._input_size_per_table_feature_splits.append([]) - - grouped_embedding_configs: List[GroupedEmbeddingConfig] = ( - sharding._grouped_embedding_configs - ) - self._sharding_feature_splits.append(len(sharding.feature_names())) - - num_sharding_features = 0 - for group_config in grouped_embedding_configs: - for table in group_config.embedding_tables: - # pyre-ignore - global_meta_data = table.global_metadata.shards_metadata - output_segments = [ - x.shard_offsets[0] - for x in table.global_metadata.shards_metadata - ] + [table.num_embeddings] - mc_module = module._managed_collision_modules[table.name] - mc_module._is_inference = True - self._managed_collision_modules[table.name] = mc_module - self._sharding_tables[-1].append(table.name) - self._sharding_features[-1].extend(table.feature_names) - self._feature_names.extend(table.feature_names) - logger.info( - f"global_meta_data for table {table} is {global_meta_data}" - ) - - for i in range(self._env.world_size): - new_min_output_id = global_meta_data[i].shard_offsets[0] - new_range_size = global_meta_data[i].shard_sizes[0] - self._managed_collision_modules_per_rank[i][table.name] = ( - mc_module.rebuild_with_output_id_range( - output_id_range=( - new_min_output_id, - new_min_output_id + new_range_size, - ), - output_segments=output_segments, - device=( - torch.device("cpu") - if self._device.type == "cpu" - else torch.device(f"{self._device.type}:{i}") - ), - ) - ) - - self._managed_collision_modules_per_rank[i][ - table.name - ].training = False - self._shard_metadata_per_rank[i][table.name] = [ - new_min_output_id, - new_range_size, - ] - - input_size = self._managed_collision_modules[ - table.name - ].input_size() - - self._table_feature_splits.append(len(table.feature_names)) - self._sharding_per_table_feature_splits[-1].append( - self._table_feature_splits[-1] - ) - self._input_size_per_table_feature_splits[-1].append( - input_size, - ) - num_sharding_features += self._table_feature_splits[-1] - - assert num_sharding_features == len( - sharding.feature_names() - ), f"Shared feature is not supported. {num_sharding_features=}, {self._sharding_per_table_feature_splits[-1]=}" - - if self._sharding_features[-1] != sharding.feature_names(): - logger.warn( - "The order of tables of this sharding is altered due to grouping: " - f"{self._sharding_features[-1]=} vs {sharding.feature_names()=}" - ) - - logger.info(f"{self._table_feature_splits=}") - logger.info(f"{self._sharding_per_table_feature_splits=}") - logger.info(f"{self._input_size_per_table_feature_splits=}") - logger.info(f"{self._feature_names=}") - # logger.info(f"{self._table_to_offset=}") - logger.info(f"{self._sharding_tables=}") - logger.info(f"{self._sharding_features=}") - logger.info(f"{self._managed_collision_modules_per_rank=}") - logger.info(f"{self._shard_metadata_per_rank=}") - - def _create_input_dists( - self, - input_feature_names: List[str], - feature_device: Optional[torch.device] = None, - ) -> None: - assert ( - not is_fx_tracing() - ), "FX Tracing is invalid in _create_input_dists. Initialize the module with sample inputs before tracing." - feature_names: List[str] = [] - for sharding in self._embedding_shardings: - assert isinstance(sharding, BaseRwEmbeddingSharding) - - emb_sharding = [] - sharding_features = [] - for embedding_table_group in sharding._grouped_embedding_configs_per_rank[ - 0 - ]: - for table in embedding_table_group.embedding_tables: - shard_split_offsets = [ - shard.shard_offsets[0] - # pyre-fixme[16]: `Optional` has no attribute `shards_metadata`. - for shard in table.global_metadata.shards_metadata - ] - # pyre-fixme[16]: Optional has no attribute size. - shard_split_offsets.append(table.global_metadata.size[0]) - emb_sharding.extend( - [shard_split_offsets] * len(table.embedding_names) - ) - sharding_features.extend(table.feature_names) - - feature_num_buckets: List[int] = [ - self._managed_collision_modules[self._feature_to_table[f]].buckets() - for f in sharding_features - ] - - input_sizes: List[int] = [ - self._managed_collision_modules[self._feature_to_table[f]].input_size() - for f in sharding_features - ] - - feature_hash_sizes: List[int] = [] - feature_total_num_buckets: List[int] = [] - for input_size, num_buckets in zip( - input_sizes, - feature_num_buckets, - ): - feature_hash_sizes.append(input_size) - feature_total_num_buckets.append(num_buckets) - - input_dist = InferRwSparseFeaturesDist( - world_size=sharding._world_size, - num_features=sharding._get_num_features(), - feature_hash_sizes=feature_hash_sizes, - feature_total_num_buckets=feature_total_num_buckets, - device=self._device, - is_sequence=True, - has_feature_processor=sharding._has_feature_processor, - need_pos=False, - embedding_shard_metadata=emb_sharding, - ) - self._input_dists.append(input_dist) - - feature_names.extend(sharding_features) - - for f in feature_names: - self._features_order.append(input_feature_names.index(f)) - self._features_order = ( - [] - if self._features_order == list(range(len(input_feature_names))) - else self._features_order - ) - self.register_buffer( - "_features_order_tensor", - torch.tensor( - self._features_order, device=feature_device, dtype=torch.int32 - ), - persistent=False, - ) - - # pyre-ignore - def input_dist( - self, - ctx: ManagedCollisionCollectionContext, - features: KeyedJaggedTensor, - ) -> ListOfKJTList: - if self._has_uninitialized_input_dists: - self._create_input_dists( - input_feature_names=features.keys(), feature_device=features.device() - ) - self._has_uninitialized_input_dists = False - - with torch.no_grad(): - if self._features_order: - features = features.permute( - self._features_order, - self._features_order_tensor, # pyre-ignore - ) - - feature_splits: List[KeyedJaggedTensor] = [] - if self.need_preprocess: - # NOTE: No shared features allowed! - assert ( - len(self._sharding_feature_splits) == 1 - ), "Preprocing only support single sharding type (row-wise)" - table_splits = features.split(self._table_feature_splits) - ti: int = 0 - for i, tables in enumerate(self._sharding_tables): - output: Dict[str, JaggedTensor] = {} - for table in tables: - kjt: KeyedJaggedTensor = table_splits[ti] - mc_module = self._managed_collision_modules[table] - # TODO: change to Dict[str, Tensor] - mc_input: Dict[str, JaggedTensor] = { - table: JaggedTensor( - values=kjt.values(), - lengths=kjt.lengths(), - ) - } - mc_input = mc_module.preprocess(mc_input) - output.update(mc_input) - ti += 1 - shard_kjt = KeyedJaggedTensor( - keys=self._sharding_features[i], - values=torch.cat([jt.values() for jt in output.values()]), - lengths=torch.cat([jt.lengths() for jt in output.values()]), - ) - feature_splits.append(shard_kjt) - else: - feature_splits = features.split(self._sharding_feature_splits) - - input_dist_result_list = [] - for feature_split, input_dist in zip(feature_splits, self._input_dists): - out = input_dist(feature_split) - input_dist_result_list.append(out.features) - ctx.sharding_contexts.append( - InferSequenceShardingContext( - features=out.features, - features_before_input_dist=features, - unbucketize_permute_tensor=( - out.unbucketize_permute_tensor - if isinstance(input_dist, InferRwSparseFeaturesDist) - else None - ), - bucket_mapping_tensor=out.bucket_mapping_tensor, - bucketized_length=out.bucketized_length, - ) - ) - - return ListOfKJTList(input_dist_result_list) - - def create_mcc_remappers(self) -> List[List[ShardedMCCRemapper]]: - ret: List[List[ShardedMCCRemapper]] = [] - # per shard - for table_feature_splits, fns in zip( - self._sharding_per_table_feature_splits, - self._sharding_features, - ): - sharding_ret: List[ShardedMCCRemapper] = [] - for i, mcms in enumerate(self._managed_collision_modules_per_rank): - sharding_ret.append( - ShardedMCCRemapper( - table_feature_splits=table_feature_splits, - fns=fns, - managed_collision_modules=mcms, - shard_metadata=self._shard_metadata_per_rank[i], - ) - ) - ret.append(sharding_ret) - return ret - - def compute( - self, - ctx: ManagedCollisionCollectionContext, - rank: int, - dist_input: KJTList, - ) -> KJTList: - raise NotImplementedError() - - # pyre-ignore - def output_dist( - self, - ctx: ManagedCollisionCollectionContext, - output: KJTList, - ) -> KeyedJaggedTensor: - raise NotImplementedError() - - def create_context(self) -> ManagedCollisionCollectionContext: - return ManagedCollisionCollectionContext(sharding_contexts=[]) - - -class InferManagedCollisionCollectionSharder(ManagedCollisionCollectionSharder): - # pyre-ignore - def shard( - self, - module: ManagedCollisionCollection, - params: Dict[str, ParameterSharding], - env: Union[ShardingEnv, Dict[str, ShardingEnv]], - embedding_shardings: List[ - EmbeddingSharding[ - EmbeddingShardingContext, - KeyedJaggedTensor, - torch.Tensor, - torch.Tensor, - ] - ], - device: Optional[torch.device] = None, - ) -> ShardedQuantManagedCollisionCollection: - - if device is None: - device = torch.device("cpu") - - return ShardedQuantManagedCollisionCollection( - module, - params, - env=env, - device=device, - embedding_shardings=embedding_shardings, - ) diff --git a/torchrec/distributed/quant_embedding.py b/torchrec/distributed/quant_embedding.py index 5096ada6e..2077297b7 100644 --- a/torchrec/distributed/quant_embedding.py +++ b/torchrec/distributed/quant_embedding.py @@ -8,22 +8,9 @@ # pyre-strict -import logging from collections import defaultdict, deque from dataclasses import dataclass -from typing import ( - Any, - cast, - Dict, - Iterator, - List, - Optional, - Set, - Tuple, - Type, - TypeVar, - Union, -) +from typing import Any, cast, Dict, List, Optional, Set, Tuple, Type, Union import torch from fbgemm_gpu.split_table_batched_embeddings_ops_inference import ( @@ -38,7 +25,6 @@ from torchrec.distributed.embedding_sharding import EmbeddingSharding from torchrec.distributed.embedding_types import ( BaseQuantEmbeddingSharder, - EmbeddingComputeKernel, FeatureShardingMixIn, GroupedEmbeddingConfig, InputDistOutputs, @@ -54,11 +40,6 @@ is_fused_param_register_tbe, ) from torchrec.distributed.global_settings import get_propogate_device -from torchrec.distributed.mc_modules import ( - InferManagedCollisionCollectionSharder, - ShardedMCCRemapper, - ShardedQuantManagedCollisionCollection, -) from torchrec.distributed.quant_state import ShardedQuantEmbeddingModuleState from torchrec.distributed.sharding.cw_sequence_sharding import ( InferCwSequenceEmbeddingSharding, @@ -66,15 +47,11 @@ from torchrec.distributed.sharding.rw_sequence_sharding import ( InferRwSequenceEmbeddingSharding, ) -from torchrec.distributed.sharding.sequence_sharding import ( - InferSequenceShardingContext, - SequenceShardingContext, -) +from torchrec.distributed.sharding.sequence_sharding import InferSequenceShardingContext from torchrec.distributed.sharding.tw_sequence_sharding import ( InferTwSequenceEmbeddingSharding, ) from torchrec.distributed.types import ParameterSharding, ShardingEnv, ShardMetadata -from torchrec.distributed.utils import append_prefix from torchrec.modules.embedding_configs import ( data_type_to_sparse_type, dtype_to_data_type, @@ -87,9 +64,8 @@ from torchrec.quant.embedding_modules import ( EmbeddingCollection as QuantEmbeddingCollection, MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, - QuantManagedCollisionEmbeddingCollection, ) -from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor from torchrec.streamable import Multistreamable torch.fx.wrap("len") @@ -103,12 +79,6 @@ pass -logger: logging.Logger = logging.getLogger(__name__) - - -ShrdCtx = TypeVar("ShrdCtx", bound=Multistreamable) - - @dataclass class EmbeddingCollectionContext(Multistreamable): sharding_contexts: List[InferSequenceShardingContext] @@ -118,35 +88,6 @@ def record_stream(self, stream: torch.Stream) -> None: ctx.record_stream(stream) -class ManagedCollisionEmbeddingCollectionContext(EmbeddingCollectionContext): - - def __init__( - self, - sharding_contexts: Optional[List[SequenceShardingContext]] = None, - input_features: Optional[List[KeyedJaggedTensor]] = None, - reverse_indices: Optional[List[torch.Tensor]] = None, - evictions_per_table: Optional[Dict[str, Optional[torch.Tensor]]] = None, - remapped_kjt: Optional[KJTList] = None, - ) -> None: - # pyre-ignore - super().__init__(sharding_contexts) - self.evictions_per_table: Optional[Dict[str, Optional[torch.Tensor]]] = ( - evictions_per_table - ) - self.remapped_kjt: Optional[KJTList] = remapped_kjt - - def record_stream(self, stream: torch.Stream) -> None: - super().record_stream(stream) - if self.evictions_per_table: - # pyre-ignore - for value in self.evictions_per_table.values(): - if value is None: - continue - value.record_stream(stream) - if self.remapped_kjt is not None: - self.remapped_kjt.record_stream(stream) - - def get_device_from_parameter_sharding( ps: ParameterSharding, ) -> Union[str, Tuple[str, ...]]: @@ -1148,240 +1089,3 @@ def forward(self, features: KeyedJaggedTensor) -> Tuple[ bucket_mapping_tensor, bucketized_lengths, ) - - -class ShardedMCECLookup(torch.nn.Module): - """ - This module implements distributed compute of a ShardedQuantManagedCollisionEmbeddingCollection. - - Args: - managed_collision_collection (ShardedQuantManagedCollisionCollection): managed collision collection - lookups (List[nn.Module]): embedding lookups - - Example:: - - """ - - def __init__( - self, - sharding: int, - rank: int, - mcc_remapper: ShardedMCCRemapper, - ec_lookup: nn.Module, - ) -> None: - super().__init__() - self._sharding = sharding - self._rank = rank - self._mcc_remapper = mcc_remapper - self._ec_lookup = ec_lookup - - def forward( - self, - features: KeyedJaggedTensor, - ) -> torch.Tensor: - remapped_kjt = self._mcc_remapper(features) - return self._ec_lookup(remapped_kjt) - - -class ShardedQuantManagedCollisionEmbeddingCollection(ShardedQuantEmbeddingCollection): - def __init__( - self, - module: QuantManagedCollisionEmbeddingCollection, - table_name_to_parameter_sharding: Dict[str, ParameterSharding], - mc_sharder: InferManagedCollisionCollectionSharder, - # TODO - maybe we need this to manage unsharded/sharded consistency/state consistency - env: Union[ShardingEnv, Dict[str, ShardingEnv]], - fused_params: Optional[Dict[str, Any]] = None, - device: Optional[torch.device] = None, - ) -> None: - super().__init__( - module, table_name_to_parameter_sharding, env, fused_params, device - ) - - self._device = device - self._env = env - - # TODO: This is a hack since _embedding_module doesn't need input - # dist, so eliminating it so all fused a2a will ignore it. - # we're using ec input_dist directly, so this cannot be escaped. - # self._has_uninitialized_input_dist = False - embedding_shardings = list( - self._sharding_type_device_group_to_sharding.values() - ) - - self._managed_collision_collection: ShardedQuantManagedCollisionCollection = ( - mc_sharder.shard( - module._managed_collision_collection, - table_name_to_parameter_sharding, - env=env, - device=device, - # pyre-ignore - embedding_shardings=embedding_shardings, - ) - ) - self._return_remapped_features: bool = module._return_remapped_features - self._create_mcec_lookups() - - def _create_mcec_lookups(self) -> None: - mcec_lookups: List[nn.ModuleList] = [] - mcc_remappers: List[List[ShardedMCCRemapper]] = ( - self._managed_collision_collection.create_mcc_remappers() - ) - for sharding in range( - len(self._managed_collision_collection._embedding_shardings) - ): - ec_sharding_lookups = self._lookups[sharding] - sharding_mcec_lookups: List[ShardedMCECLookup] = [] - for j, ec_lookup in enumerate( - ec_sharding_lookups._embedding_lookups_per_rank # pyre-ignore - ): - sharding_mcec_lookups.append( - ShardedMCECLookup( - sharding, - j, - mcc_remappers[sharding][j], - ec_lookup, - ) - ) - mcec_lookups.append(nn.ModuleList(sharding_mcec_lookups)) - self._mcec_lookup: nn.ModuleList = nn.ModuleList(mcec_lookups) - - # For consistency with ShardedManagedCollisionEmbeddingCollection - @property - def _embedding_collection(self) -> ShardedQuantEmbeddingCollection: - return cast(ShardedQuantEmbeddingCollection, self) - - def input_dist( - self, - ctx: EmbeddingCollectionContext, - features: KeyedJaggedTensor, - ) -> ListOfKJTList: - # TODO: resolve incompatiblity with different contexts - if self._has_uninitialized_output_dist: - self._create_output_dist(features.device()) - self._has_uninitialized_output_dist = False - - return self._managed_collision_collection.input_dist( - # pyre-fixme [6] - ctx, - features, - ) - - def compute( - self, - ctx: ShrdCtx, - dist_input: ListOfKJTList, - ) -> List[List[torch.Tensor]]: - ret: List[List[torch.Tensor]] = [] - for i in range(len(self._managed_collision_collection._embedding_shardings)): - dist_input_i = dist_input[i] - lookups = self._mcec_lookup[i] - sharding_ret: List[torch.Tensor] = [] - for j, lookup in enumerate(lookups): - rank_ret = lookup( - features=dist_input_i[j], - ) - sharding_ret.append(rank_ret) - ret.append(sharding_ret) - return ret - - # pyre-ignore - def output_dist( - self, - ctx: ShrdCtx, - output: List[List[torch.Tensor]], - ) -> Tuple[ - Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor] - ]: - - # pyre-ignore [6] - ebc_out = super().output_dist(ctx, output) - - kjt_out: Optional[KeyedJaggedTensor] = None - - return ebc_out, kjt_out - - def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: - for fqn, _ in self.named_parameters(): - yield append_prefix(prefix, fqn) - for fqn, _ in self.named_buffers(): - yield append_prefix(prefix, fqn) - - -class QuantManagedCollisionEmbeddingCollectionSharder( - BaseQuantEmbeddingSharder[QuantManagedCollisionEmbeddingCollection] -): - """ - This implementation uses non-fused EmbeddingCollection - """ - - def __init__( - self, - e_sharder: QuantEmbeddingCollectionSharder, - mc_sharder: InferManagedCollisionCollectionSharder, - ) -> None: - super().__init__() - self._e_sharder: QuantEmbeddingCollectionSharder = e_sharder - self._mc_sharder: InferManagedCollisionCollectionSharder = mc_sharder - - def shardable_parameters( - self, module: QuantManagedCollisionEmbeddingCollection - ) -> Dict[str, torch.nn.Parameter]: - return self._e_sharder.shardable_parameters(module) - - def compute_kernels( - self, - sharding_type: str, - compute_device_type: str, - ) -> List[str]: - return [ - EmbeddingComputeKernel.QUANT.value, - ] - - def sharding_types(self, compute_device_type: str) -> List[str]: - return list( - set.intersection( - set(self._e_sharder.sharding_types(compute_device_type)), - set(self._mc_sharder.sharding_types(compute_device_type)), - ) - ) - - @property - def fused_params(self) -> Optional[Dict[str, Any]]: - # TODO: to be deprecate after planner get cache_load_factor from ParameterConstraints - return self._e_sharder.fused_params - - def shard( - self, - module: QuantManagedCollisionEmbeddingCollection, - params: Dict[str, ParameterSharding], - env: Union[ShardingEnv, Dict[str, ShardingEnv]], - device: Optional[torch.device] = None, - module_fqn: Optional[str] = None, - ) -> ShardedQuantManagedCollisionEmbeddingCollection: - fused_params = self.fused_params if self.fused_params else {} - fused_params["output_dtype"] = data_type_to_sparse_type( - dtype_to_data_type(module.output_dtype()) - ) - if FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS not in fused_params: - fused_params[FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS] = getattr( - module, - MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, - False, - ) - if FUSED_PARAM_REGISTER_TBE_BOOL not in fused_params: - fused_params[FUSED_PARAM_REGISTER_TBE_BOOL] = getattr( - module, FUSED_PARAM_REGISTER_TBE_BOOL, False - ) - return ShardedQuantManagedCollisionEmbeddingCollection( - module, - params, - self._mc_sharder, - env, - fused_params, - device, - ) - - @property - def module_type(self) -> Type[QuantManagedCollisionEmbeddingCollection]: - return QuantManagedCollisionEmbeddingCollection diff --git a/torchrec/distributed/quant_state.py b/torchrec/distributed/quant_state.py index 179a1a40e..6cd4e15d6 100644 --- a/torchrec/distributed/quant_state.py +++ b/torchrec/distributed/quant_state.py @@ -8,8 +8,6 @@ # pyre-strict import copy -import logging -from collections import defaultdict from dataclasses import dataclass from functools import partial from typing import Any, Dict, List, Mapping, Optional, Tuple, TypeVar, Union @@ -31,7 +29,6 @@ ShardedEmbeddingModule, ) from torchrec.distributed.types import ParameterSharding, ShardingType -from torchrec.fb.modules.hash_mc_modules import HashZchManagedCollisionModule from torchrec.modules.embedding_configs import DataType from torchrec.streamable import Multistreamable from torchrec.tensor_types import UInt2Tensor, UInt4Tensor @@ -41,8 +38,6 @@ DistOut = TypeVar("DistOut") ShrdCtx = TypeVar("ShrdCtx", bound=Multistreamable) -logger: logging.Logger = logging.getLogger(__name__) - def _append_table_shard( d: Dict[str, List[Shard]], table_name: str, shard: Shard @@ -314,66 +309,59 @@ def _load_from_state_dict( _unexpected_keys: List[str] = list(state_dict.keys()) for name, dst_tensor in dst_state_dict.items(): src_state_dict_name = prefix + name - try: - if src_state_dict_name not in state_dict: - _missing_keys.append(src_state_dict_name) - continue + if src_state_dict_name not in state_dict: + _missing_keys.append(src_state_dict_name) + continue - src_tensor = state_dict[src_state_dict_name] - if isinstance(dst_tensor, ShardedTensorBase) and isinstance( - src_tensor, ShardedTensorBase - ): - # sharded to sharded model, only identically sharded - for dst_local_shard in dst_tensor.local_shards(): - copied: bool = False - for src_local_shard in src_tensor.local_shards(): - if ( - dst_local_shard.metadata.shard_offsets - == src_local_shard.metadata.shard_offsets - and dst_local_shard.metadata.shard_sizes - == src_local_shard.metadata.shard_sizes - ): - dst_local_shard.tensor.copy_(src_local_shard.tensor) - copied = True - break - assert copied, "Incompatible state_dict" - elif isinstance(dst_tensor, ShardedTensorBase) and isinstance( - src_tensor, torch.Tensor - ): - # non_sharded to sharded model - for dst_local_shard in dst_tensor.local_shards(): - dst_tensor = dst_local_shard.tensor - assert src_tensor.ndim == dst_tensor.ndim - meta = dst_local_shard.metadata - t = src_tensor.detach() - rows_from = meta.shard_offsets[0] - rows_to = rows_from + meta.shard_sizes[0] - if t.ndim == 1: - dst_tensor.copy_(t[rows_from:rows_to]) - elif t.ndim == 2: - cols_from = meta.shard_offsets[1] - cols_to = cols_from + meta.shard_sizes[1] - dst_tensor.copy_( - t[ - rows_from:rows_to, - cols_from:cols_to, - ] - ) - else: - raise RuntimeError( - "Tensors with ndim > 2 are not supported" - ) - elif isinstance(dst_tensor, list) and isinstance( - src_tensor, torch.Tensor - ): - # non_sharded to CW columns qscale, qbias (one to many) - for t in dst_tensor: - assert isinstance(t, torch.Tensor) - t.copy_(src_tensor) - else: - dst_tensor.copy_(src_tensor) - except Exception as e: - logger.error(f"Weight {name} could not be loaded. Exception: {e}.") + src_tensor = state_dict[src_state_dict_name] + if isinstance(dst_tensor, ShardedTensorBase) and isinstance( + src_tensor, ShardedTensorBase + ): + # sharded to sharded model, only identically sharded + for dst_local_shard in dst_tensor.local_shards(): + copied: bool = False + for src_local_shard in src_tensor.local_shards(): + if ( + dst_local_shard.metadata.shard_offsets + == src_local_shard.metadata.shard_offsets + and dst_local_shard.metadata.shard_sizes + == src_local_shard.metadata.shard_sizes + ): + dst_local_shard.tensor.copy_(src_local_shard.tensor) + copied = True + break + assert copied, "Incompatible state_dict" + elif isinstance(dst_tensor, ShardedTensorBase) and isinstance( + src_tensor, torch.Tensor + ): + # non_sharded to sharded model + for dst_local_shard in dst_tensor.local_shards(): + dst_tensor = dst_local_shard.tensor + assert src_tensor.ndim == dst_tensor.ndim + meta = dst_local_shard.metadata + t = src_tensor.detach() + rows_from = meta.shard_offsets[0] + rows_to = rows_from + meta.shard_sizes[0] + if t.ndim == 1: + dst_tensor.copy_(t[rows_from:rows_to]) + elif t.ndim == 2: + cols_from = meta.shard_offsets[1] + cols_to = cols_from + meta.shard_sizes[1] + dst_tensor.copy_( + t[ + rows_from:rows_to, + cols_from:cols_to, + ] + ) + else: + raise RuntimeError("Tensors with ndim > 2 are not supported") + elif isinstance(dst_tensor, list) and isinstance(src_tensor, torch.Tensor): + # non_sharded to CW columns qscale, qbias (one to many) + for t in dst_tensor: + assert isinstance(t, torch.Tensor) + t.copy_(src_tensor) + else: + dst_tensor.copy_(src_tensor) _unexpected_keys.remove(src_state_dict_name) missing_keys.extend(_missing_keys) @@ -415,17 +403,17 @@ def sharded_tbes_weights_spec( # } # In the format of ebc.tbes.i.j.table_k.weight, where i is the index of the TBE, j is the index of the embedding bag within TBE i, k is the index of the original table set in the ebc embedding_configs # e.g. ebc.tbes.1.1.table_1.weight, it represents second embedding bag within the second TBE. This part of weight is from a shard of table_1 + ret: Dict[str, WeightSpec] = {} for module_fqn, module in sharded_model.named_modules(): type_name: str = type(module).__name__ is_sqebc: bool = "ShardedQuantEmbeddingBagCollection" in type_name is_sqec: bool = "ShardedQuantEmbeddingCollection" in type_name - is_sqmcec: bool = "ShardedQuantManagedCollisionEmbeddingCollection" in type_name - if is_sqebc or is_sqec or is_sqmcec: - assert ( - is_sqec + is_sqebc + is_sqmcec == 1 - ), "Cannot have any two of ShardedQuantEmbeddingBagCollection, ShardedQuantEmbeddingCollection and ShardedQuantManagedCollisionEmbeddingCollection are true" + if is_sqebc or is_sqec: + assert not ( + is_sqebc and is_sqec + ), "Cannot be both ShardedQuantEmbeddingBagCollection and ShardedQuantEmbeddingCollection" tbes_configs: Dict[ IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig ] = module.tbes_configs() @@ -480,81 +468,3 @@ def sharded_tbes_weights_spec( sharding_type=sharding_type, ) return ret - - -def sharded_zchs_buffers_spec( - sharded_model: torch.nn.Module, -) -> Dict[str, WeightSpec]: - """ - OUTPUT: - Example: - "main_module.module.ec_in_task_arch_hash._decoupled_embedding_collection._mcec_lookup.0.0._mcc_lookup.zchs.viewer_rid_duplicate._hash_zch_identities", [0, 0], [500, 1]) - "main_module.module.ec_in_task_arch_hash._decoupled_embedding_collection._mcec_lookup.0.1._mcc_lookup.zchs.viewer_rid_duplicate._hash_zch_identities", [500, 0], [1000, 1]) - - 'main_module.module.ec_in_task_arch_hash._decoupled_embedding_collection._mcec_lookup.0.0._mcc_lookup.zchs.viewer_rid_duplicate._hash_zch_identities': WeightSpec(fqn='main_module.module.ec_in_task_arch_hash._ d_embedding_collection._managed_collision_collection.viewer_rid_duplicate._hash_zch_identities' - """ - - def _get_table_names( - sharded_module: torch.nn.Module, - ) -> List[str]: - table_names: List[str] = [] - for _, module in sharded_module.named_modules(): - type_name: str = type(module).__name__ - if "ShardedMCCRemapper" in type_name: - for table_name in module._tables: - if table_name not in table_names: - table_names.append(table_name) - return table_names - - def _get_unsharded_fqn_identities( - sharded_module: torch.nn.Module, - fqn: str, - table_name: str, - ) -> str: - for module_fqn, module in sharded_module.named_modules(): - type_name: str = type(module).__name__ - if "ManagedCollisionCollection" in type_name: - if table_name in module._table_to_features: - return f"{fqn}.{module_fqn}.{table_name}.{HashZchManagedCollisionModule.IDENTITY_BUFFER}" - logger.info(f"did not find table {table_name} in module {fqn}") - return "" - - ret: Dict[str, WeightSpec] = defaultdict() - for module_fqn, module in sharded_model.named_modules(): - type_name: str = type(module).__name__ - if "ShardedQuantManagedCollisionEmbeddingCollection" in type_name: - sharding_type = ShardingType.ROW_WISE.value - table_name_to_unsharded_fqn_identities: Dict[str, str] = {} - for subfqn, submodule in module.named_modules(): - type_name: str = type(submodule).__name__ - if "ShardedMCCRemapper" in type_name: - for table_name in submodule.zchs.keys(): - # identities tensor has only one column - shard_offsets: List[int] = [ - submodule._shard_metadata[table_name][0], - 0, - ] - shard_sizes: List[int] = [ - submodule._shard_metadata[table_name][1], - 1, - ] - if table_name not in table_name_to_unsharded_fqn_identities: - table_name_to_unsharded_fqn_identities[table_name] = ( - _get_unsharded_fqn_identities( - module, module_fqn, table_name - ) - ) - unsharded_fqn_identities: str = ( - table_name_to_unsharded_fqn_identities[table_name] - ) - # subfqn contains the index of sharding, so no need to add it specifically here - sharded_fqn_identities: str = ( - f"{module_fqn}.{subfqn}.zchs.{table_name}.{HashZchManagedCollisionModule.IDENTITY_BUFFER}" - ) - ret[sharded_fqn_identities] = WeightSpec( - fqn=unsharded_fqn_identities, - shard_offsets=shard_offsets, - shard_sizes=shard_sizes, - sharding_type=sharding_type, - ) - return ret diff --git a/torchrec/distributed/sharding/rw_sharding.py b/torchrec/distributed/sharding/rw_sharding.py index deac8359b..0ecdabb7a 100644 --- a/torchrec/distributed/sharding/rw_sharding.py +++ b/torchrec/distributed/sharding/rw_sharding.py @@ -7,7 +7,6 @@ # pyre-strict -import logging import math from typing import Any, cast, Dict, List, Optional, Tuple, TypeVar, Union @@ -59,7 +58,6 @@ from torchrec.sparse.jagged_tensor import KeyedJaggedTensor from torchrec.streamable import Multistreamable -logger: logging.Logger = logging.getLogger(__name__) C = TypeVar("C", bound=Multistreamable) F = TypeVar("F", bound=Multistreamable) @@ -576,39 +574,11 @@ def create_output_dist( ) -@torch.fx.wrap -def get_total_num_buckets_runtime_device( - total_num_buckets: Optional[List[int]], - runtime_device: torch.device, - tensor_cache: Dict[ - str, - Tuple[torch.Tensor, List[torch.Tensor]], - ], - dtype: torch.dtype = torch.int32, -) -> Optional[torch.Tensor]: - if total_num_buckets is None: - return None - cache_key: str = "__total_num_buckets" - if cache_key not in tensor_cache: - tensor_cache[cache_key] = ( - torch.tensor( - total_num_buckets, - device=runtime_device, - dtype=dtype, - ), - [], - ) - return tensor_cache[cache_key][0] - - @torch.fx.wrap def get_block_sizes_runtime_device( block_sizes: List[int], runtime_device: torch.device, - tensor_cache: Dict[ - str, - Tuple[torch.Tensor, List[torch.Tensor]], - ], + tensor_cache: Dict[str, Tuple[torch.Tensor, List[torch.Tensor]]], embedding_shard_metadata: Optional[List[List[int]]] = None, dtype: torch.dtype = torch.int32, ) -> Tuple[torch.Tensor, List[torch.Tensor]]: @@ -643,7 +613,6 @@ def __init__( world_size: int, num_features: int, feature_hash_sizes: List[int], - feature_total_num_buckets: Optional[List[int]] = None, device: Optional[torch.device] = None, is_sequence: bool = False, has_feature_processor: bool = False, @@ -651,22 +620,12 @@ def __init__( embedding_shard_metadata: Optional[List[List[int]]] = None, ) -> None: super().__init__() - logger.info( - f"InferRwSparseFeaturesDist: {world_size=}, {num_features=}, {feature_hash_sizes=}, {feature_total_num_buckets=}, {device=}, {is_sequence=}, {has_feature_processor=}, {need_pos=}, {embedding_shard_metadata=}" - ) self._world_size: int = world_size self._num_features = num_features - self._feature_total_num_buckets: Optional[List[int]] = feature_total_num_buckets - - self.feature_block_sizes: List[int] = [] - for i, hash_size in enumerate(feature_hash_sizes): - block_divisor = self._world_size - if feature_total_num_buckets is not None: - assert feature_total_num_buckets[i] % self._world_size == 0 - block_divisor = feature_total_num_buckets[i] - self.feature_block_sizes.append( - (hash_size + block_divisor - 1) // block_divisor - ) + self.feature_block_sizes: List[int] = [ + (hash_size + self._world_size - 1) // self._world_size + for hash_size in feature_hash_sizes + ] self.tensor_cache: Dict[ str, Tuple[torch.Tensor, Optional[List[torch.Tensor]]] ] = {} @@ -692,12 +651,6 @@ def forward(self, sparse_features: KeyedJaggedTensor) -> InputDistOutputs: self._embedding_shard_metadata, sparse_features.values().dtype, ) - total_num_buckets = get_total_num_buckets_runtime_device( - self._feature_total_num_buckets, - sparse_features.device(), - self.tensor_cache, - sparse_features.values().dtype, - ) ( bucketized_features, @@ -707,7 +660,6 @@ def forward(self, sparse_features: KeyedJaggedTensor) -> InputDistOutputs: sparse_features, num_buckets=self._world_size, block_sizes=block_sizes, - total_num_buckets=total_num_buckets, bucketize_pos=( self._has_feature_processor if sparse_features.weights_or_none() is None diff --git a/torchrec/distributed/sharding_plan.py b/torchrec/distributed/sharding_plan.py index 27b011300..a9e536015 100644 --- a/torchrec/distributed/sharding_plan.py +++ b/torchrec/distributed/sharding_plan.py @@ -27,12 +27,8 @@ from torchrec.distributed.mc_embeddingbag import ( ManagedCollisionEmbeddingBagCollectionSharder, ) -from torchrec.distributed.mc_modules import InferManagedCollisionCollectionSharder from torchrec.distributed.planner.constants import MIN_CW_DIM -from torchrec.distributed.quant_embedding import ( - QuantEmbeddingCollectionSharder, - QuantManagedCollisionEmbeddingCollectionSharder, -) +from torchrec.distributed.quant_embedding import QuantEmbeddingCollectionSharder from torchrec.distributed.quant_embeddingbag import QuantEmbeddingBagCollectionSharder from torchrec.distributed.types import ( EmbeddingModuleShardingPlan, @@ -55,13 +51,6 @@ def get_default_sharders() -> List[ModuleSharder[nn.Module]]: cast(ModuleSharder[nn.Module], QuantEmbeddingCollectionSharder()), cast(ModuleSharder[nn.Module], ManagedCollisionEmbeddingBagCollectionSharder()), cast(ModuleSharder[nn.Module], ManagedCollisionEmbeddingCollectionSharder()), - cast( - ModuleSharder[nn.Module], - QuantManagedCollisionEmbeddingCollectionSharder( - QuantEmbeddingCollectionSharder(), - InferManagedCollisionCollectionSharder(), - ), - ), ] @@ -845,7 +834,7 @@ def construct_module_sharding_plan( assert isinstance( module, sharder.module_type - ), f"Incorrect sharder {type(sharder)} for module type {type(module)}" + ), f"Incorrect sharder for module type {type(module)}" shardable_parameters = sharder.shardable_parameters(module) assert shardable_parameters.keys() == per_param_sharding.keys(), ( "per_param_sharding_config doesn't match the shardable parameters of the module," diff --git a/torchrec/distributed/tests/test_mc_embedding.py b/torchrec/distributed/tests/test_mc_embedding.py index 60de369d1..20f883e19 100644 --- a/torchrec/distributed/tests/test_mc_embedding.py +++ b/torchrec/distributed/tests/test_mc_embedding.py @@ -529,9 +529,8 @@ def _test_sharding_dedup( # noqa C901 dedup_loss1.backward() assert torch.allclose(loss1, dedup_loss1) - # deduping is not being used right now - # assert torch.allclose(remapped_1.values(), dedup_remapped_1.values()) - # assert torch.allclose(remapped_1.lengths(), dedup_remapped_1.lengths()) + assert torch.allclose(remapped_1.values(), dedup_remapped_1.values()) + assert torch.allclose(remapped_1.lengths(), dedup_remapped_1.lengths()) @skip_if_asan_class diff --git a/torchrec/distributed/tests/test_sharding_plan.py b/torchrec/distributed/tests/test_sharding_plan.py index b36800d08..d5ba9e774 100644 --- a/torchrec/distributed/tests/test_sharding_plan.py +++ b/torchrec/distributed/tests/test_sharding_plan.py @@ -15,9 +15,6 @@ import torch from hypothesis import given, settings, Verbosity from torchrec import distributed as trec_dist -from torchrec.distributed.quant_embedding import ( - QuantManagedCollisionEmbeddingCollectionSharder, -) from torchrec.distributed.sharding_plan import ( column_wise, construct_module_sharding_plan, @@ -66,7 +63,6 @@ from torchrec.quant.embedding_modules import ( EmbeddingBagCollection as QuantEmbeddingBagCollection, EmbeddingCollection as QuantEmbeddingCollection, - QuantManagedCollisionEmbeddingCollection, ) from torchrec.sparse.jagged_tensor import KeyedJaggedTensor @@ -896,24 +892,21 @@ def test_str(self) -> None: ) } ) - expected = """module: ebc + expected = """ +module: ebc - param | sharding type | compute kernel | ranks + param | sharding type | compute kernel | ranks -------- | ------------- | -------------- | ------ -user_id | table_wise | dense | [0] +user_id | table_wise | dense | [0] movie_id | row_wise | dense | [0, 1] - param | shard offsets | shard sizes | placement + param | shard offsets | shard sizes | placement -------- | ------------- | ----------- | ------------- user_id | [0, 0] | [4096, 32] | rank:0/cuda:0 movie_id | [0, 0] | [2048, 32] | rank:0/cuda:0 movie_id | [2048, 0] | [2048, 32] | rank:0/cuda:1 """ - self.maxDiff = None - for i in range(len(expected.splitlines())): - self.assertEqual( - expected.splitlines()[i].strip(), str(plan).splitlines()[i].strip() - ) + self.assertEqual(expected.strip(), str(plan)) def test_module_to_default_sharders(self) -> None: default_sharder_map = get_module_to_default_sharders() @@ -928,7 +921,6 @@ def test_module_to_default_sharders(self) -> None: QuantEmbeddingCollection, ManagedCollisionEmbeddingBagCollection, ManagedCollisionEmbeddingCollection, - QuantManagedCollisionEmbeddingCollection, ], ) self.assertIsInstance( @@ -962,8 +954,3 @@ def test_module_to_default_sharders(self) -> None: default_sharder_map[ManagedCollisionEmbeddingCollection], ManagedCollisionEmbeddingCollectionSharder, ) - - self.assertIsInstance( - default_sharder_map[QuantManagedCollisionEmbeddingCollection], - QuantManagedCollisionEmbeddingCollectionSharder, - )