diff --git a/torchrec/distributed/tests/test_embedding_sharding.py b/torchrec/distributed/tests/test_embedding_sharding.py index 9ec4b52bc..c62bd797e 100644 --- a/torchrec/distributed/tests/test_embedding_sharding.py +++ b/torchrec/distributed/tests/test_embedding_sharding.py @@ -36,6 +36,11 @@ class TestGetWeightedAverageCacheLoadFactor(unittest.TestCase): def test_get_avg_cache_load_factor_hbm(self) -> None: + from typing import * + + x: Optional[int | str] = None + + cache_load_factors = [random.random() for _ in range(5)] embedding_tables: List[ShardedEmbeddingTable] = [ ShardedEmbeddingTable(