Skip to content

Commit

Permalink
Redefine FBGEMM targets with gpu_cpp_library [13/N]
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#252

- Redefine `permute_pooled_embeddings_*` targets using `gpu_cpp_library`

Differential Revision: D63053938
  • Loading branch information
q10 authored and facebook-github-bot committed Sep 20, 2024
1 parent fa50993 commit 90a64df
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 11 deletions.
14 changes: 4 additions & 10 deletions fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import List, Optional

import torch
from fbgemm_gpu.utils.loader import load_torch_module

try:
# pyre-ignore[21]
Expand All @@ -19,16 +20,9 @@
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_cpu"
)
try:
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_gpu"
)
except OSError:
# This is for forward compatibility (new torch.package + old backend)
# We should be able to remove it after this diff is picked up by all backend
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_gpu_cuda"
)
load_torch_module(
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_gpu"
)


class PermutePooledEmbeddings:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* LICENSE file in the root directory of this source tree.
*/

#include "include/fbgemm_gpu/permute_pooled_embedding_function.h"
#include "fbgemm_gpu/permute_pooled_embedding_function.h"

using Tensor = at::Tensor;

Expand Down

0 comments on commit 90a64df

Please sign in to comment.