Skip to content

Commit

Permalink
Fix O(N^2) scaling in enumerator for large models
Browse files Browse the repository at this point in the history
Summary:
EmbeddingEnumerator.enumerate() is O(N^2) with respect to the number
of tables. For large models (5000 tables), this is very noticeable.

This diff fixes the N^2 problem by constructing our sharding_options
with the correct is_pool setting, rather than each sharding_option
performing an expensive O(N) search to establish this fact.

Added a benchmark to show behavior. This gives >100x saving for larger
models, removing 30sec from time-to-first-batch metric.

Reviewed By: henrylhtsang

Differential Revision: D51490378
  • Loading branch information
damianr99 authored and facebook-github-bot committed Dec 6, 2023
1 parent a0376a1 commit 994136a
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 14 deletions.
10 changes: 10 additions & 0 deletions torchrec/distributed/planner/enumerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,15 @@ def enumerate(
named_modules_queue.append((n, m))
continue

# Determine the pooling state for all sharding_options using this
# (child_module, child_path). With this optimization, we change enumerate()
# from being O(N^2) with respect to the number of tables to O(N). The
# previous quadratic behavior is because in populate_estimates() invoked below, each
# sharding_option needs to determine its pooling state, which is does via
# an expensive O(N) walk through the list of embedding tables. With this
# change sharding_option.is_pooled becomes O(1).
is_pooled = ShardingOption.module_pooled(child_module, child_path)

for name, param in sharder.shardable_parameters(child_module).items():
(
input_lengths,
Expand Down Expand Up @@ -160,6 +169,7 @@ def enumerate(
stochastic_rounding=stochastic_rounding,
bounds_check_mode=bounds_check_mode,
dependency=dependency,
is_pooled=is_pooled,
)
)
if not sharding_options:
Expand Down
82 changes: 82 additions & 0 deletions torchrec/distributed/planner/tests/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Stress tests for planner to find problematic scaling behavior."""

import time
import unittest

from typing import List, Tuple

from torch import nn

from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.planner.constants import BATCH_SIZE
from torchrec.distributed.planner.enumerators import EmbeddingEnumerator
from torchrec.distributed.planner.types import Topology
from torchrec.distributed.test_utils.test_model import TestSparseNN
from torchrec.distributed.types import ModuleSharder, ShardingType
from torchrec.modules.embedding_configs import EmbeddingBagConfig


class TWSharder(EmbeddingBagCollectionSharder, ModuleSharder[nn.Module]):
def sharding_types(self, compute_device_type: str) -> List[str]:
return [ShardingType.TABLE_WISE.value]

def compute_kernels(
self, sharding_type: str, compute_device_type: str
) -> List[str]:
return [EmbeddingComputeKernel.DENSE.value]


class TestEnumeratorBenchmark(unittest.TestCase):
@staticmethod
def build(
world_size: int, num_tables: int
) -> Tuple[EmbeddingEnumerator, nn.Module]:
compute_device = "cuda"
topology = Topology(
world_size=world_size, local_world_size=8, compute_device=compute_device
)
tables = [
EmbeddingBagConfig(
num_embeddings=100 + i,
embedding_dim=128,
name="table_" + str(i),
feature_names=["feature_" + str(i)],
)
for i in range(num_tables)
]
model = TestSparseNN(tables=tables, weighted_tables=[])
enumerator = EmbeddingEnumerator(topology=topology, batch_size=BATCH_SIZE)
return enumerator, model

def measure(self, world_size: int, num_tables: int) -> float:
enumerator, model = TestEnumeratorBenchmark.build(world_size, num_tables)

start_time = time.time()
sharding_options = enumerator.enumerate(module=model, sharders=[TWSharder()])
end_time = time.time()

self.assertEqual(len(sharding_options), num_tables)
return end_time - start_time

def test_benchmark(self) -> None:
tests = [(2048, d) for d in [100, 200, 400, 800, 1600, 3200, 6400]]
print("\nEnumerator benchmark:")
for world_size, num_tables in tests:
t = self.measure(world_size, num_tables)
print(
f"world_size={world_size:8} num_tables={num_tables:8} enumerate={t:4.2f}s"
)


# This is structured as a unitttest like file so you can use its built-in command
# line argument parsing to control which benchmarks to run, e.g. "-k Enumerator"
if __name__ == "__main__":
unittest.main()
32 changes: 18 additions & 14 deletions torchrec/distributed/planner/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def __init__(
stochastic_rounding: Optional[bool] = None,
bounds_check_mode: Optional[BoundsCheckMode] = None,
dependency: Optional[str] = None,
is_pooled: Optional[bool] = None,
) -> None:
self.name = name
self._tensor = tensor
Expand All @@ -279,7 +280,7 @@ def __init__(
self.stochastic_rounding = stochastic_rounding
self.bounds_check_mode = bounds_check_mode
self.dependency = dependency
self._is_pooled: Optional[bool] = None
self._is_pooled = is_pooled
self.is_weighted: Optional[bool] = None

@property
Expand Down Expand Up @@ -323,21 +324,24 @@ def total_perf(self) -> float:

@property
def is_pooled(self) -> bool:
if self._is_pooled is not None:
return self._is_pooled

if isinstance(self.module[1], EmbeddingCollectionInterface):
self._is_pooled = False
return self.is_pooled
for module in self.module[1].modules():
if isinstance(module, EmbeddingCollectionInterface):
for name, _ in module.named_parameters():
if self.name in name:
self._is_pooled = False
return self._is_pooled
self._is_pooled = True
if self._is_pooled is None:
self._is_pooled = ShardingOption.module_pooled(self.module[1], self.name)
return self._is_pooled

@staticmethod
def module_pooled(module: nn.Module, sharding_option_name: str) -> bool:
"""Determine if module pools output (e.g. EmbeddingBag) or uses unpooled/sequential output."""
if isinstance(module, EmbeddingCollectionInterface):
return False

for submodule in module.modules():
if isinstance(submodule, EmbeddingCollectionInterface):
for name, _ in submodule.named_parameters():
if sharding_option_name in name:
return False

return True

def __hash__(self) -> int:
return hash(
(
Expand Down

0 comments on commit 994136a

Please sign in to comment.