diff --git a/torchrec/distributed/planner/enumerators.py b/torchrec/distributed/planner/enumerators.py index d14db6f9a..90a506e6c 100644 --- a/torchrec/distributed/planner/enumerators.py +++ b/torchrec/distributed/planner/enumerators.py @@ -125,6 +125,7 @@ def enumerate( enforce_hbm, stochastic_rounding, bounds_check_mode, + feature_names, ) = _extract_constraints_for_param(self._constraints, name) sharding_options_per_table: List[ShardingOption] = [] @@ -172,6 +173,7 @@ def enumerate( bounds_check_mode=bounds_check_mode, dependency=dependency, is_pooled=is_pooled, + feature_names=feature_names, ) ) if not sharding_options_per_table: @@ -265,6 +267,7 @@ def _extract_constraints_for_param( Optional[bool], Optional[bool], Optional[BoundsCheckMode], + Optional[List[str]], ]: input_lengths = [POOLING_FACTOR] col_wise_shard_dim = None @@ -272,6 +275,7 @@ def _extract_constraints_for_param( enforce_hbm = None stochastic_rounding = None bounds_check_mode = None + feature_names = None if constraints and constraints.get(name): input_lengths = constraints[name].pooling_factors @@ -280,6 +284,7 @@ def _extract_constraints_for_param( enforce_hbm = constraints[name].enforce_hbm stochastic_rounding = constraints[name].stochastic_rounding bounds_check_mode = constraints[name].bounds_check_mode + feature_names = constraints[name].feature_names return ( input_lengths, @@ -288,6 +293,7 @@ def _extract_constraints_for_param( enforce_hbm, stochastic_rounding, bounds_check_mode, + feature_names, ) diff --git a/torchrec/distributed/planner/tests/test_planners.py b/torchrec/distributed/planner/tests/test_planners.py index a41678cfd..b8f52e118 100644 --- a/torchrec/distributed/planner/tests/test_planners.py +++ b/torchrec/distributed/planner/tests/test_planners.py @@ -6,7 +6,7 @@ # LICENSE file in the root directory of this source tree. import unittest -from typing import cast, List +from typing import cast, List, Optional import torch from torch import nn @@ -14,7 +14,12 @@ from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder from torchrec.distributed.planner import ParameterConstraints from torchrec.distributed.planner.planners import EmbeddingShardingPlanner -from torchrec.distributed.planner.types import PlannerError, PlannerErrorType, Topology +from torchrec.distributed.planner.types import ( + PlannerError, + PlannerErrorType, + ShardingOption, + Topology, +) from torchrec.distributed.sharding_plan import get_default_sharders from torchrec.distributed.test_utils.test_model import TestSparseNN from torchrec.distributed.types import ( @@ -174,18 +179,32 @@ def setUp(self) -> None: self.topology = Topology( world_size=2, hbm_cap=1024 * 1024 * 2, compute_device=compute_device ) + self.tables = [ + EmbeddingBagConfig( + num_embeddings=100, + embedding_dim=64, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(4) + ] self.constraints = { "table_0": ParameterConstraints( enforce_hbm=True, cache_params=CacheParams( algorithm=CacheAlgorithm.LFU, ), + feature_names=self.tables[0].feature_names, ), "table_1": ParameterConstraints( enforce_hbm=False, stochastic_rounding=True, + feature_names=self.tables[1].feature_names, + ), + "table_2": ParameterConstraints( + bounds_check_mode=BoundsCheckMode.FATAL, + feature_names=self.tables[2].feature_names, ), - "table_2": ParameterConstraints(bounds_check_mode=BoundsCheckMode.FATAL), "table_3": ParameterConstraints( cache_params=CacheParams( algorithm=CacheAlgorithm.LFU, @@ -193,6 +212,7 @@ def setUp(self) -> None: reserved_memory=1.0, precision=DataType.FP16, ), + feature_names=self.tables[3].feature_names, ), } self.planner = EmbeddingShardingPlanner( @@ -200,16 +220,7 @@ def setUp(self) -> None: ) def test_fused_paramters_from_constraints(self) -> None: - tables = [ - EmbeddingBagConfig( - num_embeddings=100, - embedding_dim=64, - name="table_" + str(i), - feature_names=["feature_" + str(i)], - ) - for i in range(4) - ] - model = TestSparseNN(tables=tables, sparse_device=torch.device("meta")) + model = TestSparseNN(tables=self.tables, sparse_device=torch.device("meta")) sharding_plan = self.planner.plan(module=model, sharders=get_default_sharders()) expected_fused_params = { @@ -253,3 +264,28 @@ def test_fused_paramters_from_constraints(self) -> None: ), expected_fused_params[table], ) + + def test_passing_info_through_constraints(self) -> None: + model = TestSparseNN(tables=self.tables, sparse_device=torch.device("meta")) + _ = self.planner.plan(module=model, sharders=get_default_sharders()) + + best_plan: Optional[List[ShardingOption]] = self.planner._best_plan + self.assertIsNotNone(best_plan) + + for table, constraint, sharding_option in zip( + self.tables, self.constraints.values(), best_plan + ): + self.assertEqual(table.name, sharding_option.name) + + self.assertEqual(table.feature_names, sharding_option.feature_names) + self.assertEqual(table.feature_names, constraint.feature_names) + + self.assertEqual(constraint.cache_params, sharding_option.cache_params) + self.assertEqual(constraint.enforce_hbm, sharding_option.enforce_hbm) + self.assertEqual( + constraint.stochastic_rounding, sharding_option.stochastic_rounding + ) + self.assertEqual( + constraint.bounds_check_mode, sharding_option.bounds_check_mode + ) + self.assertEqual(constraint.is_weighted, sharding_option.is_weighted) diff --git a/torchrec/distributed/planner/types.py b/torchrec/distributed/planner/types.py index 7b92ad2b3..f3fe32df5 100644 --- a/torchrec/distributed/planner/types.py +++ b/torchrec/distributed/planner/types.py @@ -288,6 +288,7 @@ def __init__( bounds_check_mode: Optional[BoundsCheckMode] = None, dependency: Optional[str] = None, is_pooled: Optional[bool] = None, + feature_names: Optional[List[str]] = None, ) -> None: self.name = name self._tensor = tensor @@ -307,6 +308,7 @@ def __init__( self.dependency = dependency self._is_pooled = is_pooled self.is_weighted: Optional[bool] = None + self.feature_names: Optional[List[str]] = feature_names @property def tensor(self) -> torch.Tensor: @@ -436,6 +438,7 @@ class ParameterConstraints: enforce_hbm: Optional[bool] = None stochastic_rounding: Optional[bool] = None bounds_check_mode: Optional[BoundsCheckMode] = None + feature_names: Optional[List[str]] = None class PlannerErrorType(Enum):