Skip to content

Commit

Permalink
Fix normal_ and bernoulli (#670)
Browse files Browse the repository at this point in the history
* normal_fix

* fix binomial test
  • Loading branch information
Samantha Andow authored Apr 8, 2022
1 parent c17bf9a commit f5ce614
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
12 changes: 8 additions & 4 deletions functorch/csrc/BatchRulesBinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,14 +292,18 @@ std::tuple<Tensor,optional<int64_t>> cdist_backward_batch_rule(
return std::make_tuple(out, out_bdim);
}

Tensor binomial_wrapper(const Tensor& count, const Tensor& prob, c10::optional<Generator> gen) {
return at::binomial(count, prob.contiguous(), gen); // Bug in PyTorch, prob shouldn't need to be contiguous
}

TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) {
#define BINARY_RANDOM_POINTWISE(op) \
m.impl(#op, BINARY_RANDOM_POINTWISE_BATCH_RULE(ATEN_FN(op)));
#define BINARY_RANDOM_POINTWISE2(op, overload) \
m.impl(#op"."#overload, BINARY_RANDOM_POINTWISE_BATCH_RULE(ATEN_FN2(op, overload)));
m.impl(#op, BINARY_RANDOM_POINTWISE_BATCH_RULE(ATEN_FN(op)));
#define BINARY_RANDOM_POINTWISE2(op, overload) \
m.impl(#op"."#overload, BINARY_RANDOM_POINTWISE_BATCH_RULE(ATEN_FN2(op, overload)));

BINARY_RANDOM_POINTWISE2(normal, Tensor_Tensor);
BINARY_RANDOM_POINTWISE(binomial);
m.impl("binomial", BINARY_RANDOM_POINTWISE_BATCH_RULE(at::functorch::binomial_wrapper));
}

TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
Expand Down
7 changes: 3 additions & 4 deletions test/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3814,7 +3814,7 @@ def test_random_unary_inplace(self, device, use_generator, randomness, batched_i
lambda t, _: t.random_(**kwargs),
lambda t, _: t.random_(100, **kwargs),
lambda t, _: t.random_(-5, 100, **kwargs),
# lambda t, _: t.normal_(**kwargs), TODO(samdow): fix normal_ with -1 bdim
lambda t, _: t.normal_(**kwargs),
lambda t, _: t.bernoulli_(**kwargs),
lambda t, _: t.cauchy_(**kwargs),
lambda t, _: t.exponential_(**kwargs),
Expand Down Expand Up @@ -3851,7 +3851,7 @@ def test_random_unary_inplace(self, device, use_generator, randomness, batched_i
self.assertEqual(vmap_result, expected)
else:
if batched_input != "none":
passed_expected = passed_expected[0]
passed_expected = passed_expected[0].clone() # bug in pytorch, normal_ on views doesn't work
expected = op(passed_expected, always_batched)
self._assert_all_slices_equal(vmap_result)
for i in range(B0):
Expand Down Expand Up @@ -3923,8 +3923,7 @@ def test_random_binary_out_of_place(self, device, use_generator, randomness, bat
kwargs = {'generator': generator} if use_generator else {}
ops = [
lambda t, o, _: torch.normal(t, o, **kwargs),
# TODO(samdow): fix binomial
# lambda t, o, _: torch.binomial(t, (o - 0.5), **kwargs),
lambda t, o, _: torch.binomial(t, (o - 0.5), **kwargs),
]

B0 = 4
Expand Down

0 comments on commit f5ce614

Please sign in to comment.