diff --git a/torchrec/distributed/tests/test_embedding_sharding.py b/torchrec/distributed/tests/test_embedding_sharding.py index 6bfcf846f..632393939 100644 --- a/torchrec/distributed/tests/test_embedding_sharding.py +++ b/torchrec/distributed/tests/test_embedding_sharding.py @@ -219,3 +219,61 @@ def test_group_tables_per_rank_for_uvm_caching(self) -> None: ["table_3", "table_25", "table_26"], ] self.assertEqual(table_names_by_groups, expected_table_names_by_groups) + + def test_group_tables_per_rank_for_uvm_caching_with_prefetch(self) -> None: + self.compute_kernels = [ + EmbeddingComputeKernel.FUSED_UVM_CACHING for _ in range(self.num_tables) + ] + + # add prefetch + for fused_params in self.fused_params_groups: + fused_params["prefetch_pipeline"] = True + + embedding_tables = self.generate_embedding_tables() + + # since we don't have access to _group_tables_per_rank + tables_per_rank: List[List[ShardedEmbeddingTable]] = [embedding_tables] + + # taking only the list for the first rank + table_groups: List[GroupedEmbeddingConfig] = group_tables(tables_per_rank)[0] + table_names_by_groups = [ + table_group.table_names() for table_group in table_groups + ] + + expected_table_names_by_groups = [ + ["table_38"], + ["table_32"], + ["table_33"], + ["table_8"], + ["table_24", "table_41"], + ["table_34"], + ["table_36"], + ["table_2"], + ["table_22"], + ["table_23"], + ["table_30"], + ["table_46"], + ["table_27"], + ["table_19"], + ["table_40"], + ["table_16", "table_18"], + ["table_21"], + ["table_20", "table_31", "table_37", "table_48"], + ["table_42"], + ["table_14"], + ["table_5", "table_17", "table_28"], + ["table_0"], + ["table_6", "table_7", "table_10", "table_45"], + ["table_47"], + ["table_1", "table_35", "table_49"], + ["table_4", "table_44"], + ["table_13"], + ["table_29"], + ["table_11", "table_12"], + ["table_9", "table_15"], + ["table_43"], + ["table_39"], + ["table_3", "table_26"], + ["table_25"], + ] + self.assertEqual(table_names_by_groups, expected_table_names_by_groups)