diff --git a/torchrec/distributed/planner/partitioners.py b/torchrec/distributed/planner/partitioners.py index 856f26e8e..b2e80abe1 100644 --- a/torchrec/distributed/planner/partitioners.py +++ b/torchrec/distributed/planner/partitioners.py @@ -9,7 +9,7 @@ import logging from dataclasses import dataclass from enum import Enum -from typing import cast, List +from typing import cast, Dict, List from torchrec.distributed.planner.perf_models import NoopPerfModel @@ -58,6 +58,7 @@ class ShardingOptionGroup: sharding_options: List[ShardingOption] storage_sum: Storage perf_sum: float + param_count: int class SortBy(Enum): @@ -66,8 +67,20 @@ class SortBy(Enum): def _group_and_sort_non_uniform_sharding_options( - sharding_options: List[ShardingOption], sort_by: SortBy = SortBy.STORAGE + sharding_options: List[ShardingOption], + sort_by: SortBy = SortBy.STORAGE, + balance_modules: bool = False, ) -> List[ShardingOptionGroup]: + + # count modules by name + param_count: Dict[str, int] = {} + for sharding_option in sharding_options: + path = sharding_option.path + if path not in param_count: + param_count[path] = 0 + param_count[path] += 1 + logger.info(f"param_count is {param_count}") + sharding_option_groups_by_dependency = {} for sharding_option in sharding_options: if sharding_option.partition_by == PartitionByType.UNIFORM.value: @@ -79,6 +92,8 @@ def _group_and_sort_non_uniform_sharding_options( [sharding_option], sharding_option.total_storage, sharding_option.total_perf, + # negative value to indicate that smaller modules should be sorted first + param_count=-param_count[sharding_option.path], ) else: sharding_option_groups_by_dependency[group_key].sharding_options.append( @@ -90,25 +105,44 @@ def _group_and_sort_non_uniform_sharding_options( sharding_option_groups_by_dependency[ group_key ].perf_sum += sharding_option.total_perf + sharding_option_groups = list(sharding_option_groups_by_dependency.values()) + sort_by_attributes: List[str] = [] + if balance_modules: + sort_by_attributes.append("param_count") + if sort_by == SortBy.STORAGE: - sharding_option_groups.sort(key=lambda group: group.storage_sum, reverse=True) + sort_by_attributes.append("storage_sum") elif sort_by == SortBy.PERF: - sharding_option_groups.sort(key=lambda group: group.perf_sum, reverse=True) + sort_by_attributes.append("perf_sum") else: raise RuntimeError(f"Unexpected sort_by: {sort_by}") + sharding_option_groups.sort( + key=lambda group: [getattr(group, attr) for attr in sort_by_attributes], + reverse=True, + ) + return sharding_option_groups class GreedyPerfPartitioner(Partitioner): - """ - Greedy Partitioner + """Greedy Partitioner + + Args: + sort_by (SortBy): Sort sharding options by storage or perf in + descending order (i.e., large tables will be placed first). + balance_modules (bool): Whether to sort by modules first, where + smaller modules will be sorted first. In effect, this will place + tables in each module in a balanced way. """ - def __init__(self, sort_by: SortBy = SortBy.STORAGE) -> None: + def __init__( + self, sort_by: SortBy = SortBy.STORAGE, balance_modules: bool = False + ) -> None: self._sort_by = sort_by + self._balance_modules = balance_modules def partition( self, @@ -186,7 +220,7 @@ def partition( # group the rest sharding options by colocation type (co-host, co-device, none) # and sort the groups by storage in reverse order sharding_option_groups = _group_and_sort_non_uniform_sharding_options( - proposal, sort_by=self._sort_by + proposal, sort_by=self._sort_by, balance_modules=self._balance_modules ) for sharding_option_group in sharding_option_groups: @@ -335,13 +369,29 @@ def _uniform_partition( class MemoryBalancedPartitioner(Partitioner): - """ - Memory balanced Partitioner. + """Memory balanced Partitioner. + + Args: + max_search_count (int): Maximum number of times to call the + GreedyPartitioner. + tolerance (float): The maximum acceptable difference between the + original plan and the new plan. If tolerance is 1, that means a new + plan will be rejected if its perf is 200% of the original plan + (i.e., the plan is 100% worse). + balance_modules (bool): Whether to sort by modules first, where + smaller modules will be sorted first. In effect, this will place + tables in each module in a balanced way. """ - def __init__(self, max_search_count: int = 10, tolerance: float = 0.02) -> None: + def __init__( + self, + max_search_count: int = 10, + tolerance: float = 0.02, + balance_modules: bool = False, + ) -> None: self._max_search_count: int = max_search_count self._tolerance: float = tolerance + self._balance_modules: bool = balance_modules def partition( self, @@ -354,7 +404,9 @@ def partition( of memory. """ _perf_model: PerfModel = NoopPerfModel(storage_constraint) - _partitioner = GreedyPerfPartitioner(sort_by=SortBy.PERF) + _partitioner = GreedyPerfPartitioner( + sort_by=SortBy.PERF, balance_modules=self._balance_modules + ) # copying storage_constraint, since we modify it in place _topology: Topology = copy.deepcopy(storage_constraint) diff --git a/torchrec/distributed/planner/tests/test_partitioners.py b/torchrec/distributed/planner/tests/test_partitioners.py index da91e22a6..e72ca2261 100644 --- a/torchrec/distributed/planner/tests/test_partitioners.py +++ b/torchrec/distributed/planner/tests/test_partitioners.py @@ -714,3 +714,133 @@ def test_different_sharding_plan(self) -> None: memory_balanced_hbm_uses[shard.rank] += shard.storage.hbm self.assertTrue(max(memory_balanced_hbm_uses) < max(greedy_perf_hbm_uses)) + + +class TestBalanceModules(unittest.TestCase): + def setUp(self) -> None: + compute_device = "cuda" + self.topology = Topology(world_size=2, compute_device=compute_device) + tables = [ + EmbeddingBagConfig( + num_embeddings=100 + i, + embedding_dim=4 * (10 + i), + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(1) + ] + weighted_tables = [ + EmbeddingBagConfig( + num_embeddings=200 + i, + embedding_dim=8 * (10 + i), + name="weighted_table_" + str(i), + feature_names=["weighted_feature_" + str(i)], + ) + for i in range(3) + ] + self.topology = Topology( + world_size=2, + compute_device=compute_device, + hbm_cap=2000 * 1024**2, + ) + self.model = TestSparseNN(tables=tables, weighted_tables=weighted_tables) + self.enumerator = EmbeddingEnumerator( + topology=self.topology, batch_size=BATCH_SIZE + ) + + self.sharding_options = self.enumerator.enumerate( + module=self.model, sharders=[TWSharder()] + ) + for sharding_option in self.sharding_options: + sharding_option.shards[0].perf = Perf( + fwd_compute=40, fwd_comms=30, bwd_compute=20, bwd_comms=10 + ) + sharding_option.shards[0].storage = Storage( + hbm=10 * 1024**2, ddr=1000 * 1024**2 + ) + + def test_greedy_partitioner(self) -> None: + greedy_partitioner = GreedyPerfPartitioner(balance_modules=False) + balance_modules_greedy_partitioner = GreedyPerfPartitioner(balance_modules=True) + + greedy_sharding_plan = greedy_partitioner.partition( + proposal=self.sharding_options, + storage_constraint=self.topology, + ) + greedy_ranks = { + sharding_option.name: [shard.rank for shard in sharding_option.shards] + for sharding_option in greedy_sharding_plan + } + + reset_shard_rank(self.sharding_options) + + balance_modules_sharding_plan = balance_modules_greedy_partitioner.partition( + proposal=self.sharding_options, + storage_constraint=self.topology, + ) + balance_modules_ranks = { + sharding_option.name: [shard.rank for shard in sharding_option.shards] + for sharding_option in balance_modules_sharding_plan + } + + greedy_expected_ranks = { + "weighted_table_0": [0], + "weighted_table_1": [1], + "weighted_table_2": [0], + "table_0": [1], + } + balance_modules_expected_ranks = { + "weighted_table_0": [1], + "weighted_table_1": [0], + "weighted_table_2": [1], + "table_0": [0], + } + + self.assertEqual(greedy_expected_ranks, greedy_ranks) + self.assertEqual(balance_modules_expected_ranks, balance_modules_ranks) + + def test_memory_balanced_partitioner(self) -> None: + memory_balanced_partitioner = MemoryBalancedPartitioner( + tolerance=100, balance_modules=False + ) + balance_modules_memory_balanced_partitioner = MemoryBalancedPartitioner( + tolerance=100, balance_modules=True + ) + + memory_balanced_plan = memory_balanced_partitioner.partition( + proposal=self.sharding_options, + storage_constraint=self.topology, + ) + memory_balanced_ranks = { + sharding_option.name: [shard.rank for shard in sharding_option.shards] + for sharding_option in memory_balanced_plan + } + + reset_shard_rank(self.sharding_options) + + balance_modules_sharding_plan = ( + balance_modules_memory_balanced_partitioner.partition( + proposal=self.sharding_options, + storage_constraint=self.topology, + ) + ) + balance_modules_ranks = { + sharding_option.name: [shard.rank for shard in sharding_option.shards] + for sharding_option in balance_modules_sharding_plan + } + + memory_balanced_expected_ranks = { + "weighted_table_0": [0], + "weighted_table_1": [1], + "weighted_table_2": [0], + "table_0": [1], + } + balance_modules_expected_ranks = { + "weighted_table_0": [1], + "weighted_table_1": [0], + "weighted_table_2": [1], + "table_0": [0], + } + + self.assertEqual(memory_balanced_expected_ranks, memory_balanced_ranks) + self.assertEqual(balance_modules_expected_ranks, balance_modules_ranks)