Skip to content

Commit

Permalink
Add option to balance modules to partitioners (pytorch#1509)
Browse files Browse the repository at this point in the history
Summary:

This diff will make sure module count on GPUs are balanced. The way we do it is to sort the tables from a smaller module first.

This should work fine with dependency (aka towers), since dependency is either child_path or starts with child_path, which is the path.

Differential Revision: D51180772
  • Loading branch information
hlhtsang authored and facebook-github-bot committed Nov 14, 2023
1 parent abdd5fb commit 35d8eb3
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 8 deletions.
52 changes: 44 additions & 8 deletions torchrec/distributed/planner/partitioners.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import logging
from dataclasses import dataclass
from enum import Enum
from typing import cast, List
from typing import cast, Dict, List

from torchrec.distributed.planner.perf_models import NoopPerfModel

Expand Down Expand Up @@ -58,6 +58,7 @@ class ShardingOptionGroup:
sharding_options: List[ShardingOption]
storage_sum: Storage
perf_sum: float
module_count: int


class SortBy(Enum):
Expand All @@ -66,8 +67,20 @@ class SortBy(Enum):


def _group_and_sort_non_uniform_sharding_options(
sharding_options: List[ShardingOption], sort_by: SortBy = SortBy.STORAGE
sharding_options: List[ShardingOption],
sort_by: SortBy = SortBy.STORAGE,
balance_modules: bool = False,
) -> List[ShardingOptionGroup]:

# count modules by name
module_count: Dict[str, int] = {}
for sharding_option in sharding_options:
path = sharding_option.path
if path not in module_count:
module_count[path] = 0
module_count[path] += 1
logger.info(f"module_count is {module_count}")

sharding_option_groups_by_dependency = {}
for sharding_option in sharding_options:
if sharding_option.partition_by == PartitionByType.UNIFORM.value:
Expand All @@ -79,6 +92,8 @@ def _group_and_sort_non_uniform_sharding_options(
[sharding_option],
sharding_option.total_storage,
sharding_option.total_perf,
# negative value to indicate that smaller modules should be sorted first
module_count=-module_count[sharding_option.path],
)
else:
sharding_option_groups_by_dependency[group_key].sharding_options.append(
Expand All @@ -90,15 +105,25 @@ def _group_and_sort_non_uniform_sharding_options(
sharding_option_groups_by_dependency[
group_key
].perf_sum += sharding_option.total_perf

sharding_option_groups = list(sharding_option_groups_by_dependency.values())

sort_by_attributes: List[str] = []
if balance_modules:
sort_by_attributes.append("module_count")

if sort_by == SortBy.STORAGE:
sharding_option_groups.sort(key=lambda group: group.storage_sum, reverse=True)
sort_by_attributes.append("storage_sum")
elif sort_by == SortBy.PERF:
sharding_option_groups.sort(key=lambda group: group.perf_sum, reverse=True)
sort_by_attributes.append("perf_sum")
else:
raise RuntimeError(f"Unexpected sort_by: {sort_by}")

sharding_option_groups.sort(
key=lambda group: [getattr(group, attr) for attr in sort_by_attributes],
reverse=True,
)

return sharding_option_groups


Expand All @@ -107,8 +132,11 @@ class GreedyPerfPartitioner(Partitioner):
Greedy Partitioner
"""

def __init__(self, sort_by: SortBy = SortBy.STORAGE) -> None:
def __init__(
self, sort_by: SortBy = SortBy.STORAGE, balance_modules: bool = False
) -> None:
self._sort_by = sort_by
self._balance_modules = balance_modules

def partition(
self,
Expand Down Expand Up @@ -186,7 +214,7 @@ def partition(
# group the rest sharding options by colocation type (co-host, co-device, none)
# and sort the groups by storage in reverse order
sharding_option_groups = _group_and_sort_non_uniform_sharding_options(
proposal, sort_by=self._sort_by
proposal, sort_by=self._sort_by, balance_modules=self._balance_modules
)

for sharding_option_group in sharding_option_groups:
Expand Down Expand Up @@ -339,9 +367,15 @@ class MemoryBalancedPartitioner(Partitioner):
Memory balanced Partitioner.
"""

def __init__(self, max_search_count: int = 10, tolerance: float = 0.02) -> None:
def __init__(
self,
max_search_count: int = 10,
tolerance: float = 0.02,
balance_modules: bool = False,
) -> None:
self._max_search_count: int = max_search_count
self._tolerance: float = tolerance
self._balance_modules: bool = balance_modules

def partition(
self,
Expand All @@ -354,7 +388,9 @@ def partition(
of memory.
"""
_perf_model: PerfModel = NoopPerfModel(storage_constraint)
_partitioner = GreedyPerfPartitioner(sort_by=SortBy.PERF)
_partitioner = GreedyPerfPartitioner(
sort_by=SortBy.PERF, balance_modules=self._balance_modules
)
# copying storage_constraint, since we modify it in place
_topology: Topology = copy.deepcopy(storage_constraint)

Expand Down
133 changes: 133 additions & 0 deletions torchrec/distributed/planner/tests/test_partitioners.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,3 +714,136 @@ def test_different_sharding_plan(self) -> None:
memory_balanced_hbm_uses[shard.rank] += shard.storage.hbm

self.assertTrue(max(memory_balanced_hbm_uses) < max(greedy_perf_hbm_uses))


class TestBalanceModules(unittest.TestCase):
def setUp(self) -> None:
compute_device = "cuda"
self.topology = Topology(world_size=2, compute_device=compute_device)
tables = [
EmbeddingBagConfig(
num_embeddings=100 + i,
embedding_dim=4 * (10 + i),
name="table_" + str(i),
feature_names=["feature_" + str(i)],
)
for i in range(1)
]
weighted_tables = [
EmbeddingBagConfig(
num_embeddings=200 + i,
embedding_dim=8 * (10 + i),
name="weighted_table_" + str(i),
feature_names=["weighted_feature_" + str(i)],
)
for i in range(3)
]
self.topology = Topology(
world_size=2,
compute_device=compute_device,
hbm_cap=2000 * 1024**2,
)
self.model = TestSparseNN(tables=tables, weighted_tables=weighted_tables)
self.enumerator = EmbeddingEnumerator(
topology=self.topology, batch_size=BATCH_SIZE
)

self.sharding_options = self.enumerator.enumerate(
module=self.model, sharders=[TWSharder()]
)
for sharding_option in self.sharding_options:
sharding_option.shards[0].perf = Perf(
fwd_compute=40, fwd_comms=30, bwd_compute=20, bwd_comms=10
)
sharding_option.shards[0].storage = Storage(
hbm=10 * 1024**2, ddr=1000 * 1024**2
)

def test_greedy_partitioner(self) -> None:
greedy_partitioner = GreedyPerfPartitioner(balance_modules=False)
balance_modules_greedy_partitioner = GreedyPerfPartitioner(balance_modules=True)

greedy_sharding_plan = greedy_partitioner.partition(
proposal=self.sharding_options,
storage_constraint=self.topology,
)
greedy_ranks = {
sharding_option.name: [shard.rank for shard in sharding_option.shards]
for sharding_option in greedy_sharding_plan
}

reset_shard_rank(self.sharding_options)

balance_modules_sharding_plan = balance_modules_greedy_partitioner.partition(
proposal=self.sharding_options,
storage_constraint=self.topology,
)
balance_modules_ranks = {
sharding_option.name: [shard.rank for shard in sharding_option.shards]
for sharding_option in balance_modules_sharding_plan
}

greedy_expected_ranks = {
"weighted_table_0": [0],
"weighted_table_1": [1],
"weighted_table_2": [0],
"table_0": [1],
}
balance_modules_expected_ranks = {
"weighted_table_0": [1],
"weighted_table_1": [0],
"weighted_table_2": [1],
"table_0": [0],
}

self.assertEqual(greedy_expected_ranks, greedy_ranks)
self.assertEqual(balance_modules_expected_ranks, balance_modules_ranks)

def test_memory_balanced_partitioner(self) -> None:
memory_balanced_partitioner = MemoryBalancedPartitioner(
tolerance=100, balance_modules=False
)
balance_modules_memory_balanced_partitioner = MemoryBalancedPartitioner(
tolerance=100, balance_modules=True
)

memory_balanced_plan = memory_balanced_partitioner.partition(
proposal=self.sharding_options,
storage_constraint=self.topology,
)
memory_balanced_ranks = {
sharding_option.name: [shard.rank for shard in sharding_option.shards]
for sharding_option in memory_balanced_plan
}

reset_shard_rank(self.sharding_options)

balance_modules_sharding_plan = (
balance_modules_memory_balanced_partitioner.partition(
proposal=self.sharding_options,
storage_constraint=self.topology,
)
)
balance_modules_ranks = {
sharding_option.name: [shard.rank for shard in sharding_option.shards]
for sharding_option in balance_modules_sharding_plan
}

memory_balanced_expected_ranks = {
"weighted_table_0": [0],
"weighted_table_1": [1],
"weighted_table_2": [0],
"table_0": [1],
}
balance_modules_expected_ranks = {
"weighted_table_0": [1],
"weighted_table_1": [0],
"weighted_table_2": [1],
"table_0": [0],
}

print(f"henry greedy_ranks {memory_balanced_ranks}")
print(f"henry balance_modules_expected_ranks {balance_modules_ranks}")

self.assertEqual(memory_balanced_expected_ranks, memory_balanced_ranks)
self.assertEqual(balance_modules_expected_ranks, balance_modules_ranks)

0 comments on commit 35d8eb3

Please sign in to comment.