From 9434f25674310b49b1d6749e90836eb2ba272ac7 Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Mon, 30 Oct 2023 17:16:51 -0700 Subject: [PATCH] Make changes to proposer design Summary: Allowing proposer to call the enumerator. Together with enumerate.populate_estimates, this allows the proposer to change the sharding options and re-estimate their perfs and storages. Differential Revision: D50514266 --- torchrec/distributed/planner/planners.py | 12 ++++++-- torchrec/distributed/planner/proposers.py | 29 ++++++++++++++++--- .../planner/tests/test_proposers.py | 9 +++++- torchrec/distributed/planner/types.py | 9 ++++++ 4 files changed, 51 insertions(+), 8 deletions(-) diff --git a/torchrec/distributed/planner/planners.py b/torchrec/distributed/planner/planners.py index b054b39df..7a75210b2 100644 --- a/torchrec/distributed/planner/planners.py +++ b/torchrec/distributed/planner/planners.py @@ -229,7 +229,7 @@ def plan( ] = {} for proposer in self._proposers: - proposer.load(search_space=search_space) + proposer.load(search_space=search_space, enumerator=self._enumerator) for proposer in self._proposers: proposal = proposer.propose() @@ -242,6 +242,7 @@ def plan( partitionable=partitionable, plan=plan, perf_rating=perf_rating, + storage_constraint=storage_constraint, ) proposal = proposer.propose() continue @@ -260,7 +261,10 @@ def plan( best_plan = copy.deepcopy(plan) proposal_cache[proposal_key] = (True, plan, perf_rating) proposer.feedback( - partitionable=True, plan=plan, perf_rating=perf_rating + partitionable=True, + plan=plan, + perf_rating=perf_rating, + storage_constraint=storage_constraint, ) except PlannerError as planner_error: last_planner_error = planner_error @@ -280,7 +284,9 @@ def plan( if current_storage < lowest_storage: lowest_storage = current_storage proposal_cache[proposal_key] = (False, None, None) - proposer.feedback(partitionable=False) + proposer.feedback( + partitionable=False, storage_constraint=storage_constraint + ) # clear shard.rank for each sharding_option reset_shard_rank(proposal) diff --git a/torchrec/distributed/planner/proposers.py b/torchrec/distributed/planner/proposers.py index f67cbf752..ef98fac44 100644 --- a/torchrec/distributed/planner/proposers.py +++ b/torchrec/distributed/planner/proposers.py @@ -10,7 +10,13 @@ from decimal import Decimal from typing import cast, Dict, List, Optional, Set, Tuple -from torchrec.distributed.planner.types import Perf, Proposer, ShardingOption +from torchrec.distributed.planner.types import ( + Enumerator, + Perf, + Proposer, + ShardingOption, + Topology, +) from torchrec.distributed.planner.utils import prod logger: logging.Logger = logging.getLogger(__name__) @@ -43,7 +49,11 @@ def __init__(self, use_depth: bool = True, threshold: Optional[int] = None) -> N self._best_perf_rating: float = float("inf") self._num_inferior_perf: int = 0 - def load(self, search_space: List[ShardingOption]) -> None: + def load( + self, + search_space: List[ShardingOption], + enumerator: Optional[Enumerator] = None, + ) -> None: self._reset() for sharding_option in search_space: fqn = sharding_option.fqn @@ -78,6 +88,7 @@ def feedback( partitionable: bool, plan: Optional[List[ShardingOption]] = None, perf_rating: Optional[float] = None, + storage_constraint: Optional[Topology] = None, ) -> None: # When threshold is passed, observe the perf_rating trend. If the perf_rating # of the newly proposed plans have worse perf_rating, stop proposing. @@ -126,7 +137,11 @@ def __init__(self, use_depth: bool = True) -> None: self._grouped_sharding_options: List[List[ShardingOption]] = [] self._proposal_index: int = 0 - def load(self, search_space: List[ShardingOption]) -> None: + def load( + self, + search_space: List[ShardingOption], + enumerator: Optional[Enumerator] = None, + ) -> None: self._reset() all_fqns = set() sharding_options_by_type_and_fqn: Dict[ @@ -175,6 +190,7 @@ def feedback( partitionable: bool, plan: Optional[List[ShardingOption]] = None, perf_rating: Optional[float] = None, + storage_constraint: Optional[Topology] = None, ) -> None: # static strategy, ignore feedback and just provide next proposal self._proposal_index += 1 @@ -187,7 +203,11 @@ def __init__(self, max_proposals: int = MAX_PROPOSALS) -> None: self._proposal_index: int = 0 self._proposals: List[List[int]] = [] - def load(self, search_space: List[ShardingOption]) -> None: + def load( + self, + search_space: List[ShardingOption], + enumerator: Optional[Enumerator] = None, + ) -> None: self._reset() for sharding_option in search_space: fqn = sharding_option.fqn @@ -246,6 +266,7 @@ def feedback( partitionable: bool, plan: Optional[List[ShardingOption]] = None, perf_rating: Optional[float] = None, + storage_constraint: Optional[Topology] = None, ) -> None: # static strategy, ignore feedback and just provide next proposal self._proposal_index += 1 diff --git a/torchrec/distributed/planner/tests/test_proposers.py b/torchrec/distributed/planner/tests/test_proposers.py index d10209706..6fd3c8998 100644 --- a/torchrec/distributed/planner/tests/test_proposers.py +++ b/torchrec/distributed/planner/tests/test_proposers.py @@ -19,7 +19,12 @@ proposers_to_proposals_list, UniformProposer, ) -from torchrec.distributed.planner.types import Proposer, ShardingOption, Topology +from torchrec.distributed.planner.types import ( + Enumerator, + Proposer, + ShardingOption, + Topology, +) from torchrec.distributed.test_utils.test_model import TestSparseNN from torchrec.distributed.types import ModuleSharder, ShardingType from torchrec.modules.embedding_configs import EmbeddingBagConfig @@ -29,6 +34,7 @@ class MockProposer(Proposer): def load( self, search_space: List[ShardingOption], + enumerator: Optional[Enumerator] = None, ) -> None: pass @@ -37,6 +43,7 @@ def feedback( partitionable: bool, plan: Optional[List[ShardingOption]] = None, perf_rating: Optional[float] = None, + storage_constraint: Optional[Topology] = None, ) -> None: pass diff --git a/torchrec/distributed/planner/types.py b/torchrec/distributed/planner/types.py index 811615907..51bfbceb4 100644 --- a/torchrec/distributed/planner/types.py +++ b/torchrec/distributed/planner/types.py @@ -488,6 +488,13 @@ def enumerate( """ ... + @abc.abstractmethod + def populate_estimates(self, sharding_options: List[ShardingOption]) -> None: + """ + See class description. + """ + ... + class Proposer(abc.ABC): """ @@ -499,6 +506,7 @@ class Proposer(abc.ABC): def load( self, search_space: List[ShardingOption], + enumerator: Optional[Enumerator] = None, ) -> None: ... @@ -508,6 +516,7 @@ def feedback( partitionable: bool, plan: Optional[List[ShardingOption]] = None, perf_rating: Optional[float] = None, + storage_constraint: Optional[Topology] = None, ) -> None: ...