From 16b52eefb36d8675a785ad7311f053179db56349 Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Mon, 30 Oct 2023 17:16:51 -0700 Subject: [PATCH] Add cache params to hash Summary: The reason for doing that is we want to distinguish sharding options with a different cache load factor. Differential Revision: D50523588 --- torchrec/distributed/planner/types.py | 1 + torchrec/distributed/types.py | 11 +++++++++++ 2 files changed, 12 insertions(+) diff --git a/torchrec/distributed/planner/types.py b/torchrec/distributed/planner/types.py index 51bfbceb4..c31e2b93e 100644 --- a/torchrec/distributed/planner/types.py +++ b/torchrec/distributed/planner/types.py @@ -338,6 +338,7 @@ def __hash__(self) -> int: self.sharding_type, self.compute_kernel, tuple(self.shards), + self.cache_params, ) ) diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index cba6b9b80..35465c0a4 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -481,6 +481,17 @@ class CacheParams: precision: Optional[DataType] = None prefetch_pipeline: Optional[bool] = None + def __hash__(self) -> int: + return hash( + ( + self.algorithm, + self.load_factor, + self.reserved_memory, + self.precision, + self.prefetch_pipeline, + ) + ) + @dataclass class ParameterSharding: