Skip to content

Commit

Permalink
Allow passing the prefetch_pipeline flag through ParameterConstraints (
Browse files Browse the repository at this point in the history
…pytorch#1435)

Summary:
Pull Request resolved: pytorch#1435

Allow passing the prefetch_pipeline flag through ParameterConstraints.

If the flag is both passed via sharders and ParameterConstraints, then the one in ParameterConstraints will take precedence.

Reviewed By: YLGH

Differential Revision: D50095093

fbshipit-source-id: 84ff9ab0fc8d0b73ede8f618e3daa93eee09f513
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed Oct 10, 2023
1 parent fd5269d commit 3870c59
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 1 deletion.
9 changes: 8 additions & 1 deletion torchrec/distributed/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,11 @@ def setUp(self) -> None:
compute_kernel="dense",
ranks=[0, 1],
sharding_spec=None,
cache_params=CacheParams(algorithm=CacheAlgorithm.LFU, reserved_memory=1.0),
cache_params=CacheParams(
algorithm=CacheAlgorithm.LFU,
reserved_memory=1.0,
prefetch_pipeline=False,
),
enforce_hbm=False,
stochastic_rounding=True,
bounds_check_mode=BoundsCheckMode.WARNING,
Expand All @@ -374,6 +378,7 @@ def test_add_params_from_parameter_sharding(self) -> None:
expected_fused_params = {
"cache_algorithm": CacheAlgorithm.LFU,
"cache_reserved_memory": 1.0,
"prefetch_pipeline": False,
"enforce_hbm": False,
"stochastic_rounding": True,
"bounds_check_mode": BoundsCheckMode.WARNING,
Expand All @@ -385,6 +390,7 @@ def test_add_params_from_parameter_sharding_override(self) -> None:
"learning_rate": 0.1,
"cache_algorithm": CacheAlgorithm.LRU,
"stochastic_rounding": False,
"prefetch_pipeline": True,
}
fused_params = add_params_from_parameter_sharding(
fused_params, self.parameter_sharding
Expand All @@ -393,6 +399,7 @@ def test_add_params_from_parameter_sharding_override(self) -> None:
"learning_rate": 0.1,
"cache_algorithm": CacheAlgorithm.LFU,
"cache_reserved_memory": 1.0,
"prefetch_pipeline": False,
"enforce_hbm": False,
"stochastic_rounding": True,
"bounds_check_mode": BoundsCheckMode.WARNING,
Expand Down
1 change: 1 addition & 0 deletions torchrec/distributed/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ class CacheParams:
load_factor: Optional[float] = None
reserved_memory: Optional[float] = None
precision: Optional[DataType] = None
prefetch_pipeline: Optional[bool] = None


@dataclass
Expand Down
2 changes: 2 additions & 0 deletions torchrec/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,8 @@ def add_params_from_parameter_sharding(
fused_params["cache_reserved_memory"] = cache_params.reserved_memory
if cache_params.precision is not None:
fused_params["cache_precision"] = cache_params.precision
if cache_params.prefetch_pipeline is not None:
fused_params["prefetch_pipeline"] = cache_params.prefetch_pipeline

if parameter_sharding.enforce_hbm is not None:
fused_params["enforce_hbm"] = parameter_sharding.enforce_hbm
Expand Down

0 comments on commit 3870c59

Please sign in to comment.