From f55935db78d04e47a17a4fdbaaea116d55398ff2 Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Mon, 13 Nov 2023 16:18:17 -0800 Subject: [PATCH] Add option to balance modules to partitioners Summary: This diff will make sure module count on GPUs are balanced. The way we do it is to sort the tables from a smaller module first. This should work fine with dependency (aka towers), since dependency is either child_path or starts with child_path, which is the path. Differential Revision: D51180772 --- torchrec/distributed/planner/partitioners.py | 52 +++++-- .../planner/tests/test_partitioners.py | 133 ++++++++++++++++++ 2 files changed, 177 insertions(+), 8 deletions(-) diff --git a/torchrec/distributed/planner/partitioners.py b/torchrec/distributed/planner/partitioners.py index 856f26e8e..136cf04fe 100644 --- a/torchrec/distributed/planner/partitioners.py +++ b/torchrec/distributed/planner/partitioners.py @@ -9,7 +9,7 @@ import logging from dataclasses import dataclass from enum import Enum -from typing import cast, List +from typing import cast, Dict, List from torchrec.distributed.planner.perf_models import NoopPerfModel @@ -58,6 +58,7 @@ class ShardingOptionGroup: sharding_options: List[ShardingOption] storage_sum: Storage perf_sum: float + module_count: int class SortBy(Enum): @@ -66,8 +67,20 @@ class SortBy(Enum): def _group_and_sort_non_uniform_sharding_options( - sharding_options: List[ShardingOption], sort_by: SortBy = SortBy.STORAGE + sharding_options: List[ShardingOption], + sort_by: SortBy = SortBy.STORAGE, + balance_modules: bool = False, ) -> List[ShardingOptionGroup]: + + # count modules by name + module_count: Dict[str, int] = {} + for sharding_option in sharding_options: + path = sharding_option.path + if path not in module_count: + module_count[path] = 0 + module_count[path] += 1 + logger.info(f"module_count is {module_count}") + sharding_option_groups_by_dependency = {} for sharding_option in sharding_options: if sharding_option.partition_by == PartitionByType.UNIFORM.value: @@ -79,6 +92,8 @@ def _group_and_sort_non_uniform_sharding_options( [sharding_option], sharding_option.total_storage, sharding_option.total_perf, + # negative value to indicate that smaller modules should be sorted first + module_count=-module_count[sharding_option.path], ) else: sharding_option_groups_by_dependency[group_key].sharding_options.append( @@ -90,15 +105,25 @@ def _group_and_sort_non_uniform_sharding_options( sharding_option_groups_by_dependency[ group_key ].perf_sum += sharding_option.total_perf + sharding_option_groups = list(sharding_option_groups_by_dependency.values()) + sort_by_attributes: List[str] = [] + if balance_modules: + sort_by_attributes.append("module_count") + if sort_by == SortBy.STORAGE: - sharding_option_groups.sort(key=lambda group: group.storage_sum, reverse=True) + sort_by_attributes.append("storage_sum") elif sort_by == SortBy.PERF: - sharding_option_groups.sort(key=lambda group: group.perf_sum, reverse=True) + sort_by_attributes.append("perf_sum") else: raise RuntimeError(f"Unexpected sort_by: {sort_by}") + sharding_option_groups.sort( + key=lambda group: [getattr(group, attr) for attr in sort_by_attributes], + reverse=True, + ) + return sharding_option_groups @@ -107,8 +132,11 @@ class GreedyPerfPartitioner(Partitioner): Greedy Partitioner """ - def __init__(self, sort_by: SortBy = SortBy.STORAGE) -> None: + def __init__( + self, sort_by: SortBy = SortBy.STORAGE, balance_modules: bool = False + ) -> None: self._sort_by = sort_by + self._balance_modules = balance_modules def partition( self, @@ -186,7 +214,7 @@ def partition( # group the rest sharding options by colocation type (co-host, co-device, none) # and sort the groups by storage in reverse order sharding_option_groups = _group_and_sort_non_uniform_sharding_options( - proposal, sort_by=self._sort_by + proposal, sort_by=self._sort_by, balance_modules=self._balance_modules ) for sharding_option_group in sharding_option_groups: @@ -339,9 +367,15 @@ class MemoryBalancedPartitioner(Partitioner): Memory balanced Partitioner. """ - def __init__(self, max_search_count: int = 10, tolerance: float = 0.02) -> None: + def __init__( + self, + max_search_count: int = 10, + tolerance: float = 0.02, + balance_modules: bool = False, + ) -> None: self._max_search_count: int = max_search_count self._tolerance: float = tolerance + self._balance_modules: bool = balance_modules def partition( self, @@ -354,7 +388,9 @@ def partition( of memory. """ _perf_model: PerfModel = NoopPerfModel(storage_constraint) - _partitioner = GreedyPerfPartitioner(sort_by=SortBy.PERF) + _partitioner = GreedyPerfPartitioner( + sort_by=SortBy.PERF, balance_modules=self._balance_modules + ) # copying storage_constraint, since we modify it in place _topology: Topology = copy.deepcopy(storage_constraint) diff --git a/torchrec/distributed/planner/tests/test_partitioners.py b/torchrec/distributed/planner/tests/test_partitioners.py index da91e22a6..3d85847ec 100644 --- a/torchrec/distributed/planner/tests/test_partitioners.py +++ b/torchrec/distributed/planner/tests/test_partitioners.py @@ -714,3 +714,136 @@ def test_different_sharding_plan(self) -> None: memory_balanced_hbm_uses[shard.rank] += shard.storage.hbm self.assertTrue(max(memory_balanced_hbm_uses) < max(greedy_perf_hbm_uses)) + + +class TestBalanceModules(unittest.TestCase): + def setUp(self) -> None: + compute_device = "cuda" + self.topology = Topology(world_size=2, compute_device=compute_device) + tables = [ + EmbeddingBagConfig( + num_embeddings=100 + i, + embedding_dim=4 * (10 + i), + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(1) + ] + weighted_tables = [ + EmbeddingBagConfig( + num_embeddings=200 + i, + embedding_dim=8 * (10 + i), + name="weighted_table_" + str(i), + feature_names=["weighted_feature_" + str(i)], + ) + for i in range(3) + ] + self.topology = Topology( + world_size=2, + compute_device=compute_device, + hbm_cap=2000 * 1024**2, + ) + self.model = TestSparseNN(tables=tables, weighted_tables=weighted_tables) + self.enumerator = EmbeddingEnumerator( + topology=self.topology, batch_size=BATCH_SIZE + ) + + self.sharding_options = self.enumerator.enumerate( + module=self.model, sharders=[TWSharder()] + ) + for sharding_option in self.sharding_options: + sharding_option.shards[0].perf = Perf( + fwd_compute=40, fwd_comms=30, bwd_compute=20, bwd_comms=10 + ) + sharding_option.shards[0].storage = Storage( + hbm=10 * 1024**2, ddr=1000 * 1024**2 + ) + + def test_greedy_partitioner(self) -> None: + greedy_partitioner = GreedyPerfPartitioner(balance_modules=False) + balance_modules_greedy_partitioner = GreedyPerfPartitioner(balance_modules=True) + + greedy_sharding_plan = greedy_partitioner.partition( + proposal=self.sharding_options, + storage_constraint=self.topology, + ) + greedy_ranks = { + sharding_option.name: [shard.rank for shard in sharding_option.shards] + for sharding_option in greedy_sharding_plan + } + + reset_shard_rank(self.sharding_options) + + balance_modules_sharding_plan = balance_modules_greedy_partitioner.partition( + proposal=self.sharding_options, + storage_constraint=self.topology, + ) + balance_modules_ranks = { + sharding_option.name: [shard.rank for shard in sharding_option.shards] + for sharding_option in balance_modules_sharding_plan + } + + greedy_expected_ranks = { + "weighted_table_0": [0], + "weighted_table_1": [1], + "weighted_table_2": [0], + "table_0": [1], + } + balance_modules_expected_ranks = { + "weighted_table_0": [1], + "weighted_table_1": [0], + "weighted_table_2": [1], + "table_0": [0], + } + + self.assertEqual(greedy_expected_ranks, greedy_ranks) + self.assertEqual(balance_modules_expected_ranks, balance_modules_ranks) + + def test_memory_balanced_partitioner(self) -> None: + memory_balanced_partitioner = MemoryBalancedPartitioner( + tolerance=100, balance_modules=False + ) + balance_modules_memory_balanced_partitioner = MemoryBalancedPartitioner( + tolerance=100, balance_modules=True + ) + + memory_balanced_plan = memory_balanced_partitioner.partition( + proposal=self.sharding_options, + storage_constraint=self.topology, + ) + memory_balanced_ranks = { + sharding_option.name: [shard.rank for shard in sharding_option.shards] + for sharding_option in memory_balanced_plan + } + + reset_shard_rank(self.sharding_options) + + balance_modules_sharding_plan = ( + balance_modules_memory_balanced_partitioner.partition( + proposal=self.sharding_options, + storage_constraint=self.topology, + ) + ) + balance_modules_ranks = { + sharding_option.name: [shard.rank for shard in sharding_option.shards] + for sharding_option in balance_modules_sharding_plan + } + + memory_balanced_expected_ranks = { + "weighted_table_0": [0], + "weighted_table_1": [1], + "weighted_table_2": [0], + "table_0": [1], + } + balance_modules_expected_ranks = { + "weighted_table_0": [1], + "weighted_table_1": [0], + "weighted_table_2": [1], + "table_0": [0], + } + + print(f"henry greedy_ranks {memory_balanced_ranks}") + print(f"henry balance_modules_expected_ranks {balance_modules_ranks}") + + self.assertEqual(memory_balanced_expected_ranks, memory_balanced_ranks) + self.assertEqual(balance_modules_expected_ranks, balance_modules_ranks)