Skip to content

Commit

Permalink
Add cache params to hash
Browse files Browse the repository at this point in the history
Summary: The reason for doing that is we want to distinguish sharding options with a different cache load factor.

Differential Revision: D50523588
  • Loading branch information
hlhtsang authored and facebook-github-bot committed Oct 31, 2023
1 parent 9434f25 commit 16b52ee
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
1 change: 1 addition & 0 deletions torchrec/distributed/planner/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def __hash__(self) -> int:
self.sharding_type,
self.compute_kernel,
tuple(self.shards),
self.cache_params,
)
)

Expand Down
11 changes: 11 additions & 0 deletions torchrec/distributed/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,17 @@ class CacheParams:
precision: Optional[DataType] = None
prefetch_pipeline: Optional[bool] = None

def __hash__(self) -> int:
return hash(
(
self.algorithm,
self.load_factor,
self.reserved_memory,
self.precision,
self.prefetch_pipeline,
)
)


@dataclass
class ParameterSharding:
Expand Down

0 comments on commit 16b52ee

Please sign in to comment.