From 6ed5f9251c87c23398f89d618ef0ff9bf1990b8c Mon Sep 17 00:00:00 2001 From: Damian Reeves Date: Mon, 20 Nov 2023 13:28:53 -0800 Subject: [PATCH] Extend torchrec CacheParams to include CacheStatistics Summary: Extend torchrec CacheParams with cacheability data (miss ratio curves for embedding tables), so we can make smarter decisions about uvm_cache sizes and shard placement in the planner. Reviewed By: henrylhtsang Differential Revision: D51129138 --- .../distributed/planner/shard_estimators.py | 86 ++++++++++++++++++- .../planner/tests/test_shard_estimators.py | 54 ++++++++++++ torchrec/distributed/types.py | 26 ++++++ 3 files changed, 165 insertions(+), 1 deletion(-) diff --git a/torchrec/distributed/planner/shard_estimators.py b/torchrec/distributed/planner/shard_estimators.py index d5f1f5fc5..0f262afab 100644 --- a/torchrec/distributed/planner/shard_estimators.py +++ b/torchrec/distributed/planner/shard_estimators.py @@ -34,7 +34,12 @@ Topology, ) from torchrec.distributed.planner.utils import prod, sharder_name -from torchrec.distributed.types import CommOp, ModuleSharder, ShardingType +from torchrec.distributed.types import ( + CacheStatistics, + CommOp, + ModuleSharder, + ShardingType, +) from torchrec.modules.embedding_modules import EmbeddingBagCollectionInterface @@ -1191,3 +1196,82 @@ def _get_optimizer_multipler( return 1 / shape[-1] else: return 1 + + +class EmbeddingOffloadStats(CacheStatistics): + """Computes cache statistics for uvm_fused_cache tables. + + Args: + + cachebility (float): + The area-under-the-curve of miss-ratio curve. + expected_lookups (float): + The expected number of unique embedding ids per global batch. + mrc_hist_counts (torch.Tensor): + A 1d tensor (size n) holding a histogram of LRU miss ratio curve. Each bin + represents 1/nth of possible LRU cache sizes (from load_factor 0 to load_factor + 1.0). The bin contains the number of expected LRU operations that could be + handled without a cache miss if the LRU load_factor was at least that size. + height (int): + The height (num_embeddings) of the embedding table. + """ + + def __init__( + self, + cacheability: float, + expected_lookups: int, + mrc_hist_counts: torch.Tensor, + height: int, + ) -> None: + self._cacheability = cacheability + self._expected_lookups = expected_lookups + self.height = height + + if mrc_hist_counts.dim() != 1: + raise ValueError(f"expected 1d tensor, got {mrc_hist_counts.dim()}d") + if mrc_hist_counts.size()[0] == 0: + raise ValueError("expected non-empty tensor") + + self.hist: torch.Tensor = mrc_hist_counts + self.bins: torch.Tensor = torch.linspace(0, height, len(mrc_hist_counts) + 1) + + @property + def expected_lookups(self) -> int: + return self._expected_lookups + + def expected_miss_rate(self, clf: float) -> float: + cache_size = torch.tensor(clf * self.height) + miss_rate = EmbeddingOffloadStats.estimate_cache_miss_rate( + cache_sizes=cache_size, hist=self.hist, bins=self.bins + ) + return miss_rate.item() + + @property + def cacheability(self) -> float: + return self._cacheability + + @staticmethod + def estimate_cache_miss_rate( + cache_sizes: torch.Tensor, hist: torch.Tensor, bins: torch.Tensor + ) -> torch.Tensor: + """Calculate estimated cache miss ratio for the proposed cache_sizes, given the MRC + histogram. + """ + ys = hist.cumsum(dim=0) + ys = ys / ys[-1] # rescale [0,1] + ys = 1 - ys # make miss-ratio, not hit-ratio + + # torch.bucketize has slightly different semantics to np.digitize, + # and np.digitize has a complex interface, read the docs carefully! + # we're trying to reverse the ops of np.histogram, indices are one larger than + # the insert positions, since with right=True, index returned such that x < + # bins[index], so x 'lives' in hist[index-1] + # A cache size of k will get hits for all stack distances of upto k-1 inclusive. + larger_bin_indices = torch.bucketize(cache_sizes - 1, bins, right=True) + # Augment ys to deal with torch.bucketize boundary conditions: + # values outside of bins range map to 0, or len(bins). + # So we extend ys to populate sentinel values for these cases. With the twist that + # the left-hand sentinel we put on the right side of the array, as larger_bin_indices - 1 + # maps 0 -> -1, which pytorch maps to most right hand value. + ys = torch.cat((ys, torch.tensor([0.0, 1.0]))) + return ys[larger_bin_indices - 1] diff --git a/torchrec/distributed/planner/tests/test_shard_estimators.py b/torchrec/distributed/planner/tests/test_shard_estimators.py index a172d93f7..9f099707d 100644 --- a/torchrec/distributed/planner/tests/test_shard_estimators.py +++ b/torchrec/distributed/planner/tests/test_shard_estimators.py @@ -22,6 +22,7 @@ from torchrec.distributed.planner.enumerators import EmbeddingEnumerator from torchrec.distributed.planner.shard_estimators import ( _calculate_storage_specific_sizes, + EmbeddingOffloadStats, EmbeddingPerfEstimator, ) from torchrec.distributed.planner.types import Perf, Topology @@ -447,3 +448,56 @@ def test_calculate_storage_specific_sizes(self) -> None: ) self.assertEqual(estimates, expected_storage) + + +class TestEmbeddingOffloadStats(unittest.TestCase): + def test_basic(self) -> None: + stats = EmbeddingOffloadStats( + cacheability=0.42, + expected_lookups=31, + mrc_hist_counts=torch.tensor([99, 98, 97]), + height=92, + ) + self.assertEqual(stats.cacheability, 0.42) + self.assertEqual(stats.expected_lookups, 31) + self.assertEqual(stats.expected_miss_rate(0), 1.0) + self.assertEqual(stats.expected_miss_rate(1), 0.0) + self.assertAlmostEqual( + stats.expected_miss_rate(0.5), 1 - (99 + 98) / (99 + 98 + 97) + ) + + def test_estimate_cache_miss_rate(self) -> None: + hist = torch.tensor([0, 6, 0, 8]) + bins = torch.tensor([0.0, 1.0, 2.0, 3.0, 4.0]) + miss_rates = EmbeddingOffloadStats.estimate_cache_miss_rate( + torch.tensor([0, 1, 2, 3, 4]), hist, bins + ) + m = 1 - (6 / (6 + 8)) # from hist counts above + want = [ + 1, # size 0 - 100% miss + 1, # size 1 - 100%, no immediate repetitions + m, # size 2 - m (~57%) miss, 6 occurrences + m, # size 3 - same as size 2, no 3 stack distances, + # so increasing cache by 1 doesn't help + 0, # size 4 - 0% miss rate, everything fits + ] + torch.testing.assert_close(miss_rates, torch.tensor(want)) + + # test with bigger bins to better validate boundary conditions + # create simple linear miss rate curve + trace = torch.arange(100.0) + hist = torch.histc(trace, bins=10, min=0, max=100) + bins = torch.linspace(0, 100, len(hist) + 1) + miss_rates = EmbeddingOffloadStats.estimate_cache_miss_rate( + torch.tensor([0, 9, 10, 11, 89, 99, 100]), hist, bins + ) + want = [ + 1, # 0 -> no cache, 100% miss + 0.9, # 9 -> bin 0, which is all cache sizes <= 10, has 90 misses of 100, so 90% miss + 0.9, # 10 -> bin 0, same as above + 0.8, # 11 -> bin 1, cache sizes (10, 20], 80 misses out of 100, so 80% miss + 0.1, # 89 -> bin 8, cache sizes (80, 90], 10 misses out of 100, so 10% miss + 0, # 99 -> bin 9, cache sizes (90, 100], final last bin gets scaled to 1, so 0% misses + 0, # 100 -> off the end of the histogram, 0% misses + ] + torch.testing.assert_close(miss_rates, torch.tensor(want)) diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index 058b07ef9..e504791f9 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -451,6 +451,31 @@ class ModuleShardingPlan: pass +class CacheStatistics(abc.ABC): + @property + @abc.abstractmethod + def expected_lookups(self) -> float: + """Number of expected cache lookups per training step. + + This is the expected number of distinct values in a global training batch.""" + + @abc.abstractmethod + def expected_miss_rate(self, clf: float) -> float: + """Expected cache lookup miss rate for a given cache size. + + When clf (cache load factor) is 0, returns 1.0 (100% miss). When clf is 1.0, + returns 0 (100% hit). For values of clf between these extremes, returns the + estimated miss rate of the cache, e.g. based on knowledge of the statistical + properties of the training data set.""" + + @property + @abc.abstractmethod + def cacheability(self) -> float: + """Summarized measure of the difficulty to cache a dataset that is independent of + cache size. A score of 0 means the dataset is very cacheable (e.g. high locality + between accesses), a score of 1 is very difficult to cache.""" + + @dataclass class CacheParams: algorithm: Optional[CacheAlgorithm] = None @@ -458,6 +483,7 @@ class CacheParams: reserved_memory: Optional[float] = None precision: Optional[DataType] = None prefetch_pipeline: Optional[bool] = None + stats: Optional[CacheStatistics] = None def __hash__(self) -> int: return hash(