Skip to content

Commit

Permalink
Extend torchrec CacheParams to include CacheStatistics
Browse files Browse the repository at this point in the history
Summary:
Extend torchrec CacheParams with cacheability data (miss ratio curves
for embedding tables), so we can make smarter decisions about
uvm_cache sizes and shard placement in the planner.

Reviewed By: henrylhtsang

Differential Revision: D51129138
  • Loading branch information
damianr99 authored and facebook-github-bot committed Nov 20, 2023
1 parent 873faf3 commit 6ed5f92
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 1 deletion.
86 changes: 85 additions & 1 deletion torchrec/distributed/planner/shard_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@
Topology,
)
from torchrec.distributed.planner.utils import prod, sharder_name
from torchrec.distributed.types import CommOp, ModuleSharder, ShardingType
from torchrec.distributed.types import (
CacheStatistics,
CommOp,
ModuleSharder,
ShardingType,
)

from torchrec.modules.embedding_modules import EmbeddingBagCollectionInterface

Expand Down Expand Up @@ -1191,3 +1196,82 @@ def _get_optimizer_multipler(
return 1 / shape[-1]
else:
return 1


class EmbeddingOffloadStats(CacheStatistics):
"""Computes cache statistics for uvm_fused_cache tables.
Args:
cachebility (float):
The area-under-the-curve of miss-ratio curve.
expected_lookups (float):
The expected number of unique embedding ids per global batch.
mrc_hist_counts (torch.Tensor):
A 1d tensor (size n) holding a histogram of LRU miss ratio curve. Each bin
represents 1/nth of possible LRU cache sizes (from load_factor 0 to load_factor
1.0). The bin contains the number of expected LRU operations that could be
handled without a cache miss if the LRU load_factor was at least that size.
height (int):
The height (num_embeddings) of the embedding table.
"""

def __init__(
self,
cacheability: float,
expected_lookups: int,
mrc_hist_counts: torch.Tensor,
height: int,
) -> None:
self._cacheability = cacheability
self._expected_lookups = expected_lookups
self.height = height

if mrc_hist_counts.dim() != 1:
raise ValueError(f"expected 1d tensor, got {mrc_hist_counts.dim()}d")
if mrc_hist_counts.size()[0] == 0:
raise ValueError("expected non-empty tensor")

self.hist: torch.Tensor = mrc_hist_counts
self.bins: torch.Tensor = torch.linspace(0, height, len(mrc_hist_counts) + 1)

@property
def expected_lookups(self) -> int:
return self._expected_lookups

def expected_miss_rate(self, clf: float) -> float:
cache_size = torch.tensor(clf * self.height)
miss_rate = EmbeddingOffloadStats.estimate_cache_miss_rate(
cache_sizes=cache_size, hist=self.hist, bins=self.bins
)
return miss_rate.item()

@property
def cacheability(self) -> float:
return self._cacheability

@staticmethod
def estimate_cache_miss_rate(
cache_sizes: torch.Tensor, hist: torch.Tensor, bins: torch.Tensor
) -> torch.Tensor:
"""Calculate estimated cache miss ratio for the proposed cache_sizes, given the MRC
histogram.
"""
ys = hist.cumsum(dim=0)
ys = ys / ys[-1] # rescale [0,1]
ys = 1 - ys # make miss-ratio, not hit-ratio

# torch.bucketize has slightly different semantics to np.digitize,
# and np.digitize has a complex interface, read the docs carefully!
# we're trying to reverse the ops of np.histogram, indices are one larger than
# the insert positions, since with right=True, index returned such that x <
# bins[index], so x 'lives' in hist[index-1]
# A cache size of k will get hits for all stack distances of upto k-1 inclusive.
larger_bin_indices = torch.bucketize(cache_sizes - 1, bins, right=True)
# Augment ys to deal with torch.bucketize boundary conditions:
# values outside of bins range map to 0, or len(bins).
# So we extend ys to populate sentinel values for these cases. With the twist that
# the left-hand sentinel we put on the right side of the array, as larger_bin_indices - 1
# maps 0 -> -1, which pytorch maps to most right hand value.
ys = torch.cat((ys, torch.tensor([0.0, 1.0])))
return ys[larger_bin_indices - 1]
54 changes: 54 additions & 0 deletions torchrec/distributed/planner/tests/test_shard_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torchrec.distributed.planner.enumerators import EmbeddingEnumerator
from torchrec.distributed.planner.shard_estimators import (
_calculate_storage_specific_sizes,
EmbeddingOffloadStats,
EmbeddingPerfEstimator,
)
from torchrec.distributed.planner.types import Perf, Topology
Expand Down Expand Up @@ -447,3 +448,56 @@ def test_calculate_storage_specific_sizes(self) -> None:
)

self.assertEqual(estimates, expected_storage)


class TestEmbeddingOffloadStats(unittest.TestCase):
def test_basic(self) -> None:
stats = EmbeddingOffloadStats(
cacheability=0.42,
expected_lookups=31,
mrc_hist_counts=torch.tensor([99, 98, 97]),
height=92,
)
self.assertEqual(stats.cacheability, 0.42)
self.assertEqual(stats.expected_lookups, 31)
self.assertEqual(stats.expected_miss_rate(0), 1.0)
self.assertEqual(stats.expected_miss_rate(1), 0.0)
self.assertAlmostEqual(
stats.expected_miss_rate(0.5), 1 - (99 + 98) / (99 + 98 + 97)
)

def test_estimate_cache_miss_rate(self) -> None:
hist = torch.tensor([0, 6, 0, 8])
bins = torch.tensor([0.0, 1.0, 2.0, 3.0, 4.0])
miss_rates = EmbeddingOffloadStats.estimate_cache_miss_rate(
torch.tensor([0, 1, 2, 3, 4]), hist, bins
)
m = 1 - (6 / (6 + 8)) # from hist counts above
want = [
1, # size 0 - 100% miss
1, # size 1 - 100%, no immediate repetitions
m, # size 2 - m (~57%) miss, 6 occurrences
m, # size 3 - same as size 2, no 3 stack distances,
# so increasing cache by 1 doesn't help
0, # size 4 - 0% miss rate, everything fits
]
torch.testing.assert_close(miss_rates, torch.tensor(want))

# test with bigger bins to better validate boundary conditions
# create simple linear miss rate curve
trace = torch.arange(100.0)
hist = torch.histc(trace, bins=10, min=0, max=100)
bins = torch.linspace(0, 100, len(hist) + 1)
miss_rates = EmbeddingOffloadStats.estimate_cache_miss_rate(
torch.tensor([0, 9, 10, 11, 89, 99, 100]), hist, bins
)
want = [
1, # 0 -> no cache, 100% miss
0.9, # 9 -> bin 0, which is all cache sizes <= 10, has 90 misses of 100, so 90% miss
0.9, # 10 -> bin 0, same as above
0.8, # 11 -> bin 1, cache sizes (10, 20], 80 misses out of 100, so 80% miss
0.1, # 89 -> bin 8, cache sizes (80, 90], 10 misses out of 100, so 10% miss
0, # 99 -> bin 9, cache sizes (90, 100], final last bin gets scaled to 1, so 0% misses
0, # 100 -> off the end of the histogram, 0% misses
]
torch.testing.assert_close(miss_rates, torch.tensor(want))
26 changes: 26 additions & 0 deletions torchrec/distributed/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,13 +451,39 @@ class ModuleShardingPlan:
pass


class CacheStatistics(abc.ABC):
@property
@abc.abstractmethod
def expected_lookups(self) -> float:
"""Number of expected cache lookups per training step.
This is the expected number of distinct values in a global training batch."""

@abc.abstractmethod
def expected_miss_rate(self, clf: float) -> float:
"""Expected cache lookup miss rate for a given cache size.
When clf (cache load factor) is 0, returns 1.0 (100% miss). When clf is 1.0,
returns 0 (100% hit). For values of clf between these extremes, returns the
estimated miss rate of the cache, e.g. based on knowledge of the statistical
properties of the training data set."""

@property
@abc.abstractmethod
def cacheability(self) -> float:
"""Summarized measure of the difficulty to cache a dataset that is independent of
cache size. A score of 0 means the dataset is very cacheable (e.g. high locality
between accesses), a score of 1 is very difficult to cache."""


@dataclass
class CacheParams:
algorithm: Optional[CacheAlgorithm] = None
load_factor: Optional[float] = None
reserved_memory: Optional[float] = None
precision: Optional[DataType] = None
prefetch_pipeline: Optional[bool] = None
stats: Optional[CacheStatistics] = None

def __hash__(self) -> int:
return hash(
Expand Down

0 comments on commit 6ed5f92

Please sign in to comment.