diff --git a/torchrec/distributed/planner/proposers.py b/torchrec/distributed/planner/proposers.py index ef98fac44..f53f34c55 100644 --- a/torchrec/distributed/planner/proposers.py +++ b/torchrec/distributed/planner/proposers.py @@ -5,10 +5,15 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import copy import itertools import logging from decimal import Decimal -from typing import cast, Dict, List, Optional, Set, Tuple +from typing import cast, Dict, List, Optional, Set, Tuple, Union + +import torch + +from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.planner.types import ( Enumerator, @@ -17,7 +22,7 @@ ShardingOption, Topology, ) -from torchrec.distributed.planner.utils import prod +from torchrec.distributed.planner.utils import BinarySearchPredicate, bytes_to_gb, prod logger: logging.Logger = logging.getLogger(__name__) @@ -272,6 +277,316 @@ def feedback( self._proposal_index += 1 +class EmbeddingOffloadScaleupProposer(Proposer): + def __init__(self, use_depth: bool = True) -> None: + self.use_depth: bool = use_depth + self.enumerator: Optional[Enumerator] = None + self.starting_proposal: List[ShardingOption] = [] + self.proposal: Optional[List[ShardingOption]] = None + self.search: Optional[BinarySearchPredicate] = None + self.previous_plan_perf_rating: float = 0.0 + + def load( + self, + search_space: List[ShardingOption], + enumerator: Optional[Enumerator] = None, + ) -> None: + self.enumerator = enumerator + sharding_options_by_fqn: Dict[str, List[ShardingOption]] = {} + for sharding_option in search_space: + sharding_options_by_fqn.setdefault(sharding_option.fqn, []).append( + sharding_option + ) + for sharding_options in sharding_options_by_fqn.values(): + sharding_options.sort( + key=lambda x: _sharding_option_score(x, self.use_depth) + ) + # currently only use 1st sharding option for proposal only. + # TODO: could traverse through multiple options like GreedyProposer + proposal = [ + sharding_options[0] for sharding_options in sharding_options_by_fqn.values() + ] + # deepcopy so it won't affect other proposers + self.starting_proposal = copy.deepcopy(proposal) + self.proposal = copy.deepcopy(self.starting_proposal) + + def propose(self) -> Optional[List[ShardingOption]]: + return self.proposal + + def feedback( + self, + partitionable: bool, + plan: Optional[List[ShardingOption]] = None, + perf_rating: Optional[float] = None, + storage_constraint: Optional[Topology] = None, + ) -> None: + if not self.enumerator or not plan or not storage_constraint: + self.proposal = None + return + + hbm_used_previously = sum( + sharding_option.total_storage.hbm for sharding_option in plan + ) + + if self.search is None: + # Determine how much extra HBM memory is available for scaling our caches + # beyond the baseline-min-working-set plan. We may not be able to find a + # partitionable plan that uses all this budget, or we may find a plan that + # uses only a portion of this budget enables a layout that reduces overall + # cost at the expense of larger prefetch delay. So we perform a binary + # search to sample plans with different budgets to discover a good + # configuration. + hbm_available = EmbeddingOffloadScaleupProposer.get_budget( + plan, storage_constraint + ) + logger.info( + f"EmbeddingOffloadScaleupProposer - cache scale up budget={round(bytes_to_gb(hbm_available), 2)} GB, exploring [{round(bytes_to_gb(hbm_used_previously), 2)}, {round(bytes_to_gb(hbm_used_previously + hbm_available), 2)}] GB" + ) + # Partitioning proposals is quite expensive when there are a lot of tables, + # so we reduce the number probes the binary search uses to find the max + # cache sizes that fit inside the budget by specifying a tolerance. Once + # we've less than tolerance bytes left of unused budget we stop searching. + # We set tolerance to try to waste less than 3% of budget. For 100TB budget, + # this reduces number of proposals from 47 to 6. + tolerance = round(hbm_available * 0.03) + self.search = BinarySearchPredicate(0, hbm_available, tolerance) + + logger.info( + f"EmbeddingOffloadScaleupProposer - proposed size={round(bytes_to_gb(hbm_used_previously), 2)} GB, score={perf_rating}" + ) + + # Guide the binary search. We assume the partitioned perf model cost is + # monotonic with respect to CLF, so if the feedback from our prior attempt was + # worse than previous one, we reduce the memory for our next proposal, else we + # try using more. This allows us to focus the budget allocation search into the + # productive region where plans are still getting better. + warmer = partitionable and ( + self.previous_plan_perf_rating == 0.0 + or ( + perf_rating is not None and perf_rating < self.previous_plan_perf_rating + ) + ) + self.previous_plan_perf_rating = perf_rating or 0.0 + + assert self.search is not None # keep pyre happy + budget = self.search.next(warmer) + self.proposal = EmbeddingOffloadScaleupProposer.next_plan( + self.starting_proposal, budget, self.enumerator + ) + + @staticmethod + def get_budget(proposal: List[ShardingOption], storage_constraint: Topology) -> int: + """returns additional HBM budget available for GPU caches.""" + available_hbm = sum(device.storage.hbm for device in storage_constraint.devices) + used_hbm = sum( + sharding_option.total_storage.hbm for sharding_option in proposal + ) + return available_hbm - used_hbm + + # Given an available budget of additional memory, and a provisional sharding plan, + # attempt to use the budget wisely to scale up caches that would most benefit from it. + @staticmethod + def next_plan( + starting_proposal: List[ShardingOption], + budget: Optional[int], + enumerator: Optional[Enumerator], + ) -> Optional[List[ShardingOption]]: + if budget is None or enumerator is None: + return None + + def none_to_zero(x: Optional[float]) -> float: + return x if x is not None else 0.0 + + proposal = copy.deepcopy(starting_proposal) + # This is the subset of tables that we can scale + cache_tables = [ + sharding_option + for sharding_option in proposal + if sharding_option.compute_kernel + == EmbeddingComputeKernel.FUSED_UVM_CACHING.value + and none_to_zero( + EmbeddingOffloadScaleupProposer.get_cacheability(sharding_option) + ) + * none_to_zero( + EmbeddingOffloadScaleupProposer.get_expected_lookups(sharding_option) + ) + * none_to_zero( + EmbeddingOffloadScaleupProposer.get_load_factor(sharding_option) + ) + > 0 + ] + # Nothing to scale + if len(cache_tables) == 0: + return None + + size_model = EmbeddingOffloadScaleupProposer.build_affine_storage_model( + cache_tables, enumerator + ) + clfs = torch.tensor( + [ + EmbeddingOffloadScaleupProposer.get_load_factor(sharding_option) + for sharding_option in cache_tables + ] + ) + # cooked_cacheability is cacheability scaled by the expected number of cache + # lookups. + + cooked_cacheability = torch.tensor( + [ + none_to_zero( + EmbeddingOffloadScaleupProposer.get_cacheability(sharding_option) + ) + * none_to_zero( + EmbeddingOffloadScaleupProposer.get_expected_lookups( + sharding_option + ) + ) + for sharding_option in cache_tables + ] + ) + new_clfs = EmbeddingOffloadScaleupProposer.allocate_budget( + model=size_model, + clfs=clfs, + budget=budget, + allocation_priority=cooked_cacheability, + ) + # apply new_clfs, promoting tables that made it to 1.0 + for sharding_option, clf in zip(cache_tables, new_clfs): + clf = clf.item() # tensor scalar -> scalar + assert sharding_option.cache_params # appease pyre + sharding_option.cache_params.load_factor = clf + if clf > 0.9999: # tolerate float roundoff + assert sharding_option.cache_params # appease pyre + sharding_option.cache_params.load_factor = None + sharding_option.compute_kernel = EmbeddingComputeKernel.FUSED.value + # recalculate cost estimates of modified tables + enumerator.populate_estimates(cache_tables) + return proposal + + @staticmethod + def get_cacheability(sharding_option: ShardingOption) -> Optional[float]: + # helper to appease pyre type checker, as cache_params is Optional it maybe None + if ( + sharding_option.cache_params is None + or sharding_option.cache_params.stats is None + ): + return None + return sharding_option.cache_params.stats.cacheability + + @staticmethod + def get_expected_lookups(sharding_option: ShardingOption) -> Optional[float]: + # helper to appease pyre type checker, as cache_params is Optional it maybe None + if ( + sharding_option.cache_params is None + or sharding_option.cache_params.stats is None + ): + return None + return sharding_option.cache_params.stats.expected_lookups + + @staticmethod + def get_load_factor(sharding_option: ShardingOption) -> Optional[float]: + # helper to appease pyre type checker, as cache_params is Optional it maybe None + if ( + sharding_option.cache_params is None + or sharding_option.cache_params.stats is None + ): + return None + return sharding_option.cache_params.load_factor + + # The relationship between clf and shard memory usage is non-linear due to non-clf + # overheads like optimization stats and input/output storage. We model it as an + # affine relationship: bytes = clf * A + B where B is fixed overhead independent of + # CLF (e.g. input / output IO sizes and A is per cache-row overhead. + @staticmethod + def build_affine_storage_model( + uvm_caching_sharding_options: List[ShardingOption], enumerator: Enumerator + ) -> torch.Tensor: + plan: List[ShardingOption] = copy.deepcopy(uvm_caching_sharding_options) + + def compute_hbm_sizes(clf: float) -> torch.Tensor: + for sharding_option in plan: + assert sharding_option.cache_params # appease pyre + sharding_option.cache_params.load_factor = clf + enumerator.populate_estimates(plan) + return torch.tensor( + [sharding_option.total_storage.hbm for sharding_option in plan] + ) + + low_clf, high_clf = 0.1, 0.9 + low_hbms = compute_hbm_sizes(low_clf) + high_hbms = compute_hbm_sizes(high_clf) + + A = (high_hbms - low_hbms) / (high_clf - low_clf) + B = low_hbms - A * low_clf + return torch.stack((A, B), dim=1) # Nx2 (a,b) + + @staticmethod + def clf_to_bytes( + model: torch.Tensor, clfs: Union[float, torch.Tensor] + ) -> torch.Tensor: + # evaluate affine model AX + B + return (model[:, 0] * clfs + model[:, 1]).to(torch.int64) + + # Given a model of an affine system, an existing configuration (clfs), available + # budget, and an allocation policy, return new configuration that best uses the + # available budget. We only add additional budget, we assume the existing + # configuration is specifying a floor or minimum size. + @staticmethod + def allocate_budget( + model: torch.Tensor, + clfs: torch.Tensor, + budget: int, + allocation_priority: torch.Tensor, + ) -> torch.Tensor: + # min size is size of table at 0 CLF + min_size_bytes = EmbeddingOffloadScaleupProposer.clf_to_bytes(model, 0) + max_size_bytes = EmbeddingOffloadScaleupProposer.clf_to_bytes(model, 1) + table_size_bytes = EmbeddingOffloadScaleupProposer.clf_to_bytes(model, clfs) + cache_size_bytes = table_size_bytes - min_size_bytes + max_cache_size_bytes = max_size_bytes - min_size_bytes + + # We have budget bytes to share across the tables in. We want to increase the + # cache_size_bytes of each table in proportion to their allocation priority + # fraction. If we raise the cache_size_bytes to beyond max_cache_size_bytes, + # this is equivalent to reaching CLF=1.0, so we clip the memory to 1.0, and + # reassign the released budget in a subsequent pass. + num_pass = 0 + while budget > 1 and num_pass < 128: + num_pass += 1 + # mask is False for tables at >= max_size, and True otherwise. This allows + # us to remove tables that have already reached full size in one round from + # being dealt more budget in subsequent rounds. + mask = (min_size_bytes + cache_size_bytes) < max_size_bytes + if mask.sum() == 0: + break + + logging.debug( + f"[allocate_budget] pass={num_pass}, budget={budget}, #cache_tables={mask.sum()}" + ) + + # switch to 64bit float to avoid rounding errors, as table cache sizes can + # easily be > 2^24. + masked_priority = (mask * allocation_priority).to(torch.float64) + increase_ratio = masked_priority / torch.sum(masked_priority) + proposed_increase_bytes = budget * increase_ratio + new_cache_size_bytes = torch.minimum( + cache_size_bytes + proposed_increase_bytes, max_cache_size_bytes + ) + actual_increase_bytes = new_cache_size_bytes - cache_size_bytes + + budget -= torch.sum(actual_increase_bytes) + cache_size_bytes = new_cache_size_bytes + # TODO: consider trade off of using remaining budget to push >0.95 tables + # to HBM vs spending that budget on improving hit rate on other tables in + # next pass. + + # cache_size_bytes are the new cache sizes we want to use. We convert them back + # to clfs by dividing by max_cache_size_bytes, which has isolated the clf + # portion of the table size from the fixed overheads. + # convert 64bit values back to original clf precision + return (cache_size_bytes / max_cache_size_bytes).to(clfs.dtype) + + def _sharding_option_score( sharding_option: ShardingOption, use_depth: bool = True ) -> float: diff --git a/torchrec/distributed/planner/tests/test_proposers.py b/torchrec/distributed/planner/tests/test_proposers.py index 6fd3c8998..89574032c 100644 --- a/torchrec/distributed/planner/tests/test_proposers.py +++ b/torchrec/distributed/planner/tests/test_proposers.py @@ -10,10 +10,12 @@ from unittest.mock import MagicMock import torch +from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder from torchrec.distributed.planner.constants import BATCH_SIZE from torchrec.distributed.planner.enumerators import EmbeddingEnumerator from torchrec.distributed.planner.proposers import ( + EmbeddingOffloadScaleupProposer, GreedyProposer, GridSearchProposer, proposers_to_proposals_list, @@ -21,12 +23,18 @@ ) from torchrec.distributed.planner.types import ( Enumerator, + ParameterConstraints, Proposer, ShardingOption, Topology, ) from torchrec.distributed.test_utils.test_model import TestSparseNN -from torchrec.distributed.types import ModuleSharder, ShardingType +from torchrec.distributed.types import ( + CacheParams, + CacheStatistics, + ModuleSharder, + ShardingType, +) from torchrec.modules.embedding_configs import EmbeddingBagConfig @@ -51,6 +59,23 @@ def propose(self) -> Optional[List[ShardingOption]]: pass +class MockCacheStatistics(CacheStatistics): + def __init__(self, expected_lookups: int, cacheability: float) -> None: + self._expected_lookups = expected_lookups + self._cacheability = cacheability + + @property + def expected_lookups(self) -> int: + return self._expected_lookups + + def expected_miss_rate(self, clf: float) -> float: + return clf + + @property + def cacheability(self) -> float: + return self._cacheability + + class TestProposers(unittest.TestCase): def setUp(self) -> None: topology = Topology(world_size=2, compute_device="cuda") @@ -316,6 +341,293 @@ def test_grid_search_three_table(self) -> None: self.assertEqual(num_pruned_options ** len(tables), num_proposals) + def test_allocate_budget(self) -> None: + model = torch.tensor([[1.0, 0.0], [2.0, 3.0], [4.0, 5.0]]) + got = EmbeddingOffloadScaleupProposer.clf_to_bytes( + model, torch.tensor([0, 0.5, 1]) + ) + torch.testing.assert_close(got, torch.tensor([0, 4, 9])) + + # Scenario 1, enough budget to scale everything to 1.0 + model = torch.tensor( + [[30_000_000, 2_000_000], [30_000_000, 2_000_000], [30_000_000, 2_000_000]] + ) + mins = torch.tensor([0.1, 0.1, 1]) + budget = 100_000_000 + got = EmbeddingOffloadScaleupProposer.allocate_budget( + model, + clfs=torch.tensor(mins), + budget=budget, + allocation_priority=torch.tensor([2, 2, 2]), + ) + torch.testing.assert_close(got, torch.tensor([1.0, 1.0, 1.0])) + increase = ( + EmbeddingOffloadScaleupProposer.clf_to_bytes(model, got).sum() + - EmbeddingOffloadScaleupProposer.clf_to_bytes(model, mins).sum() + ).item() + self.assertLess(increase, budget) + + # Scenario 2, limited budget, uniform scale up + model = torch.tensor( + [[30_000_000, 2_000_000], [30_000_000, 2_000_000], [30_000_000, 2_000_000]] + ) + mins = torch.tensor([0.1, 0.1, 1]) + budget = 10_000_000 + got = EmbeddingOffloadScaleupProposer.allocate_budget( + model, clfs=mins, budget=budget, allocation_priority=torch.tensor([2, 2, 2]) + ) + torch.testing.assert_close(got, torch.tensor([0.26667, 0.26667, 1.0])) + increase = ( + EmbeddingOffloadScaleupProposer.clf_to_bytes(model, got).sum() + - EmbeddingOffloadScaleupProposer.clf_to_bytes(model, mins).sum() + ) + self.assertEqual(increase, budget) + + # Scenario 3, limited budget, skewed scale up + model = torch.tensor( + [[30_000_000, 2_000_000], [30_000_000, 2_000_000], [30_000_000, 2_000_000]] + ) + mins = torch.tensor([0.1, 0.1, 1]) + budget = 10_000_000 + got = EmbeddingOffloadScaleupProposer.allocate_budget( + model, clfs=mins, budget=budget, allocation_priority=torch.tensor([2, 4, 2]) + ) + # increase is twice as much for table 2 (started at 0.1) + torch.testing.assert_close( + got, torch.tensor([0.1 + 0.11111, 0.1 + 2 * 0.11111, 1.0]) + ) + increase = ( + EmbeddingOffloadScaleupProposer.clf_to_bytes(model, got).sum() + - EmbeddingOffloadScaleupProposer.clf_to_bytes(model, mins).sum() + ) + self.assertEqual(increase, budget) + + # Scenario 4, multi-pass scale up + model = torch.tensor( + [[30_000_000, 2_000_000], [30_000_000, 2_000_000], [30_000_000, 2_000_000]] + ) + mins = torch.tensor([0.1, 0.3, 0.5]) + budget = 50_000_000 + got = EmbeddingOffloadScaleupProposer.allocate_budget( + model, + clfs=mins, + budget=budget, + allocation_priority=torch.tensor([1, 2, 100]), + ) + torch.testing.assert_close(got, torch.tensor([0.56667, 1.0, 1.0])) + increase = ( + EmbeddingOffloadScaleupProposer.clf_to_bytes(model, got).sum() + - EmbeddingOffloadScaleupProposer.clf_to_bytes(model, mins).sum() + ) + self.assertEqual(increase, budget) + + def test_scaleup(self) -> None: + + tables = [ + EmbeddingBagConfig( + num_embeddings=2_000_000, + embedding_dim=10, + name=f"table_{i}", + feature_names=[f"feature_{i}"], + ) + for i in range(3) + ] + + # Place first two tables into cache, 3rd table leave on hbm. table_1 has a + # larger cacheability score so budget should be skewed to scaling table_1 more + # than table_0. + constraints = { + "table_0": ParameterConstraints( + compute_kernels=[EmbeddingComputeKernel.FUSED_UVM_CACHING.value], + cache_params=CacheParams( + load_factor=0.1, + stats=MockCacheStatistics(expected_lookups=2, cacheability=0.2), + ), + ), + "table_1": ParameterConstraints( + compute_kernels=[EmbeddingComputeKernel.FUSED_UVM_CACHING.value], + cache_params=CacheParams( + load_factor=0.1, + stats=MockCacheStatistics(expected_lookups=2, cacheability=0.5), + ), + ), + } + + MB = 1024 * 1024 + storage_constraint = Topology( + world_size=2, compute_device="cuda", hbm_cap=100 * MB, ddr_cap=1000 * MB + ) + + model = TestSparseNN(tables=tables, sparse_device=torch.device("meta")) + enumerator = EmbeddingEnumerator( + topology=storage_constraint, batch_size=BATCH_SIZE, constraints=constraints + ) + search_space = enumerator.enumerate( + module=model, + sharders=[ + cast(ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder()) + ], + ) + proposer = EmbeddingOffloadScaleupProposer() + proposer.load(search_space, enumerator=enumerator) + + output = [] + proposal = proposer.propose() + while proposal is not None: + output.append( + [ + ( + candidate.name, + candidate.compute_kernel, + candidate.cache_params.load_factor + if candidate.cache_params + else None, + ) + for candidate in proposal + ] + ) + proposer.feedback( + partitionable=True, + plan=proposal, + storage_constraint=storage_constraint, + ) + proposal = proposer.propose() + + # Expected output (name, kernel clf). + # First attempt uses the mins supplied, then as we apply increasing budget + # clfs increase, with the later attempts enough to promote table_3 into hbm. + expected_output = [ + [ + ("table_0", "fused_uvm_caching", 0.1), + ("table_1", "fused_uvm_caching", 0.1), + ("table_2", "fused", None), + ], + [ + ("table_0", "fused_uvm_caching", 0.3025801181793213), + ("table_1", "fused_uvm_caching", 0.6064502596855164), + ("table_2", "fused", None), + ], + [ + ("table_0", "fused_uvm_caching", 0.403870165348053), + ("table_1", "fused_uvm_caching", 0.859675407409668), + ("table_2", "fused", None), + ], + [ + ("table_0", "fused_uvm_caching", 0.4545151889324188), + ("table_1", "fused_uvm_caching", 0.9862880110740662), + ("table_2", "fused", None), + ], + [ + ("table_0", "fused_uvm_caching", 0.5294319987297058), + ("table_1", "fused", None), + ("table_2", "fused", None), + ], + [ + ("table_0", "fused_uvm_caching", 0.573746383190155), + ("table_1", "fused", None), + ("table_2", "fused", None), + ], + [ + ("table_0", "fused_uvm_caching", 0.5959035754203796), + ("table_1", "fused", None), + ("table_2", "fused", None), + ], + ] + + self.assertEqual(output, expected_output) + + def test_scaleup_ample_budget_and_deprecated_feature(self) -> None: + tables = [ + EmbeddingBagConfig( + num_embeddings=2_000_000, + embedding_dim=10, + name=f"table_{i}", + feature_names=[f"feature_{i}"], + ) + for i in range(3) + ] + + # Place first two tables into cache, 3rd table leave on hbm. table_1 has an + # expected lookup of 0 (deprecated feature). + constraints = { + "table_0": ParameterConstraints( + compute_kernels=[EmbeddingComputeKernel.FUSED_UVM_CACHING.value], + cache_params=CacheParams( + load_factor=0.1, + stats=MockCacheStatistics(expected_lookups=2, cacheability=0.2), + ), + ), + "table_1": ParameterConstraints( + compute_kernels=[EmbeddingComputeKernel.FUSED_UVM_CACHING.value], + cache_params=CacheParams( + load_factor=0.1, + stats=MockCacheStatistics(expected_lookups=0, cacheability=0), + ), + ), + } + + MB = 1024 * 1024 + storage_constraint = Topology( + world_size=2, compute_device="cuda", hbm_cap=100 * MB, ddr_cap=1000 * MB + ) + + model = TestSparseNN(tables=tables, sparse_device=torch.device("meta")) + enumerator = EmbeddingEnumerator( + topology=storage_constraint, batch_size=BATCH_SIZE, constraints=constraints + ) + search_space = enumerator.enumerate( + module=model, + sharders=[ + cast(ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder()) + ], + ) + proposer = EmbeddingOffloadScaleupProposer() + proposer.load(search_space, enumerator=enumerator) + + output = [] + proposal = proposer.propose() + while proposal is not None: + output.append( + [ + ( + candidate.name, + candidate.compute_kernel, + candidate.cache_params.load_factor + if candidate.cache_params + else None, + ) + for candidate in proposal + ] + ) + proposer.feedback( + partitionable=True, + plan=proposal, + storage_constraint=storage_constraint, + ) + proposal = proposer.propose() + + # Expected output (name, kernel clf). + # First attempt uses the mins supplied, then as we apply increasing budget + # clfs increase, table 0 gets promoted, table 1 left as original minimum. + expected_output = [ + [ + ("table_0", "fused_uvm_caching", 0.1), + ("table_1", "fused_uvm_caching", 0.1), + ("table_2", "fused", None), + ], + [ + ("table_0", "fused_uvm_caching", 0.8090304136276245), + ("table_1", "fused_uvm_caching", 0.1), + ("table_2", "fused", None), + ], + [ + ("table_0", "fused", None), + ("table_1", "fused_uvm_caching", 0.1), + ("table_2", "fused", None), + ], + ] + self.assertEqual(output[0:3], expected_output) + def test_proposers_to_proposals_list(self) -> None: def make_mock_proposal(name: str) -> List[ShardingOption]: return [ diff --git a/torchrec/distributed/planner/tests/test_utils.py b/torchrec/distributed/planner/tests/test_utils.py index c28b5e4c3..c1b9e8583 100644 --- a/torchrec/distributed/planner/tests/test_utils.py +++ b/torchrec/distributed/planner/tests/test_utils.py @@ -6,11 +6,15 @@ # LICENSE file in the root directory of this source tree. import unittest -from typing import List +from typing import Callable, List from unittest.mock import MagicMock from torchrec.distributed.planner.types import Perf, Shard, ShardingOption, Storage -from torchrec.distributed.planner.utils import _find_imbalance_tables, reset_shard_rank +from torchrec.distributed.planner.utils import ( + _find_imbalance_tables, + BinarySearchPredicate, + reset_shard_rank, +) from torchrec.distributed.types import ShardingType @@ -74,3 +78,36 @@ def test_find_hbm_imbalance_tables(self) -> None: ) ] self.assertTrue(expected_max_hbm_table_names, max_hbm_table_names) + + +class TestBinarySearchPredicate(unittest.TestCase): + def test_binary_search_predicate(self) -> None: + def F(x: int) -> bool: + return x < 90 + + def probes( + search: BinarySearchPredicate, f: Callable[[int], bool] + ) -> List[int]: + r = [] + probe = search.next(True) + while probe is not None: + r.append(probe) + probe = search.next(f(probe)) + return r + + got = probes(BinarySearchPredicate(0, 100, 0), F) + self.assertEqual(got, [50, 75, 88, 94, 91, 89, 90]) + got = probes(BinarySearchPredicate(0, 100, 3), F) + self.assertEqual(got, [50, 75, 88, 94, 91]) + got = probes(BinarySearchPredicate(0, 100, 20), F) + self.assertEqual(got, [50, 75, 88]) + + got = probes(BinarySearchPredicate(91, 100, 0), F) + self.assertEqual(got, [95, 92, 91]) + got = probes(BinarySearchPredicate(1, 10, 0), F) + self.assertEqual(got, [5, 8, 9, 10]) + + got = probes(BinarySearchPredicate(1, 1, 0), F) + self.assertEqual(got, [1]) + got = probes(BinarySearchPredicate(1, 0, 0), F) + self.assertEqual(got, []) diff --git a/torchrec/distributed/planner/utils.py b/torchrec/distributed/planner/utils.py index 715869311..4594396f8 100644 --- a/torchrec/distributed/planner/utils.py +++ b/torchrec/distributed/planner/utils.py @@ -123,3 +123,43 @@ def _find_imbalance_tables( raise ValueError(f"Unknown target imbalance {target_imbalance}") return tables_in_max_value_ranks + + +class BinarySearchPredicate: + """Generates values of X between A & B to invoke on an external predicate F(X) to + discover the largest X for which F(X) is true. Uses binary search to minimize the + number of invocations of F. Assumes F is a step function, i.e. if F(X) is false, + there is no point trying F(X+1).""" + + def __init__(self, A: int, B: int, tolerance: int) -> None: + """A = lower boundary (inclusive) + B = upper boundary (inclusive) + tolerance = stop search early if remaining search range is less than tolerance""" + self.left = A + self.right = B + self.tolerance = tolerance + self.first = True + + def next(self, prior_result: bool) -> Optional[int]: + """next() returns the next value to probe, given the result of the prior probe. + The first time next() is invoked the prior_result is ignored. Returns None if + entire range explored or threshold reached.""" + if self.right - self.left < self.tolerance: + return None + + mid = self._mid() + if self.first: + self.first = False + return mid + + if prior_result: + self.left = mid + 1 + else: + self.right = mid - 1 + if self.right - self.left < self.tolerance: + return None + + return self._mid() + + def _mid(self) -> int: + return self.left + ((self.right - self.left) // 2)