Skip to content

Commit

Permalink
Allow passing feature names through ParameterConstraints and Sharding…
Browse files Browse the repository at this point in the history
…Option

Summary:
Allow users to pass in feature_names. Preferably, it should be same as table.feature_names.

In the case when no constraints are passed in, the feature_names will be None. 

Another option is to force extract feature names from BaseEmbeddingConfig, but that seems to be against the convention.

Differential Revision: D53454672
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed Feb 6, 2024
1 parent 0c2462f commit 2157f0e
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 13 deletions.
6 changes: 6 additions & 0 deletions torchrec/distributed/planner/enumerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def enumerate(
enforce_hbm,
stochastic_rounding,
bounds_check_mode,
feature_names,
) = _extract_constraints_for_param(self._constraints, name)

sharding_options_per_table: List[ShardingOption] = []
Expand Down Expand Up @@ -172,6 +173,7 @@ def enumerate(
bounds_check_mode=bounds_check_mode,
dependency=dependency,
is_pooled=is_pooled,
feature_names=feature_names,
)
)
if not sharding_options_per_table:
Expand Down Expand Up @@ -265,13 +267,15 @@ def _extract_constraints_for_param(
Optional[bool],
Optional[bool],
Optional[BoundsCheckMode],
Optional[List[str]],
]:
input_lengths = [POOLING_FACTOR]
col_wise_shard_dim = None
cache_params = None
enforce_hbm = None
stochastic_rounding = None
bounds_check_mode = None
feature_names = None

if constraints and constraints.get(name):
input_lengths = constraints[name].pooling_factors
Expand All @@ -280,6 +284,7 @@ def _extract_constraints_for_param(
enforce_hbm = constraints[name].enforce_hbm
stochastic_rounding = constraints[name].stochastic_rounding
bounds_check_mode = constraints[name].bounds_check_mode
feature_names = constraints[name].feature_names

return (
input_lengths,
Expand All @@ -288,6 +293,7 @@ def _extract_constraints_for_param(
enforce_hbm,
stochastic_rounding,
bounds_check_mode,
feature_names,
)


Expand Down
62 changes: 49 additions & 13 deletions torchrec/distributed/planner/tests/test_planners.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,20 @@
# LICENSE file in the root directory of this source tree.

import unittest
from typing import cast, List
from typing import cast, List, Optional

import torch
from torch import nn
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.planner import ParameterConstraints
from torchrec.distributed.planner.planners import EmbeddingShardingPlanner
from torchrec.distributed.planner.types import PlannerError, PlannerErrorType, Topology
from torchrec.distributed.planner.types import (
PlannerError,
PlannerErrorType,
ShardingOption,
Topology,
)
from torchrec.distributed.sharding_plan import get_default_sharders
from torchrec.distributed.test_utils.test_model import TestSparseNN
from torchrec.distributed.types import (
Expand Down Expand Up @@ -174,42 +179,48 @@ def setUp(self) -> None:
self.topology = Topology(
world_size=2, hbm_cap=1024 * 1024 * 2, compute_device=compute_device
)
self.tables = [
EmbeddingBagConfig(
num_embeddings=100,
embedding_dim=64,
name="table_" + str(i),
feature_names=["feature_" + str(i)],
)
for i in range(4)
]
self.constraints = {
"table_0": ParameterConstraints(
enforce_hbm=True,
cache_params=CacheParams(
algorithm=CacheAlgorithm.LFU,
),
feature_names=self.tables[0].feature_names,
),
"table_1": ParameterConstraints(
enforce_hbm=False,
stochastic_rounding=True,
feature_names=self.tables[1].feature_names,
),
"table_2": ParameterConstraints(
bounds_check_mode=BoundsCheckMode.FATAL,
feature_names=self.tables[2].feature_names,
),
"table_2": ParameterConstraints(bounds_check_mode=BoundsCheckMode.FATAL),
"table_3": ParameterConstraints(
cache_params=CacheParams(
algorithm=CacheAlgorithm.LFU,
load_factor=0.1,
reserved_memory=1.0,
precision=DataType.FP16,
),
feature_names=self.tables[3].feature_names,
),
}
self.planner = EmbeddingShardingPlanner(
topology=self.topology, constraints=self.constraints
)

def test_fused_paramters_from_constraints(self) -> None:
tables = [
EmbeddingBagConfig(
num_embeddings=100,
embedding_dim=64,
name="table_" + str(i),
feature_names=["feature_" + str(i)],
)
for i in range(4)
]
model = TestSparseNN(tables=tables, sparse_device=torch.device("meta"))
model = TestSparseNN(tables=self.tables, sparse_device=torch.device("meta"))
sharding_plan = self.planner.plan(module=model, sharders=get_default_sharders())

expected_fused_params = {
Expand Down Expand Up @@ -253,3 +264,28 @@ def test_fused_paramters_from_constraints(self) -> None:
),
expected_fused_params[table],
)

def test_passing_info_through_constraints(self) -> None:
model = TestSparseNN(tables=self.tables, sparse_device=torch.device("meta"))
_ = self.planner.plan(module=model, sharders=get_default_sharders())

best_plan: Optional[List[ShardingOption]] = self.planner._best_plan
self.assertIsNotNone(best_plan)

for table, constraint, sharding_option in zip(
self.tables, self.constraints.values(), best_plan
):
self.assertEqual(table.name, sharding_option.name)

self.assertEqual(table.feature_names, sharding_option.feature_names)
self.assertEqual(table.feature_names, constraint.feature_names)

self.assertEqual(constraint.cache_params, sharding_option.cache_params)
self.assertEqual(constraint.enforce_hbm, sharding_option.enforce_hbm)
self.assertEqual(
constraint.stochastic_rounding, sharding_option.stochastic_rounding
)
self.assertEqual(
constraint.bounds_check_mode, sharding_option.bounds_check_mode
)
self.assertEqual(constraint.is_weighted, sharding_option.is_weighted)
3 changes: 3 additions & 0 deletions torchrec/distributed/planner/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def __init__(
bounds_check_mode: Optional[BoundsCheckMode] = None,
dependency: Optional[str] = None,
is_pooled: Optional[bool] = None,
feature_names: Optional[List[str]] = None,
) -> None:
self.name = name
self._tensor = tensor
Expand All @@ -307,6 +308,7 @@ def __init__(
self.dependency = dependency
self._is_pooled = is_pooled
self.is_weighted: Optional[bool] = None
self.feature_names: Optional[List[str]] = feature_names

@property
def tensor(self) -> torch.Tensor:
Expand Down Expand Up @@ -436,6 +438,7 @@ class ParameterConstraints:
enforce_hbm: Optional[bool] = None
stochastic_rounding: Optional[bool] = None
bounds_check_mode: Optional[BoundsCheckMode] = None
feature_names: Optional[List[str]] = None


class PlannerErrorType(Enum):
Expand Down

0 comments on commit 2157f0e

Please sign in to comment.