diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index b556430c4..fc2a514b5 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -112,7 +112,9 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]: cache_load_factor = 0.2 local_rows_sum: int = sum(table.local_rows for table in config.embedding_tables) - ssd_tbe_params["cache_sets"] = int(cache_load_factor * local_rows_sum / ASSOC) + ssd_tbe_params["cache_sets"] = max( + int(cache_load_factor * local_rows_sum / ASSOC), 1 + ) # populate init min and max if (