diff --git a/torchrec/distributed/planner/types.py b/torchrec/distributed/planner/types.py index 51bfbceb4..c31e2b93e 100644 --- a/torchrec/distributed/planner/types.py +++ b/torchrec/distributed/planner/types.py @@ -338,6 +338,7 @@ def __hash__(self) -> int: self.sharding_type, self.compute_kernel, tuple(self.shards), + self.cache_params, ) ) diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index cba6b9b80..35465c0a4 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -481,6 +481,17 @@ class CacheParams: precision: Optional[DataType] = None prefetch_pipeline: Optional[bool] = None + def __hash__(self) -> int: + return hash( + ( + self.algorithm, + self.load_factor, + self.reserved_memory, + self.precision, + self.prefetch_pipeline, + ) + ) + @dataclass class ParameterSharding: