diff --git a/torchrec/distributed/planner/enumerators.py b/torchrec/distributed/planner/enumerators.py index 975c7e620..15834a6c1 100644 --- a/torchrec/distributed/planner/enumerators.py +++ b/torchrec/distributed/planner/enumerators.py @@ -108,6 +108,15 @@ def enumerate( named_modules_queue.append((n, m)) continue + # Determine the pooling state for all sharding_options using this + # (child_module, child_path). With this optimization, we change enumerate() + # from being O(N^2) with respect to the number of tables to O(N). The + # previous quadratic behavior is because in populate_estimates() invoked below, each + # sharding_option needs to determine its pooling state, which is does via + # an expensive O(N) walk through the list of embedding tables. With this + # change sharding_option.is_pooled becomes O(1). + is_pooled = ShardingOption.module_pooled(child_module, child_path) + for name, param in sharder.shardable_parameters(child_module).items(): ( input_lengths, @@ -160,6 +169,7 @@ def enumerate( stochastic_rounding=stochastic_rounding, bounds_check_mode=bounds_check_mode, dependency=dependency, + is_pooled=is_pooled, ) ) if not sharding_options: diff --git a/torchrec/distributed/planner/tests/benchmark.py b/torchrec/distributed/planner/tests/benchmark.py new file mode 100644 index 000000000..4e78d3f79 --- /dev/null +++ b/torchrec/distributed/planner/tests/benchmark.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Stress tests for planner to find problematic scaling behavior.""" + +import time +import unittest + +from typing import List, Tuple + +from torch import nn + +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder +from torchrec.distributed.planner.constants import BATCH_SIZE +from torchrec.distributed.planner.enumerators import EmbeddingEnumerator +from torchrec.distributed.planner.types import Topology +from torchrec.distributed.test_utils.test_model import TestSparseNN +from torchrec.distributed.types import ModuleSharder, ShardingType +from torchrec.modules.embedding_configs import EmbeddingBagConfig + + +class TWSharder(EmbeddingBagCollectionSharder, ModuleSharder[nn.Module]): + def sharding_types(self, compute_device_type: str) -> List[str]: + return [ShardingType.TABLE_WISE.value] + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + return [EmbeddingComputeKernel.DENSE.value] + + +class TestEnumeratorBenchmark(unittest.TestCase): + @staticmethod + def build( + world_size: int, num_tables: int + ) -> Tuple[EmbeddingEnumerator, nn.Module]: + compute_device = "cuda" + topology = Topology( + world_size=world_size, local_world_size=8, compute_device=compute_device + ) + tables = [ + EmbeddingBagConfig( + num_embeddings=100 + i, + embedding_dim=128, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(num_tables) + ] + model = TestSparseNN(tables=tables, weighted_tables=[]) + enumerator = EmbeddingEnumerator(topology=topology, batch_size=BATCH_SIZE) + return enumerator, model + + def measure(self, world_size: int, num_tables: int) -> float: + enumerator, model = TestEnumeratorBenchmark.build(world_size, num_tables) + + start_time = time.time() + sharding_options = enumerator.enumerate(module=model, sharders=[TWSharder()]) + end_time = time.time() + + self.assertEqual(len(sharding_options), num_tables) + return end_time - start_time + + def test_benchmark(self) -> None: + tests = [(2048, d) for d in [100, 200, 400, 800, 1600, 3200, 6400]] + print("\nEnumerator benchmark:") + for world_size, num_tables in tests: + t = self.measure(world_size, num_tables) + print( + f"world_size={world_size:8} num_tables={num_tables:8} enumerate={t:4.2f}s" + ) + + +# This is structured as a unitttest like file so you can use its built-in command +# line argument parsing to control which benchmarks to run, e.g. "-k Enumerator" +if __name__ == "__main__": + unittest.main() diff --git a/torchrec/distributed/planner/types.py b/torchrec/distributed/planner/types.py index 6a043770d..6602b4393 100644 --- a/torchrec/distributed/planner/types.py +++ b/torchrec/distributed/planner/types.py @@ -262,6 +262,7 @@ def __init__( stochastic_rounding: Optional[bool] = None, bounds_check_mode: Optional[BoundsCheckMode] = None, dependency: Optional[str] = None, + is_pooled: Optional[bool] = None, ) -> None: self.name = name self._tensor = tensor @@ -279,7 +280,7 @@ def __init__( self.stochastic_rounding = stochastic_rounding self.bounds_check_mode = bounds_check_mode self.dependency = dependency - self._is_pooled: Optional[bool] = None + self._is_pooled = is_pooled self.is_weighted: Optional[bool] = None @property @@ -323,21 +324,24 @@ def total_perf(self) -> float: @property def is_pooled(self) -> bool: - if self._is_pooled is not None: - return self._is_pooled - - if isinstance(self.module[1], EmbeddingCollectionInterface): - self._is_pooled = False - return self.is_pooled - for module in self.module[1].modules(): - if isinstance(module, EmbeddingCollectionInterface): - for name, _ in module.named_parameters(): - if self.name in name: - self._is_pooled = False - return self._is_pooled - self._is_pooled = True + if self._is_pooled is None: + self._is_pooled = ShardingOption.module_pooled(self.module[1], self.name) return self._is_pooled + @staticmethod + def module_pooled(module: nn.Module, sharding_option_name: str) -> bool: + """Determine if module pools output (e.g. EmbeddingBag) or uses unpooled/sequential output.""" + if isinstance(module, EmbeddingCollectionInterface): + return False + + for submodule in module.modules(): + if isinstance(submodule, EmbeddingCollectionInterface): + for name, _ in submodule.named_parameters(): + if sharding_option_name in name: + return False + + return True + def __hash__(self) -> int: return hash( (