From e072fd0991925adbca552d3197627e814648e338 Mon Sep 17 00:00:00 2001 From: samdow Date: Wed, 15 Jun 2022 13:41:58 -0400 Subject: [PATCH] embedding decomp --- functorch/_src/eager_transforms.py | 1 + functorch/csrc/DynamicLayer.cpp | 1 + 2 files changed, 2 insertions(+) diff --git a/functorch/_src/eager_transforms.py b/functorch/_src/eager_transforms.py index 0fa0559adf..e814dd8351 100644 --- a/functorch/_src/eager_transforms.py +++ b/functorch/_src/eager_transforms.py @@ -1340,5 +1340,6 @@ def _register_python_decomposition_vmap(decomp): _register_jit_decomposition(torch.ops.aten.log_sigmoid_forward.default) _register_jit_decomposition(torch.ops.aten.native_layer_norm_backward.default) _register_jit_decomposition(torch.ops.aten.native_batch_norm_backward.default, use_python=True) +_register_jit_decomposition(torch.ops.aten.embedding_dense_backward.default) _register_python_decomposition_vmap(torch.ops.aten.mse_loss_backward.default) _register_python_decomposition_vmap(torch.ops.aten.addr.default) diff --git a/functorch/csrc/DynamicLayer.cpp b/functorch/csrc/DynamicLayer.cpp index dc30a70bd7..930c5e201a 100644 --- a/functorch/csrc/DynamicLayer.cpp +++ b/functorch/csrc/DynamicLayer.cpp @@ -503,6 +503,7 @@ TORCH_LIBRARY_IMPL(aten, FT_DYNAMIC_LAYER_FRONT_MODE_KEY, m) { JVP_DECOMP(log_sigmoid_forward); JVP_DECOMP(native_layer_norm_backward); JVP_DECOMP(native_batch_norm_backward); + JVP_DECOMP(embedding_dense_backward); }