Skip to content

Commit

Permalink
Make changes to proposer design (pytorch#1505)
Browse files Browse the repository at this point in the history
Summary:

Allowing proposer to call the enumerator. Together with enumerate.populate_estimates, this allows the proposer to change the sharding options and re-estimate their perfs and storages.

Reviewed By: ge0405

Differential Revision: D50514266
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed Nov 13, 2023
1 parent 9db35bb commit 323699b
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 8 deletions.
12 changes: 9 additions & 3 deletions torchrec/distributed/planner/planners.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def plan(
] = {}

for proposer in self._proposers:
proposer.load(search_space=search_space)
proposer.load(search_space=search_space, enumerator=self._enumerator)

for proposer in self._proposers:
proposal = proposer.propose()
Expand All @@ -242,6 +242,7 @@ def plan(
partitionable=partitionable,
plan=plan,
perf_rating=perf_rating,
storage_constraint=storage_constraint,
)
proposal = proposer.propose()
continue
Expand All @@ -260,7 +261,10 @@ def plan(
best_plan = copy.deepcopy(plan)
proposal_cache[proposal_key] = (True, plan, perf_rating)
proposer.feedback(
partitionable=True, plan=plan, perf_rating=perf_rating
partitionable=True,
plan=plan,
perf_rating=perf_rating,
storage_constraint=storage_constraint,
)
except PlannerError as planner_error:
last_planner_error = planner_error
Expand All @@ -280,7 +284,9 @@ def plan(
if current_storage < lowest_storage:
lowest_storage = current_storage
proposal_cache[proposal_key] = (False, None, None)
proposer.feedback(partitionable=False)
proposer.feedback(
partitionable=False, storage_constraint=storage_constraint
)

# clear shard.rank for each sharding_option
reset_shard_rank(proposal)
Expand Down
29 changes: 25 additions & 4 deletions torchrec/distributed/planner/proposers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@
from decimal import Decimal
from typing import cast, Dict, List, Optional, Set, Tuple

from torchrec.distributed.planner.types import Perf, Proposer, ShardingOption
from torchrec.distributed.planner.types import (
Enumerator,
Perf,
Proposer,
ShardingOption,
Topology,
)
from torchrec.distributed.planner.utils import prod

logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -43,7 +49,11 @@ def __init__(self, use_depth: bool = True, threshold: Optional[int] = None) -> N
self._best_perf_rating: float = float("inf")
self._num_inferior_perf: int = 0

def load(self, search_space: List[ShardingOption]) -> None:
def load(
self,
search_space: List[ShardingOption],
enumerator: Optional[Enumerator] = None,
) -> None:
self._reset()
for sharding_option in search_space:
fqn = sharding_option.fqn
Expand Down Expand Up @@ -78,6 +88,7 @@ def feedback(
partitionable: bool,
plan: Optional[List[ShardingOption]] = None,
perf_rating: Optional[float] = None,
storage_constraint: Optional[Topology] = None,
) -> None:
# When threshold is passed, observe the perf_rating trend. If the perf_rating
# of the newly proposed plans have worse perf_rating, stop proposing.
Expand Down Expand Up @@ -126,7 +137,11 @@ def __init__(self, use_depth: bool = True) -> None:
self._grouped_sharding_options: List[List[ShardingOption]] = []
self._proposal_index: int = 0

def load(self, search_space: List[ShardingOption]) -> None:
def load(
self,
search_space: List[ShardingOption],
enumerator: Optional[Enumerator] = None,
) -> None:
self._reset()
all_fqns = set()
sharding_options_by_type_and_fqn: Dict[
Expand Down Expand Up @@ -175,6 +190,7 @@ def feedback(
partitionable: bool,
plan: Optional[List[ShardingOption]] = None,
perf_rating: Optional[float] = None,
storage_constraint: Optional[Topology] = None,
) -> None:
# static strategy, ignore feedback and just provide next proposal
self._proposal_index += 1
Expand All @@ -187,7 +203,11 @@ def __init__(self, max_proposals: int = MAX_PROPOSALS) -> None:
self._proposal_index: int = 0
self._proposals: List[List[int]] = []

def load(self, search_space: List[ShardingOption]) -> None:
def load(
self,
search_space: List[ShardingOption],
enumerator: Optional[Enumerator] = None,
) -> None:
self._reset()
for sharding_option in search_space:
fqn = sharding_option.fqn
Expand Down Expand Up @@ -246,6 +266,7 @@ def feedback(
partitionable: bool,
plan: Optional[List[ShardingOption]] = None,
perf_rating: Optional[float] = None,
storage_constraint: Optional[Topology] = None,
) -> None:
# static strategy, ignore feedback and just provide next proposal
self._proposal_index += 1
Expand Down
9 changes: 8 additions & 1 deletion torchrec/distributed/planner/tests/test_proposers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
proposers_to_proposals_list,
UniformProposer,
)
from torchrec.distributed.planner.types import Proposer, ShardingOption, Topology
from torchrec.distributed.planner.types import (
Enumerator,
Proposer,
ShardingOption,
Topology,
)
from torchrec.distributed.test_utils.test_model import TestSparseNN
from torchrec.distributed.types import ModuleSharder, ShardingType
from torchrec.modules.embedding_configs import EmbeddingBagConfig
Expand All @@ -29,6 +34,7 @@ class MockProposer(Proposer):
def load(
self,
search_space: List[ShardingOption],
enumerator: Optional[Enumerator] = None,
) -> None:
pass

Expand All @@ -37,6 +43,7 @@ def feedback(
partitionable: bool,
plan: Optional[List[ShardingOption]] = None,
perf_rating: Optional[float] = None,
storage_constraint: Optional[Topology] = None,
) -> None:
pass

Expand Down
9 changes: 9 additions & 0 deletions torchrec/distributed/planner/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,13 @@ def enumerate(
"""
...

@abc.abstractmethod
def populate_estimates(self, sharding_options: List[ShardingOption]) -> None:
"""
See class description.
"""
...


class Proposer(abc.ABC):
"""
Expand All @@ -507,6 +514,7 @@ class Proposer(abc.ABC):
def load(
self,
search_space: List[ShardingOption],
enumerator: Optional[Enumerator] = None,
) -> None:
...

Expand All @@ -516,6 +524,7 @@ def feedback(
partitionable: bool,
plan: Optional[List[ShardingOption]] = None,
perf_rating: Optional[float] = None,
storage_constraint: Optional[Topology] = None,
) -> None:
...

Expand Down

0 comments on commit 323699b

Please sign in to comment.