Skip to content

Commit

Permalink
silu batch rule
Browse files Browse the repository at this point in the history
  • Loading branch information
zou3519 committed Apr 7, 2022
1 parent cca6486 commit c17bf9a
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
1 change: 1 addition & 0 deletions functorch/csrc/BatchRulesBinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
BINARY_POINTWISE(softshrink_backward);
BINARY_POINTWISE(tanh_backward);
BINARY_POINTWISE(threshold_backward);
BINARY_POINTWISE(silu_backward);

using TensorScalarInplaceT = Tensor& (Tensor::*)(const Tensor&, const Scalar&) const;
using ScalarScalarInplaceT = Tensor& (Tensor::*)(const Scalar&, const Scalar&) const;
Expand Down
12 changes: 12 additions & 0 deletions test/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,6 +1420,18 @@ def test_copy_(self):
with self.assertRaisesRegex(RuntimeError, 'inplace'):
vmap(Tensor.copy_, in_dims=(None, 0))(x, y)

def test_silu_backward(self):
test = self._vmap_test
device = 'cpu'
getter = TensorFactory.randp1
B0 = 7
op = torch.ops.aten.silu_backward

# Single vmap: op(Tensor, Tensor)
test(op, (getter([B0, 3], device), getter([B0, 3], device)))
test(op, (getter([], device), getter([B0], device)), in_dims=(None, 0))
test(op, (getter([2, B0], device), getter([2], device)), in_dims=(1, None))

@parametrize('case', [
subtest(_make_case(torch.add), name='add'),
subtest(_make_case(lambda x, y: x + y), name='add_dunder'),
Expand Down

0 comments on commit c17bf9a

Please sign in to comment.