From 18df7d63f3ef36d9be7ff7fab45fe4b8b4028a9a Mon Sep 17 00:00:00 2001 From: Damian Reeves Date: Thu, 14 Dec 2023 14:38:33 -0800 Subject: [PATCH] Implement EmbeddingOffloadScaleupProposer (#1558) Summary: Implements a new type of Proposer that attempts to scale up fused_uvm_caches individually according to an allocation policy based on the expected statistical distribution of the cache workload and a budget of HBM memory that is available for caching. Scaling fused_uvm_caches identically (e.g. using a global default load factor) is suboptimal as we find significant differences between the cache load factors needed for different embedding tables to achieve a reasonable miss rate. This diff just implements the Proposer, but does not (yet) use it by default. To enable the new Proposer, the trainer should explicitly specify this Proposer when initializing the EmbeddingShardingPlanner. The cost model for fused_uvm_caching does not yet fully account for storage and perf overheads of the cache. So this proposer should not be used in conjunction with other proposers in the planner. In a later diff we will improve the cost model so proposals generated by caching-aware proposers and existing proposers are comparable, removing this restriction. Reviewed By: henrylhtsang Differential Revision: D51451167 --- torchrec/distributed/planner/proposers.py | 319 +++++++++++++++++- .../planner/tests/test_proposers.py | 314 ++++++++++++++++- .../distributed/planner/tests/test_utils.py | 41 ++- torchrec/distributed/planner/utils.py | 40 +++ 4 files changed, 709 insertions(+), 5 deletions(-) diff --git a/torchrec/distributed/planner/proposers.py b/torchrec/distributed/planner/proposers.py index ef98fac44..f53f34c55 100644 --- a/torchrec/distributed/planner/proposers.py +++ b/torchrec/distributed/planner/proposers.py @@ -5,10 +5,15 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import copy import itertools import logging from decimal import Decimal -from typing import cast, Dict, List, Optional, Set, Tuple +from typing import cast, Dict, List, Optional, Set, Tuple, Union + +import torch + +from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.planner.types import ( Enumerator, @@ -17,7 +22,7 @@ ShardingOption, Topology, ) -from torchrec.distributed.planner.utils import prod +from torchrec.distributed.planner.utils import BinarySearchPredicate, bytes_to_gb, prod logger: logging.Logger = logging.getLogger(__name__) @@ -272,6 +277,316 @@ def feedback( self._proposal_index += 1 +class EmbeddingOffloadScaleupProposer(Proposer): + def __init__(self, use_depth: bool = True) -> None: + self.use_depth: bool = use_depth + self.enumerator: Optional[Enumerator] = None + self.starting_proposal: List[ShardingOption] = [] + self.proposal: Optional[List[ShardingOption]] = None + self.search: Optional[BinarySearchPredicate] = None + self.previous_plan_perf_rating: float = 0.0 + + def load( + self, + search_space: List[ShardingOption], + enumerator: Optional[Enumerator] = None, + ) -> None: + self.enumerator = enumerator + sharding_options_by_fqn: Dict[str, List[ShardingOption]] = {} + for sharding_option in search_space: + sharding_options_by_fqn.setdefault(sharding_option.fqn, []).append( + sharding_option + ) + for sharding_options in sharding_options_by_fqn.values(): + sharding_options.sort( + key=lambda x: _sharding_option_score(x, self.use_depth) + ) + # currently only use 1st sharding option for proposal only. + # TODO: could traverse through multiple options like GreedyProposer + proposal = [ + sharding_options[0] for sharding_options in sharding_options_by_fqn.values() + ] + # deepcopy so it won't affect other proposers + self.starting_proposal = copy.deepcopy(proposal) + self.proposal = copy.deepcopy(self.starting_proposal) + + def propose(self) -> Optional[List[ShardingOption]]: + return self.proposal + + def feedback( + self, + partitionable: bool, + plan: Optional[List[ShardingOption]] = None, + perf_rating: Optional[float] = None, + storage_constraint: Optional[Topology] = None, + ) -> None: + if not self.enumerator or not plan or not storage_constraint: + self.proposal = None + return + + hbm_used_previously = sum( + sharding_option.total_storage.hbm for sharding_option in plan + ) + + if self.search is None: + # Determine how much extra HBM memory is available for scaling our caches + # beyond the baseline-min-working-set plan. We may not be able to find a + # partitionable plan that uses all this budget, or we may find a plan that + # uses only a portion of this budget enables a layout that reduces overall + # cost at the expense of larger prefetch delay. So we perform a binary + # search to sample plans with different budgets to discover a good + # configuration. + hbm_available = EmbeddingOffloadScaleupProposer.get_budget( + plan, storage_constraint + ) + logger.info( + f"EmbeddingOffloadScaleupProposer - cache scale up budget={round(bytes_to_gb(hbm_available), 2)} GB, exploring [{round(bytes_to_gb(hbm_used_previously), 2)}, {round(bytes_to_gb(hbm_used_previously + hbm_available), 2)}] GB" + ) + # Partitioning proposals is quite expensive when there are a lot of tables, + # so we reduce the number probes the binary search uses to find the max + # cache sizes that fit inside the budget by specifying a tolerance. Once + # we've less than tolerance bytes left of unused budget we stop searching. + # We set tolerance to try to waste less than 3% of budget. For 100TB budget, + # this reduces number of proposals from 47 to 6. + tolerance = round(hbm_available * 0.03) + self.search = BinarySearchPredicate(0, hbm_available, tolerance) + + logger.info( + f"EmbeddingOffloadScaleupProposer - proposed size={round(bytes_to_gb(hbm_used_previously), 2)} GB, score={perf_rating}" + ) + + # Guide the binary search. We assume the partitioned perf model cost is + # monotonic with respect to CLF, so if the feedback from our prior attempt was + # worse than previous one, we reduce the memory for our next proposal, else we + # try using more. This allows us to focus the budget allocation search into the + # productive region where plans are still getting better. + warmer = partitionable and ( + self.previous_plan_perf_rating == 0.0 + or ( + perf_rating is not None and perf_rating < self.previous_plan_perf_rating + ) + ) + self.previous_plan_perf_rating = perf_rating or 0.0 + + assert self.search is not None # keep pyre happy + budget = self.search.next(warmer) + self.proposal = EmbeddingOffloadScaleupProposer.next_plan( + self.starting_proposal, budget, self.enumerator + ) + + @staticmethod + def get_budget(proposal: List[ShardingOption], storage_constraint: Topology) -> int: + """returns additional HBM budget available for GPU caches.""" + available_hbm = sum(device.storage.hbm for device in storage_constraint.devices) + used_hbm = sum( + sharding_option.total_storage.hbm for sharding_option in proposal + ) + return available_hbm - used_hbm + + # Given an available budget of additional memory, and a provisional sharding plan, + # attempt to use the budget wisely to scale up caches that would most benefit from it. + @staticmethod + def next_plan( + starting_proposal: List[ShardingOption], + budget: Optional[int], + enumerator: Optional[Enumerator], + ) -> Optional[List[ShardingOption]]: + if budget is None or enumerator is None: + return None + + def none_to_zero(x: Optional[float]) -> float: + return x if x is not None else 0.0 + + proposal = copy.deepcopy(starting_proposal) + # This is the subset of tables that we can scale + cache_tables = [ + sharding_option + for sharding_option in proposal + if sharding_option.compute_kernel + == EmbeddingComputeKernel.FUSED_UVM_CACHING.value + and none_to_zero( + EmbeddingOffloadScaleupProposer.get_cacheability(sharding_option) + ) + * none_to_zero( + EmbeddingOffloadScaleupProposer.get_expected_lookups(sharding_option) + ) + * none_to_zero( + EmbeddingOffloadScaleupProposer.get_load_factor(sharding_option) + ) + > 0 + ] + # Nothing to scale + if len(cache_tables) == 0: + return None + + size_model = EmbeddingOffloadScaleupProposer.build_affine_storage_model( + cache_tables, enumerator + ) + clfs = torch.tensor( + [ + EmbeddingOffloadScaleupProposer.get_load_factor(sharding_option) + for sharding_option in cache_tables + ] + ) + # cooked_cacheability is cacheability scaled by the expected number of cache + # lookups. + + cooked_cacheability = torch.tensor( + [ + none_to_zero( + EmbeddingOffloadScaleupProposer.get_cacheability(sharding_option) + ) + * none_to_zero( + EmbeddingOffloadScaleupProposer.get_expected_lookups( + sharding_option + ) + ) + for sharding_option in cache_tables + ] + ) + new_clfs = EmbeddingOffloadScaleupProposer.allocate_budget( + model=size_model, + clfs=clfs, + budget=budget, + allocation_priority=cooked_cacheability, + ) + # apply new_clfs, promoting tables that made it to 1.0 + for sharding_option, clf in zip(cache_tables, new_clfs): + clf = clf.item() # tensor scalar -> scalar + assert sharding_option.cache_params # appease pyre + sharding_option.cache_params.load_factor = clf + if clf > 0.9999: # tolerate float roundoff + assert sharding_option.cache_params # appease pyre + sharding_option.cache_params.load_factor = None + sharding_option.compute_kernel = EmbeddingComputeKernel.FUSED.value + # recalculate cost estimates of modified tables + enumerator.populate_estimates(cache_tables) + return proposal + + @staticmethod + def get_cacheability(sharding_option: ShardingOption) -> Optional[float]: + # helper to appease pyre type checker, as cache_params is Optional it maybe None + if ( + sharding_option.cache_params is None + or sharding_option.cache_params.stats is None + ): + return None + return sharding_option.cache_params.stats.cacheability + + @staticmethod + def get_expected_lookups(sharding_option: ShardingOption) -> Optional[float]: + # helper to appease pyre type checker, as cache_params is Optional it maybe None + if ( + sharding_option.cache_params is None + or sharding_option.cache_params.stats is None + ): + return None + return sharding_option.cache_params.stats.expected_lookups + + @staticmethod + def get_load_factor(sharding_option: ShardingOption) -> Optional[float]: + # helper to appease pyre type checker, as cache_params is Optional it maybe None + if ( + sharding_option.cache_params is None + or sharding_option.cache_params.stats is None + ): + return None + return sharding_option.cache_params.load_factor + + # The relationship between clf and shard memory usage is non-linear due to non-clf + # overheads like optimization stats and input/output storage. We model it as an + # affine relationship: bytes = clf * A + B where B is fixed overhead independent of + # CLF (e.g. input / output IO sizes and A is per cache-row overhead. + @staticmethod + def build_affine_storage_model( + uvm_caching_sharding_options: List[ShardingOption], enumerator: Enumerator + ) -> torch.Tensor: + plan: List[ShardingOption] = copy.deepcopy(uvm_caching_sharding_options) + + def compute_hbm_sizes(clf: float) -> torch.Tensor: + for sharding_option in plan: + assert sharding_option.cache_params # appease pyre + sharding_option.cache_params.load_factor = clf + enumerator.populate_estimates(plan) + return torch.tensor( + [sharding_option.total_storage.hbm for sharding_option in plan] + ) + + low_clf, high_clf = 0.1, 0.9 + low_hbms = compute_hbm_sizes(low_clf) + high_hbms = compute_hbm_sizes(high_clf) + + A = (high_hbms - low_hbms) / (high_clf - low_clf) + B = low_hbms - A * low_clf + return torch.stack((A, B), dim=1) # Nx2 (a,b) + + @staticmethod + def clf_to_bytes( + model: torch.Tensor, clfs: Union[float, torch.Tensor] + ) -> torch.Tensor: + # evaluate affine model AX + B + return (model[:, 0] * clfs + model[:, 1]).to(torch.int64) + + # Given a model of an affine system, an existing configuration (clfs), available + # budget, and an allocation policy, return new configuration that best uses the + # available budget. We only add additional budget, we assume the existing + # configuration is specifying a floor or minimum size. + @staticmethod + def allocate_budget( + model: torch.Tensor, + clfs: torch.Tensor, + budget: int, + allocation_priority: torch.Tensor, + ) -> torch.Tensor: + # min size is size of table at 0 CLF + min_size_bytes = EmbeddingOffloadScaleupProposer.clf_to_bytes(model, 0) + max_size_bytes = EmbeddingOffloadScaleupProposer.clf_to_bytes(model, 1) + table_size_bytes = EmbeddingOffloadScaleupProposer.clf_to_bytes(model, clfs) + cache_size_bytes = table_size_bytes - min_size_bytes + max_cache_size_bytes = max_size_bytes - min_size_bytes + + # We have budget bytes to share across the tables in. We want to increase the + # cache_size_bytes of each table in proportion to their allocation priority + # fraction. If we raise the cache_size_bytes to beyond max_cache_size_bytes, + # this is equivalent to reaching CLF=1.0, so we clip the memory to 1.0, and + # reassign the released budget in a subsequent pass. + num_pass = 0 + while budget > 1 and num_pass < 128: + num_pass += 1 + # mask is False for tables at >= max_size, and True otherwise. This allows + # us to remove tables that have already reached full size in one round from + # being dealt more budget in subsequent rounds. + mask = (min_size_bytes + cache_size_bytes) < max_size_bytes + if mask.sum() == 0: + break + + logging.debug( + f"[allocate_budget] pass={num_pass}, budget={budget}, #cache_tables={mask.sum()}" + ) + + # switch to 64bit float to avoid rounding errors, as table cache sizes can + # easily be > 2^24. + masked_priority = (mask * allocation_priority).to(torch.float64) + increase_ratio = masked_priority / torch.sum(masked_priority) + proposed_increase_bytes = budget * increase_ratio + new_cache_size_bytes = torch.minimum( + cache_size_bytes + proposed_increase_bytes, max_cache_size_bytes + ) + actual_increase_bytes = new_cache_size_bytes - cache_size_bytes + + budget -= torch.sum(actual_increase_bytes) + cache_size_bytes = new_cache_size_bytes + # TODO: consider trade off of using remaining budget to push >0.95 tables + # to HBM vs spending that budget on improving hit rate on other tables in + # next pass. + + # cache_size_bytes are the new cache sizes we want to use. We convert them back + # to clfs by dividing by max_cache_size_bytes, which has isolated the clf + # portion of the table size from the fixed overheads. + # convert 64bit values back to original clf precision + return (cache_size_bytes / max_cache_size_bytes).to(clfs.dtype) + + def _sharding_option_score( sharding_option: ShardingOption, use_depth: bool = True ) -> float: diff --git a/torchrec/distributed/planner/tests/test_proposers.py b/torchrec/distributed/planner/tests/test_proposers.py index 6fd3c8998..89574032c 100644 --- a/torchrec/distributed/planner/tests/test_proposers.py +++ b/torchrec/distributed/planner/tests/test_proposers.py @@ -10,10 +10,12 @@ from unittest.mock import MagicMock import torch +from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder from torchrec.distributed.planner.constants import BATCH_SIZE from torchrec.distributed.planner.enumerators import EmbeddingEnumerator from torchrec.distributed.planner.proposers import ( + EmbeddingOffloadScaleupProposer, GreedyProposer, GridSearchProposer, proposers_to_proposals_list, @@ -21,12 +23,18 @@ ) from torchrec.distributed.planner.types import ( Enumerator, + ParameterConstraints, Proposer, ShardingOption, Topology, ) from torchrec.distributed.test_utils.test_model import TestSparseNN -from torchrec.distributed.types import ModuleSharder, ShardingType +from torchrec.distributed.types import ( + CacheParams, + CacheStatistics, + ModuleSharder, + ShardingType, +) from torchrec.modules.embedding_configs import EmbeddingBagConfig @@ -51,6 +59,23 @@ def propose(self) -> Optional[List[ShardingOption]]: pass +class MockCacheStatistics(CacheStatistics): + def __init__(self, expected_lookups: int, cacheability: float) -> None: + self._expected_lookups = expected_lookups + self._cacheability = cacheability + + @property + def expected_lookups(self) -> int: + return self._expected_lookups + + def expected_miss_rate(self, clf: float) -> float: + return clf + + @property + def cacheability(self) -> float: + return self._cacheability + + class TestProposers(unittest.TestCase): def setUp(self) -> None: topology = Topology(world_size=2, compute_device="cuda") @@ -316,6 +341,293 @@ def test_grid_search_three_table(self) -> None: self.assertEqual(num_pruned_options ** len(tables), num_proposals) + def test_allocate_budget(self) -> None: + model = torch.tensor([[1.0, 0.0], [2.0, 3.0], [4.0, 5.0]]) + got = EmbeddingOffloadScaleupProposer.clf_to_bytes( + model, torch.tensor([0, 0.5, 1]) + ) + torch.testing.assert_close(got, torch.tensor([0, 4, 9])) + + # Scenario 1, enough budget to scale everything to 1.0 + model = torch.tensor( + [[30_000_000, 2_000_000], [30_000_000, 2_000_000], [30_000_000, 2_000_000]] + ) + mins = torch.tensor([0.1, 0.1, 1]) + budget = 100_000_000 + got = EmbeddingOffloadScaleupProposer.allocate_budget( + model, + clfs=torch.tensor(mins), + budget=budget, + allocation_priority=torch.tensor([2, 2, 2]), + ) + torch.testing.assert_close(got, torch.tensor([1.0, 1.0, 1.0])) + increase = ( + EmbeddingOffloadScaleupProposer.clf_to_bytes(model, got).sum() + - EmbeddingOffloadScaleupProposer.clf_to_bytes(model, mins).sum() + ).item() + self.assertLess(increase, budget) + + # Scenario 2, limited budget, uniform scale up + model = torch.tensor( + [[30_000_000, 2_000_000], [30_000_000, 2_000_000], [30_000_000, 2_000_000]] + ) + mins = torch.tensor([0.1, 0.1, 1]) + budget = 10_000_000 + got = EmbeddingOffloadScaleupProposer.allocate_budget( + model, clfs=mins, budget=budget, allocation_priority=torch.tensor([2, 2, 2]) + ) + torch.testing.assert_close(got, torch.tensor([0.26667, 0.26667, 1.0])) + increase = ( + EmbeddingOffloadScaleupProposer.clf_to_bytes(model, got).sum() + - EmbeddingOffloadScaleupProposer.clf_to_bytes(model, mins).sum() + ) + self.assertEqual(increase, budget) + + # Scenario 3, limited budget, skewed scale up + model = torch.tensor( + [[30_000_000, 2_000_000], [30_000_000, 2_000_000], [30_000_000, 2_000_000]] + ) + mins = torch.tensor([0.1, 0.1, 1]) + budget = 10_000_000 + got = EmbeddingOffloadScaleupProposer.allocate_budget( + model, clfs=mins, budget=budget, allocation_priority=torch.tensor([2, 4, 2]) + ) + # increase is twice as much for table 2 (started at 0.1) + torch.testing.assert_close( + got, torch.tensor([0.1 + 0.11111, 0.1 + 2 * 0.11111, 1.0]) + ) + increase = ( + EmbeddingOffloadScaleupProposer.clf_to_bytes(model, got).sum() + - EmbeddingOffloadScaleupProposer.clf_to_bytes(model, mins).sum() + ) + self.assertEqual(increase, budget) + + # Scenario 4, multi-pass scale up + model = torch.tensor( + [[30_000_000, 2_000_000], [30_000_000, 2_000_000], [30_000_000, 2_000_000]] + ) + mins = torch.tensor([0.1, 0.3, 0.5]) + budget = 50_000_000 + got = EmbeddingOffloadScaleupProposer.allocate_budget( + model, + clfs=mins, + budget=budget, + allocation_priority=torch.tensor([1, 2, 100]), + ) + torch.testing.assert_close(got, torch.tensor([0.56667, 1.0, 1.0])) + increase = ( + EmbeddingOffloadScaleupProposer.clf_to_bytes(model, got).sum() + - EmbeddingOffloadScaleupProposer.clf_to_bytes(model, mins).sum() + ) + self.assertEqual(increase, budget) + + def test_scaleup(self) -> None: + + tables = [ + EmbeddingBagConfig( + num_embeddings=2_000_000, + embedding_dim=10, + name=f"table_{i}", + feature_names=[f"feature_{i}"], + ) + for i in range(3) + ] + + # Place first two tables into cache, 3rd table leave on hbm. table_1 has a + # larger cacheability score so budget should be skewed to scaling table_1 more + # than table_0. + constraints = { + "table_0": ParameterConstraints( + compute_kernels=[EmbeddingComputeKernel.FUSED_UVM_CACHING.value], + cache_params=CacheParams( + load_factor=0.1, + stats=MockCacheStatistics(expected_lookups=2, cacheability=0.2), + ), + ), + "table_1": ParameterConstraints( + compute_kernels=[EmbeddingComputeKernel.FUSED_UVM_CACHING.value], + cache_params=CacheParams( + load_factor=0.1, + stats=MockCacheStatistics(expected_lookups=2, cacheability=0.5), + ), + ), + } + + MB = 1024 * 1024 + storage_constraint = Topology( + world_size=2, compute_device="cuda", hbm_cap=100 * MB, ddr_cap=1000 * MB + ) + + model = TestSparseNN(tables=tables, sparse_device=torch.device("meta")) + enumerator = EmbeddingEnumerator( + topology=storage_constraint, batch_size=BATCH_SIZE, constraints=constraints + ) + search_space = enumerator.enumerate( + module=model, + sharders=[ + cast(ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder()) + ], + ) + proposer = EmbeddingOffloadScaleupProposer() + proposer.load(search_space, enumerator=enumerator) + + output = [] + proposal = proposer.propose() + while proposal is not None: + output.append( + [ + ( + candidate.name, + candidate.compute_kernel, + candidate.cache_params.load_factor + if candidate.cache_params + else None, + ) + for candidate in proposal + ] + ) + proposer.feedback( + partitionable=True, + plan=proposal, + storage_constraint=storage_constraint, + ) + proposal = proposer.propose() + + # Expected output (name, kernel clf). + # First attempt uses the mins supplied, then as we apply increasing budget + # clfs increase, with the later attempts enough to promote table_3 into hbm. + expected_output = [ + [ + ("table_0", "fused_uvm_caching", 0.1), + ("table_1", "fused_uvm_caching", 0.1), + ("table_2", "fused", None), + ], + [ + ("table_0", "fused_uvm_caching", 0.3025801181793213), + ("table_1", "fused_uvm_caching", 0.6064502596855164), + ("table_2", "fused", None), + ], + [ + ("table_0", "fused_uvm_caching", 0.403870165348053), + ("table_1", "fused_uvm_caching", 0.859675407409668), + ("table_2", "fused", None), + ], + [ + ("table_0", "fused_uvm_caching", 0.4545151889324188), + ("table_1", "fused_uvm_caching", 0.9862880110740662), + ("table_2", "fused", None), + ], + [ + ("table_0", "fused_uvm_caching", 0.5294319987297058), + ("table_1", "fused", None), + ("table_2", "fused", None), + ], + [ + ("table_0", "fused_uvm_caching", 0.573746383190155), + ("table_1", "fused", None), + ("table_2", "fused", None), + ], + [ + ("table_0", "fused_uvm_caching", 0.5959035754203796), + ("table_1", "fused", None), + ("table_2", "fused", None), + ], + ] + + self.assertEqual(output, expected_output) + + def test_scaleup_ample_budget_and_deprecated_feature(self) -> None: + tables = [ + EmbeddingBagConfig( + num_embeddings=2_000_000, + embedding_dim=10, + name=f"table_{i}", + feature_names=[f"feature_{i}"], + ) + for i in range(3) + ] + + # Place first two tables into cache, 3rd table leave on hbm. table_1 has an + # expected lookup of 0 (deprecated feature). + constraints = { + "table_0": ParameterConstraints( + compute_kernels=[EmbeddingComputeKernel.FUSED_UVM_CACHING.value], + cache_params=CacheParams( + load_factor=0.1, + stats=MockCacheStatistics(expected_lookups=2, cacheability=0.2), + ), + ), + "table_1": ParameterConstraints( + compute_kernels=[EmbeddingComputeKernel.FUSED_UVM_CACHING.value], + cache_params=CacheParams( + load_factor=0.1, + stats=MockCacheStatistics(expected_lookups=0, cacheability=0), + ), + ), + } + + MB = 1024 * 1024 + storage_constraint = Topology( + world_size=2, compute_device="cuda", hbm_cap=100 * MB, ddr_cap=1000 * MB + ) + + model = TestSparseNN(tables=tables, sparse_device=torch.device("meta")) + enumerator = EmbeddingEnumerator( + topology=storage_constraint, batch_size=BATCH_SIZE, constraints=constraints + ) + search_space = enumerator.enumerate( + module=model, + sharders=[ + cast(ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder()) + ], + ) + proposer = EmbeddingOffloadScaleupProposer() + proposer.load(search_space, enumerator=enumerator) + + output = [] + proposal = proposer.propose() + while proposal is not None: + output.append( + [ + ( + candidate.name, + candidate.compute_kernel, + candidate.cache_params.load_factor + if candidate.cache_params + else None, + ) + for candidate in proposal + ] + ) + proposer.feedback( + partitionable=True, + plan=proposal, + storage_constraint=storage_constraint, + ) + proposal = proposer.propose() + + # Expected output (name, kernel clf). + # First attempt uses the mins supplied, then as we apply increasing budget + # clfs increase, table 0 gets promoted, table 1 left as original minimum. + expected_output = [ + [ + ("table_0", "fused_uvm_caching", 0.1), + ("table_1", "fused_uvm_caching", 0.1), + ("table_2", "fused", None), + ], + [ + ("table_0", "fused_uvm_caching", 0.8090304136276245), + ("table_1", "fused_uvm_caching", 0.1), + ("table_2", "fused", None), + ], + [ + ("table_0", "fused", None), + ("table_1", "fused_uvm_caching", 0.1), + ("table_2", "fused", None), + ], + ] + self.assertEqual(output[0:3], expected_output) + def test_proposers_to_proposals_list(self) -> None: def make_mock_proposal(name: str) -> List[ShardingOption]: return [ diff --git a/torchrec/distributed/planner/tests/test_utils.py b/torchrec/distributed/planner/tests/test_utils.py index c28b5e4c3..c1b9e8583 100644 --- a/torchrec/distributed/planner/tests/test_utils.py +++ b/torchrec/distributed/planner/tests/test_utils.py @@ -6,11 +6,15 @@ # LICENSE file in the root directory of this source tree. import unittest -from typing import List +from typing import Callable, List from unittest.mock import MagicMock from torchrec.distributed.planner.types import Perf, Shard, ShardingOption, Storage -from torchrec.distributed.planner.utils import _find_imbalance_tables, reset_shard_rank +from torchrec.distributed.planner.utils import ( + _find_imbalance_tables, + BinarySearchPredicate, + reset_shard_rank, +) from torchrec.distributed.types import ShardingType @@ -74,3 +78,36 @@ def test_find_hbm_imbalance_tables(self) -> None: ) ] self.assertTrue(expected_max_hbm_table_names, max_hbm_table_names) + + +class TestBinarySearchPredicate(unittest.TestCase): + def test_binary_search_predicate(self) -> None: + def F(x: int) -> bool: + return x < 90 + + def probes( + search: BinarySearchPredicate, f: Callable[[int], bool] + ) -> List[int]: + r = [] + probe = search.next(True) + while probe is not None: + r.append(probe) + probe = search.next(f(probe)) + return r + + got = probes(BinarySearchPredicate(0, 100, 0), F) + self.assertEqual(got, [50, 75, 88, 94, 91, 89, 90]) + got = probes(BinarySearchPredicate(0, 100, 3), F) + self.assertEqual(got, [50, 75, 88, 94, 91]) + got = probes(BinarySearchPredicate(0, 100, 20), F) + self.assertEqual(got, [50, 75, 88]) + + got = probes(BinarySearchPredicate(91, 100, 0), F) + self.assertEqual(got, [95, 92, 91]) + got = probes(BinarySearchPredicate(1, 10, 0), F) + self.assertEqual(got, [5, 8, 9, 10]) + + got = probes(BinarySearchPredicate(1, 1, 0), F) + self.assertEqual(got, [1]) + got = probes(BinarySearchPredicate(1, 0, 0), F) + self.assertEqual(got, []) diff --git a/torchrec/distributed/planner/utils.py b/torchrec/distributed/planner/utils.py index 715869311..4594396f8 100644 --- a/torchrec/distributed/planner/utils.py +++ b/torchrec/distributed/planner/utils.py @@ -123,3 +123,43 @@ def _find_imbalance_tables( raise ValueError(f"Unknown target imbalance {target_imbalance}") return tables_in_max_value_ranks + + +class BinarySearchPredicate: + """Generates values of X between A & B to invoke on an external predicate F(X) to + discover the largest X for which F(X) is true. Uses binary search to minimize the + number of invocations of F. Assumes F is a step function, i.e. if F(X) is false, + there is no point trying F(X+1).""" + + def __init__(self, A: int, B: int, tolerance: int) -> None: + """A = lower boundary (inclusive) + B = upper boundary (inclusive) + tolerance = stop search early if remaining search range is less than tolerance""" + self.left = A + self.right = B + self.tolerance = tolerance + self.first = True + + def next(self, prior_result: bool) -> Optional[int]: + """next() returns the next value to probe, given the result of the prior probe. + The first time next() is invoked the prior_result is ignored. Returns None if + entire range explored or threshold reached.""" + if self.right - self.left < self.tolerance: + return None + + mid = self._mid() + if self.first: + self.first = False + return mid + + if prior_result: + self.left = mid + 1 + else: + self.right = mid - 1 + if self.right - self.left < self.tolerance: + return None + + return self._mid() + + def _mid(self) -> int: + return self.left + ((self.right - self.left) // 2)