Skip to content

Commit

Permalink
Implement EmbeddingOffloadScaleupProposer
Browse files Browse the repository at this point in the history
Summary:
Implements a new type of Proposer that attempts to scale up
fused_uvm_caches individually according to an allocation policy based
on the expected statistical distribution of the cache workload and a
budget of HBM memory that is available for caching. Scaling
fused_uvm_caches identically (e.g. using a global default load factor)
is suboptimal as we find significant differences between the cache
load factors needed for different embedding tables to achieve a
reasonable miss rate.

This diff just implements the Proposer, but does not (yet) use it by
default. To enable the new Proposer, the trainer should explicitly
specify this Proposer when initializing the EmbeddingShardingPlanner.

The cost model for fused_uvm_caching does not yet fully account for
storage and perf overheads of the cache. So this proposer should not
be used in conjunction with other proposers in the planner. In a later
diff we will improve the cost model so proposals generated by
caching-aware proposers and existing proposers are comparable,
removing this restriction.

Differential Revision: D51451167
  • Loading branch information
damianr99 authored and facebook-github-bot committed Dec 6, 2023
1 parent c93eef0 commit 2179eda
Show file tree
Hide file tree
Showing 4 changed files with 675 additions and 5 deletions.
285 changes: 283 additions & 2 deletions torchrec/distributed/planner/proposers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,15 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import copy
import itertools
import logging
from decimal import Decimal
from typing import cast, Dict, List, Optional, Set, Tuple
from typing import cast, Dict, List, Optional, Set, Tuple, Union

import torch

from torchrec.distributed.embedding_types import EmbeddingComputeKernel

from torchrec.distributed.planner.types import (
Enumerator,
Expand All @@ -17,7 +22,7 @@
ShardingOption,
Topology,
)
from torchrec.distributed.planner.utils import prod
from torchrec.distributed.planner.utils import BinarySearchPredicate, prod

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -272,6 +277,282 @@ def feedback(
self._proposal_index += 1


class EmbeddingOffloadScaleupProposer(Proposer):
def __init__(self, use_depth: bool = True) -> None:
self.use_depth: bool = use_depth
self.enumerator: Optional[Enumerator] = None
self.starting_proposal: List[ShardingOption] = []
self.proposal: Optional[List[ShardingOption]] = None
self.search: Optional[BinarySearchPredicate] = None

def load(
self,
search_space: List[ShardingOption],
enumerator: Optional[Enumerator] = None,
) -> None:
self.enumerator = enumerator
sharding_options_by_fqn: Dict[str, List[ShardingOption]] = {}
for sharding_option in search_space:
sharding_options_by_fqn.setdefault(sharding_option.fqn, []).append(
sharding_option
)
for sharding_options in sharding_options_by_fqn.values():
sharding_options.sort(
key=lambda x: _sharding_option_score(x, self.use_depth)
)
# currently only use 1st sharding option for proposal only.
# TODO: could traverse through multiple options like GreedyProposer
proposal = [
sharding_options[0] for sharding_options in sharding_options_by_fqn.values()
]
# deepcopy so it won't affect other proposers
self.starting_proposal = copy.deepcopy(proposal)
self.proposal = copy.deepcopy(self.starting_proposal)

def propose(self) -> Optional[List[ShardingOption]]:
return self.proposal

def feedback(
self,
partitionable: bool,
plan: Optional[List[ShardingOption]] = None,
perf_rating: Optional[float] = None,
storage_constraint: Optional[Topology] = None,
) -> None:
if not self.enumerator or not plan or not storage_constraint:
self.proposal = None
return

if self.search is None:
# Determine how much extra HBM memory is available for scaling our caches
# beyond the baseline-min-working-set plan. We may not be able to find a
# partitionable plan that uses all this budget, or we may find a plan that
# uses only a portion of this budget enables a layout that reduces overall
# cost at the expense of larger prefetch delay. So we perform a binary
# search to sample plans with different budgets to discover a good
# configuration.
hbm_available = EmbeddingOffloadScaleupProposer.get_budget(
plan, storage_constraint
)
# Partitioning proposals is quite expensive when there are a lot of tables,
# so we reduce the number probes the binary search uses to find the max
# cache sizes that fit inside the budget by specifying a tolerance. Once
# we've less than tolerance bytes left of unused budget we stop searching.
# We set tolerance to try to waste less than 3% of budget. For 100TB budget,
# this reduces number of proposals from 47 to 6.
tolerance = round(hbm_available * 0.03)
self.search = BinarySearchPredicate(0, hbm_available, tolerance)

budget = self.search.next(partitionable)
self.proposal = EmbeddingOffloadScaleupProposer.next_plan(
self.starting_proposal, budget, self.enumerator
)

@staticmethod
def get_budget(proposal: List[ShardingOption], storage_constraint: Topology) -> int:
"""returns additional HBM budget available for GPU caches."""
available_hbm = sum(device.storage.hbm for device in storage_constraint.devices)
used_hbm = sum(
sharding_option.total_storage.hbm for sharding_option in proposal
)
return available_hbm - used_hbm

# Given an available budget of additional memory, and a provisional sharding plan,
# attempt to use the budget wisely to scale up caches that would most benefit from it.
@staticmethod
def next_plan(
starting_proposal: List[ShardingOption],
budget: Optional[int],
enumerator: Optional[Enumerator],
) -> Optional[List[ShardingOption]]:
if budget is None or enumerator is None:
return None

def n2z(x: Optional[float]) -> float:
return x if x is not None else 0.0

proposal = copy.deepcopy(starting_proposal)
# This is the subset of tables that we can scale
cache_tables = [
sharding_option
for sharding_option in proposal
if sharding_option.compute_kernel
== EmbeddingComputeKernel.FUSED_UVM_CACHING.value
and n2z(EmbeddingOffloadScaleupProposer.get_cacheability(sharding_option))
* n2z(EmbeddingOffloadScaleupProposer.get_expected_lookups(sharding_option))
* n2z(EmbeddingOffloadScaleupProposer.get_load_factor(sharding_option))
> 0
]
# Nothing to scale
if len(cache_tables) == 0:
return None

size_model = EmbeddingOffloadScaleupProposer.build_affine_storage_model(
cache_tables, enumerator
)
clfs = torch.tensor(
[
EmbeddingOffloadScaleupProposer.get_load_factor(sharding_option)
for sharding_option in cache_tables
]
)
# cooked_cacheability is cacheability scaled by the expected number of cache
# lookups.

cooked_cacheability = torch.tensor(
[
n2z(EmbeddingOffloadScaleupProposer.get_cacheability(sharding_option))
* n2z(
EmbeddingOffloadScaleupProposer.get_expected_lookups(
sharding_option
)
)
for sharding_option in cache_tables
]
)
new_clfs = EmbeddingOffloadScaleupProposer.allocate_budget(
model=size_model,
clfs=clfs,
budget=budget,
allocation_priority=cooked_cacheability,
)
# apply new_clfs, promoting tables that made it to 1.0
for sharding_option, clf in zip(cache_tables, new_clfs):
clf = clf.item() # tensor scalar -> scalar
assert sharding_option.cache_params # appease pyre
sharding_option.cache_params.load_factor = clf
if clf > 0.9999: # tolerate float roundoff
assert sharding_option.cache_params # appease pyre
sharding_option.cache_params.load_factor = None
sharding_option.compute_kernel = EmbeddingComputeKernel.FUSED.value
# recalculate cost estimates of modified tables
enumerator.populate_estimates(cache_tables)
return proposal

@staticmethod
def get_cacheability(sharding_option: ShardingOption) -> Optional[float]:
# helper to appease pyre type checker, as cache_params is Optional it maybe None
if (
sharding_option.cache_params is None
or sharding_option.cache_params.stats is None
):
return None
return sharding_option.cache_params.stats.cacheability

@staticmethod
def get_expected_lookups(sharding_option: ShardingOption) -> Optional[float]:
# helper to appease pyre type checker, as cache_params is Optional it maybe None
if (
sharding_option.cache_params is None
or sharding_option.cache_params.stats is None
):
return None
return sharding_option.cache_params.stats.expected_lookups

@staticmethod
def get_load_factor(sharding_option: ShardingOption) -> Optional[float]:
# helper to appease pyre type checker, as cache_params is Optional it maybe None
if (
sharding_option.cache_params is None
or sharding_option.cache_params.stats is None
):
return None
return sharding_option.cache_params.load_factor

# The relationship between clf and shard memory usage is non-linear due to non-clf
# overheads like optimization stats and input/output storage. We model it as an
# affine relationship: bytes = clf * A + B where B is fixed overhead independent of
# CLF (e.g. input / output IO sizes and A is per cache-row overhead.
@staticmethod
def build_affine_storage_model(
uvm_caching_sharding_options: List[ShardingOption], enumerator: Enumerator
) -> torch.Tensor:
plan: List[ShardingOption] = copy.deepcopy(uvm_caching_sharding_options)

def compute_hbm_sizes(clf: float) -> torch.Tensor:
for sharding_option in plan:
assert sharding_option.cache_params # appease pyre
sharding_option.cache_params.load_factor = clf
enumerator.populate_estimates(plan)
return torch.tensor(
[sharding_option.total_storage.hbm for sharding_option in plan]
)

low_clf, high_clf = 0.1, 0.9
low_hbms = compute_hbm_sizes(low_clf)
high_hbms = compute_hbm_sizes(high_clf)

A = (high_hbms - low_hbms) / (high_clf - low_clf)
B = low_hbms - A * low_clf
return torch.stack((A, B), dim=1) # Nx2 (a,b)

@staticmethod
def clf_to_bytes(
model: torch.Tensor, clfs: Union[float, torch.Tensor]
) -> torch.Tensor:
# evaluate affine model AX + B
return (model[:, 0] * clfs + model[:, 1]).to(torch.int64)

# Given a model of an affine system, an existing configuration (clfs), available
# budget, and an allocation policy, return new configuration that best uses the
# available budget. We only add additional budget, we assume the existing
# configuration is specifying a floor or minimum size.
@staticmethod
def allocate_budget(
model: torch.Tensor,
clfs: torch.Tensor,
budget: int,
allocation_priority: torch.Tensor,
) -> torch.Tensor:
# min size is size of table at 0 CLF
min_size_bytes = EmbeddingOffloadScaleupProposer.clf_to_bytes(model, 0)
max_size_bytes = EmbeddingOffloadScaleupProposer.clf_to_bytes(model, 1)
table_size_bytes = EmbeddingOffloadScaleupProposer.clf_to_bytes(model, clfs)
cache_size_bytes = table_size_bytes - min_size_bytes
max_cache_size_bytes = max_size_bytes - min_size_bytes

# We have budget bytes to share across the tables in. We want to increase the
# cache_size_bytes of each table in proportion to their allocation priority
# fraction. If we raise the cache_size_bytes to beyond max_cache_size_bytes,
# this is equivalent to reaching CLF=1.0, so we clip the memory to 1.0, and
# reassign the released budget in a subsequent pass.
num_pass = 0
while budget > 1 and num_pass < 128:
num_pass += 1
# mask is False for tables at >= max_size, and True otherwise. This allows
# us to remove tables that have already reached full size in one round from
# being dealt more budget in subsequent rounds.
mask = (min_size_bytes + cache_size_bytes) < max_size_bytes
if mask.sum() == 0:
break

logging.debug(
f"[allocate_budget] pass={num_pass}, budget={budget}, #cache_tables={mask.sum()}"
)

# switch to 64bit float to avoid rounding errors, as table cache sizes can
# easily be > 2^24.
masked_priority = (mask * allocation_priority).to(torch.float64)
increase_ratio = masked_priority / torch.sum(masked_priority)
proposed_increase_bytes = budget * increase_ratio
new_cache_size_bytes = torch.minimum(
cache_size_bytes + proposed_increase_bytes, max_cache_size_bytes
)
actual_increase_bytes = new_cache_size_bytes - cache_size_bytes

budget -= torch.sum(actual_increase_bytes)
cache_size_bytes = new_cache_size_bytes
# TODO: consider trade off of using remaining budget to push >0.95 tables
# to HBM vs spending that budget on improving hit rate on other tables in
# next pass.

# cache_size_bytes are the new cache sizes we want to use. We convert them back
# to clfs by dividing by max_cache_size_bytes, which has isolated the clf
# portion of the table size from the fixed overheads.
# convert 64bit values back to original clf precision
return (cache_size_bytes / max_cache_size_bytes).to(clfs.dtype)


def _sharding_option_score(
sharding_option: ShardingOption, use_depth: bool = True
) -> float:
Expand Down
Loading

0 comments on commit 2179eda

Please sign in to comment.