diff --git a/torchrec/distributed/fp_embeddingbag.py b/torchrec/distributed/fp_embeddingbag.py index 2c703b951..9f2020762 100644 --- a/torchrec/distributed/fp_embeddingbag.py +++ b/torchrec/distributed/fp_embeddingbag.py @@ -6,7 +6,7 @@ # LICENSE file in the root directory of this source tree. from functools import partial -from typing import Dict, Iterator, List, Optional, Type, Union +from typing import Any, Dict, Iterator, List, Optional, Type, Union import torch from torch import nn @@ -190,6 +190,11 @@ def shard( device=device, ) + @property + def fused_params(self) -> Optional[Dict[str, Any]]: + # TODO: to be deprecate after planner get cache_load_factor from ParameterConstraints + return self._ebc_sharder.fused_params + def shardable_parameters( self, module: FeatureProcessedEmbeddingBagCollection ) -> Dict[str, torch.nn.Parameter]: