Skip to content

Commit

Permalink
Add uvm caching with prefetch test for group_tables
Browse files Browse the repository at this point in the history
Summary: Adding a test after enable embedding dim bucketizer diff.

Differential Revision: D50483066

fbshipit-source-id: 6d43b0c5b7e6a59cb9b4377f38d6a0c2be1d4950
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed Oct 20, 2023
1 parent e6687d5 commit 9965b12
Showing 1 changed file with 58 additions and 0 deletions.
58 changes: 58 additions & 0 deletions torchrec/distributed/tests/test_embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,61 @@ def test_group_tables_per_rank_for_uvm_caching(self) -> None:
["table_3", "table_25", "table_26"],
]
self.assertEqual(table_names_by_groups, expected_table_names_by_groups)

def test_group_tables_per_rank_for_uvm_caching_with_prefetch(self) -> None:
self.compute_kernels = [
EmbeddingComputeKernel.FUSED_UVM_CACHING for _ in range(self.num_tables)
]

# add prefetch
for fused_params in self.fused_params_groups:
fused_params["prefetch_pipeline"] = True

embedding_tables = self.generate_embedding_tables()

# since we don't have access to _group_tables_per_rank
tables_per_rank: List[List[ShardedEmbeddingTable]] = [embedding_tables]

# taking only the list for the first rank
table_groups: List[GroupedEmbeddingConfig] = group_tables(tables_per_rank)[0]
table_names_by_groups = [
table_group.table_names() for table_group in table_groups
]

expected_table_names_by_groups = [
["table_38"],
["table_32"],
["table_33"],
["table_8"],
["table_24", "table_41"],
["table_34"],
["table_36"],
["table_2"],
["table_22"],
["table_23"],
["table_30"],
["table_46"],
["table_27"],
["table_19"],
["table_40"],
["table_16", "table_18"],
["table_21"],
["table_20", "table_31", "table_37", "table_48"],
["table_42"],
["table_14"],
["table_5", "table_17", "table_28"],
["table_0"],
["table_6", "table_7", "table_10", "table_45"],
["table_47"],
["table_1", "table_35", "table_49"],
["table_4", "table_44"],
["table_13"],
["table_29"],
["table_11", "table_12"],
["table_9", "table_15"],
["table_43"],
["table_39"],
["table_3", "table_26"],
["table_25"],
]
self.assertEqual(table_names_by_groups, expected_table_names_by_groups)

0 comments on commit 9965b12

Please sign in to comment.