diff --git a/torchrec/distributed/planner/enumerators.py b/torchrec/distributed/planner/enumerators.py index 6f6c2f61e..d14db6f9a 100644 --- a/torchrec/distributed/planner/enumerators.py +++ b/torchrec/distributed/planner/enumerators.py @@ -127,6 +127,8 @@ def enumerate( bounds_check_mode, ) = _extract_constraints_for_param(self._constraints, name) + sharding_options_per_table: List[ShardingOption] = [] + for sharding_type in self._filter_sharding_types( name, sharder.sharding_types(self._compute_device) ): @@ -150,7 +152,7 @@ def enumerate( elif isinstance(child_module, EmbeddingTowerCollection): tower_index = _get_tower_index(name, child_module) dependency = child_path + ".tower_" + str(tower_index) - sharding_options.append( + sharding_options_per_table.append( ShardingOption( name=name, tensor=param, @@ -172,12 +174,14 @@ def enumerate( is_pooled=is_pooled, ) ) - if not sharding_options: + if not sharding_options_per_table: raise RuntimeError( "No available sharding type and compute kernel combination " f"after applying user provided constraints for {name}" ) + sharding_options.extend(sharding_options_per_table) + self.populate_estimates(sharding_options) return sharding_options diff --git a/torchrec/distributed/planner/tests/test_enumerators.py b/torchrec/distributed/planner/tests/test_enumerators.py index 9291aed3e..3736f8d62 100644 --- a/torchrec/distributed/planner/tests/test_enumerators.py +++ b/torchrec/distributed/planner/tests/test_enumerators.py @@ -858,3 +858,49 @@ def test_tower_collection_sharding(self) -> None: def test_empty(self) -> None: sharding_options = self.enumerator.enumerate(self.model, sharders=[]) self.assertFalse(sharding_options) + + def test_throw_ex_no_sharding_option_for_table(self) -> None: + cw_constraint = ParameterConstraints( + sharding_types=[ + ShardingType.COLUMN_WISE.value, + ], + compute_kernels=[ + EmbeddingComputeKernel.FUSED.value, + ], + ) + + rw_constraint = ParameterConstraints( + sharding_types=[ + ShardingType.TABLE_ROW_WISE.value, + ], + compute_kernels=[ + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + ], + ) + + constraints = { + "table_0": cw_constraint, + "table_1": rw_constraint, + "table_2": cw_constraint, + "table_3": cw_constraint, + } + + enumerator = EmbeddingEnumerator( + topology=Topology( + world_size=self.world_size, + compute_device=self.compute_device, + local_world_size=self.local_world_size, + ), + batch_size=self.batch_size, + constraints=constraints, + ) + + sharder = cast(ModuleSharder[torch.nn.Module], CWSharder()) + + with self.assertRaises(Exception) as context: + _ = enumerator.enumerate(self.model, [sharder]) + + self.assertTrue( + "No available sharding type and compute kernel combination after applying user provided constraints for table_1" + in str(context.exception) + )