From f5ce61446d38d4c2b54fab111fa79c3156207de3 Mon Sep 17 00:00:00 2001 From: Samantha Andow Date: Fri, 8 Apr 2022 09:32:16 -0400 Subject: [PATCH] Fix normal_ and bernoulli (#670) * normal_fix * fix binomial test --- functorch/csrc/BatchRulesBinaryOps.cpp | 12 ++++++++---- test/test_vmap.py | 7 +++---- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/functorch/csrc/BatchRulesBinaryOps.cpp b/functorch/csrc/BatchRulesBinaryOps.cpp index 34b1d8881..912d25e61 100644 --- a/functorch/csrc/BatchRulesBinaryOps.cpp +++ b/functorch/csrc/BatchRulesBinaryOps.cpp @@ -292,14 +292,18 @@ std::tuple> cdist_backward_batch_rule( return std::make_tuple(out, out_bdim); } +Tensor binomial_wrapper(const Tensor& count, const Tensor& prob, c10::optional gen) { + return at::binomial(count, prob.contiguous(), gen); // Bug in PyTorch, prob shouldn't need to be contiguous +} + TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) { #define BINARY_RANDOM_POINTWISE(op) \ - m.impl(#op, BINARY_RANDOM_POINTWISE_BATCH_RULE(ATEN_FN(op))); -#define BINARY_RANDOM_POINTWISE2(op, overload) \ - m.impl(#op"."#overload, BINARY_RANDOM_POINTWISE_BATCH_RULE(ATEN_FN2(op, overload))); + m.impl(#op, BINARY_RANDOM_POINTWISE_BATCH_RULE(ATEN_FN(op))); + #define BINARY_RANDOM_POINTWISE2(op, overload) \ + m.impl(#op"."#overload, BINARY_RANDOM_POINTWISE_BATCH_RULE(ATEN_FN2(op, overload))); BINARY_RANDOM_POINTWISE2(normal, Tensor_Tensor); - BINARY_RANDOM_POINTWISE(binomial); + m.impl("binomial", BINARY_RANDOM_POINTWISE_BATCH_RULE(at::functorch::binomial_wrapper)); } TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { diff --git a/test/test_vmap.py b/test/test_vmap.py index 071d89977..5cf1c862a 100644 --- a/test/test_vmap.py +++ b/test/test_vmap.py @@ -3814,7 +3814,7 @@ def test_random_unary_inplace(self, device, use_generator, randomness, batched_i lambda t, _: t.random_(**kwargs), lambda t, _: t.random_(100, **kwargs), lambda t, _: t.random_(-5, 100, **kwargs), - # lambda t, _: t.normal_(**kwargs), TODO(samdow): fix normal_ with -1 bdim + lambda t, _: t.normal_(**kwargs), lambda t, _: t.bernoulli_(**kwargs), lambda t, _: t.cauchy_(**kwargs), lambda t, _: t.exponential_(**kwargs), @@ -3851,7 +3851,7 @@ def test_random_unary_inplace(self, device, use_generator, randomness, batched_i self.assertEqual(vmap_result, expected) else: if batched_input != "none": - passed_expected = passed_expected[0] + passed_expected = passed_expected[0].clone() # bug in pytorch, normal_ on views doesn't work expected = op(passed_expected, always_batched) self._assert_all_slices_equal(vmap_result) for i in range(B0): @@ -3923,8 +3923,7 @@ def test_random_binary_out_of_place(self, device, use_generator, randomness, bat kwargs = {'generator': generator} if use_generator else {} ops = [ lambda t, o, _: torch.normal(t, o, **kwargs), - # TODO(samdow): fix binomial - # lambda t, o, _: torch.binomial(t, (o - 0.5), **kwargs), + lambda t, o, _: torch.binomial(t, (o - 0.5), **kwargs), ] B0 = 4