From 2e0e188349acee292135a1d8122a56eada63f37d Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Mon, 9 Oct 2023 12:03:26 -0700 Subject: [PATCH] Add fused_params to fp-ebc (#1429) Summary: This is needed in the planner, since we rely on sharder to pass in global cache_load_factor. mc-ebc also has this problem. But it is less worrying. Reviewed By: YLGH Differential Revision: D49105069 --- torchrec/distributed/fp_embeddingbag.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torchrec/distributed/fp_embeddingbag.py b/torchrec/distributed/fp_embeddingbag.py index 2c703b951..9f2020762 100644 --- a/torchrec/distributed/fp_embeddingbag.py +++ b/torchrec/distributed/fp_embeddingbag.py @@ -6,7 +6,7 @@ # LICENSE file in the root directory of this source tree. from functools import partial -from typing import Dict, Iterator, List, Optional, Type, Union +from typing import Any, Dict, Iterator, List, Optional, Type, Union import torch from torch import nn @@ -190,6 +190,11 @@ def shard( device=device, ) + @property + def fused_params(self) -> Optional[Dict[str, Any]]: + # TODO: to be deprecate after planner get cache_load_factor from ParameterConstraints + return self._ebc_sharder.fused_params + def shardable_parameters( self, module: FeatureProcessedEmbeddingBagCollection ) -> Dict[str, torch.nn.Parameter]: