diff --git a/torchrec/distributed/planner/shard_estimators.py b/torchrec/distributed/planner/shard_estimators.py index 8530dd008..650333d4c 100644 --- a/torchrec/distributed/planner/shard_estimators.py +++ b/torchrec/distributed/planner/shard_estimators.py @@ -1142,7 +1142,7 @@ def calculate_shard_storages( if compute_kernel in {EmbeddingComputeKernel.KEY_VALUE.value}: ddr_storage = 0 - optimizer_class = getattr(tensor, "_optimizer_class", None) + optimizer_class = getattr(tensor, "_optimizer_classes", [None])[0] hbm_specific_sizes: List[int] = _calculate_storage_specific_sizes( storage=hbm_storage,