Skip to content

Commit

Permalink
Add uvm to allowed compute kernel for zch (pytorch#1695)
Browse files Browse the repository at this point in the history
Summary:

Allow zch to use uvm.

Differential Revision: D53598269
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed Feb 9, 2024
1 parent 47441cf commit 7c44f3d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
1 change: 1 addition & 0 deletions torchrec/distributed/mc_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ def compute_kernels(
return [
EmbeddingComputeKernel.FUSED.value,
EmbeddingComputeKernel.FUSED_UVM_CACHING.value,
EmbeddingComputeKernel.FUSED_UVM.value,
]

def sharding_types(self, compute_device_type: str) -> List[str]:
Expand Down
5 changes: 4 additions & 1 deletion torchrec/distributed/planner/tests/test_enumerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,10 @@ def test_filter_compute_kernels_mch_ebc(self) -> None:

self.assertEqual(
set(allowed_compute_kernels),
{EmbeddingComputeKernel.FUSED.value},
{
EmbeddingComputeKernel.FUSED.value,
EmbeddingComputeKernel.FUSED_UVM.value,
},
)

def test_filter_compute_kernels_mch_ebc_no_available(self) -> None:
Expand Down

0 comments on commit 7c44f3d

Please sign in to comment.