diff --git a/functorch/csrc/BatchRulesBinaryOps.cpp b/functorch/csrc/BatchRulesBinaryOps.cpp index 6aaf206f9..34b1d8881 100644 --- a/functorch/csrc/BatchRulesBinaryOps.cpp +++ b/functorch/csrc/BatchRulesBinaryOps.cpp @@ -422,6 +422,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { BINARY_POINTWISE(softshrink_backward); BINARY_POINTWISE(tanh_backward); BINARY_POINTWISE(threshold_backward); + BINARY_POINTWISE(silu_backward); using TensorScalarInplaceT = Tensor& (Tensor::*)(const Tensor&, const Scalar&) const; using ScalarScalarInplaceT = Tensor& (Tensor::*)(const Scalar&, const Scalar&) const; diff --git a/test/test_vmap.py b/test/test_vmap.py index 2907a3991..071d89977 100644 --- a/test/test_vmap.py +++ b/test/test_vmap.py @@ -1420,6 +1420,18 @@ def test_copy_(self): with self.assertRaisesRegex(RuntimeError, 'inplace'): vmap(Tensor.copy_, in_dims=(None, 0))(x, y) + def test_silu_backward(self): + test = self._vmap_test + device = 'cpu' + getter = TensorFactory.randp1 + B0 = 7 + op = torch.ops.aten.silu_backward + + # Single vmap: op(Tensor, Tensor) + test(op, (getter([B0, 3], device), getter([B0, 3], device))) + test(op, (getter([], device), getter([B0], device)), in_dims=(None, 0)) + test(op, (getter([2, B0], device), getter([2], device)), in_dims=(1, None)) + @parametrize('case', [ subtest(_make_case(torch.add), name='add'), subtest(_make_case(lambda x, y: x + y), name='add_dunder'),