From 0c14b8d4c8739705b7410aea03828dce4d4d76a6 Mon Sep 17 00:00:00 2001 From: Eric Lippert Date: Tue, 4 Oct 2022 11:02:58 -0700 Subject: [PATCH] Start adding broadcasting to compiler Summary: This diff continues the work to add broadcasting semantics to the compiler and BMG. In this diff, I fix a bug in the type analyzer which was incorrectly detecting if two matrix types were broadcastable; it was only working if one of the types was a matrix size strictly bigger than the other, but broadcasting allows an n x 1 matrix to be broadcast with a 1 x m matrix regardless of the size of n and m. I've added a test case that demonstrates that the existing vectorization algorithm gets as far as generating an incorrect BMG graph; the graph needs to have broadcast nodes added to it. We'll fix that up in some upcoming diffs. Reviewed By: AishwaryaSivaraman Differential Revision: D40042673 fbshipit-source-id: 887c7df34763172c97cdf6ff03c9f14a22e754fb --- src/beanmachine/ppl/compiler/lattice_typer.py | 43 +++-- tests/ppl/compiler/broadcast_test.py | 150 ++++++++++++++++++ 2 files changed, 171 insertions(+), 22 deletions(-) create mode 100644 tests/ppl/compiler/broadcast_test.py diff --git a/src/beanmachine/ppl/compiler/lattice_typer.py b/src/beanmachine/ppl/compiler/lattice_typer.py index 856a5caee0..b4e753787d 100644 --- a/src/beanmachine/ppl/compiler/lattice_typer.py +++ b/src/beanmachine/ppl/compiler/lattice_typer.py @@ -156,6 +156,19 @@ } +def _broadcast_size(left: bt.BMGMatrixType, right: bt.BMGMatrixType): + def helper(x, y): + return x == 1 or y == 1 or x == y + + if not helper(left.rows, right.rows): + return None + if not helper(left.columns, right.columns): + return None + rows = left.rows if right.rows == 1 else right.rows + cols = left.columns if right.columns == 1 else right.columns + return (rows, cols) + + class LatticeTyper(TyperBase[bt.BMGLatticeType]): _dispatch: Dict[type, Callable] @@ -220,19 +233,6 @@ def _lattice_type_for_element_type( else: raise ValueError("unrecognized element type") - def __assert_can_be_broadcast_to( - self, small: bt.BMGMatrixType, big: bt.BMGMatrixType - ): - if small.rows == 1: - assert small.columns == 1 or small.columns == big.columns - else: - assert small.rows == big.rows - - if small.columns == 1: - assert small.rows == 1 or small.rows == big.rows - else: - assert small.columns == big.columns - def _type_binary_elementwise_op( self, node: bn.BinaryOperatorNode ) -> bt.BMGLatticeType: @@ -240,21 +240,20 @@ def _type_binary_elementwise_op( right_type = self[node.right] assert isinstance(left_type, bt.BMGMatrixType) assert isinstance(right_type, bt.BMGMatrixType) - r_count = right_type.columns * right_type.rows - l_count = left_type.columns * left_type.rows - if l_count < r_count: - self.__assert_can_be_broadcast_to(left_type, right_type) - else: - self.__assert_can_be_broadcast_to(right_type, left_type) + bsize = _broadcast_size(left_type, right_type) + if bsize is None: + return bt.Untypable + rows, cols = bsize + op_type = bt.supremum( self._lattice_type_for_element_type(left_type.element_type), self._lattice_type_for_element_type(right_type.element_type), ) if bt.supremum(op_type, bt.NegativeReal) == bt.NegativeReal: - return bt.NegativeRealMatrix(left_type.rows, left_type.columns) + return bt.NegativeRealMatrix(rows, cols) if bt.supremum(op_type, bt.PositiveReal) == bt.PositiveReal: - return bt.PositiveRealMatrix(left_type.rows, left_type.columns) - return bt.RealMatrix(left_type.rows, left_type.columns) + return bt.PositiveRealMatrix(rows, cols) + return bt.RealMatrix(rows, cols) _matrix_tpe_constructors = { bt.Real: lambda r, c: bt.RealMatrix(r, c), diff --git a/tests/ppl/compiler/broadcast_test.py b/tests/ppl/compiler/broadcast_test.py new file mode 100644 index 0000000000..9a25180714 --- /dev/null +++ b/tests/ppl/compiler/broadcast_test.py @@ -0,0 +1,150 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import beanmachine.ppl as bm +from beanmachine.ppl.inference import BMGInference +from torch import tensor +from torch.distributions import Normal + + +@bm.random_variable +def n(n): + return Normal(0, 1) + + +@bm.random_variable +def n12(): + return Normal(tensor([n(3), n(4)]), 1.0) + + +@bm.random_variable +def n21(): + return Normal(tensor([[n(1)], [n(2)]]), 1.0) + + +@bm.functional +def broadcast_add(): + return n12() + n21() + + +class BroadcastTest(unittest.TestCase): + def test_broadcast_add(self) -> None: + self.maxDiff = None + observations = {} + queries = [broadcast_add()] + + observed = BMGInference().to_dot(queries, observations, after_transform=False) + + # The model before the rewrite: + + expected = """ +digraph "graph" { + N00[label=0.0]; + N01[label=1.0]; + N02[label=Normal]; + N03[label=Sample]; + N04[label=Sample]; + N05[label=Tensor]; + N06[label=1.0]; + N07[label=Normal]; + N08[label=Sample]; + N09[label=Sample]; + N10[label=Sample]; + N11[label=Tensor]; + N12[label=Normal]; + N13[label=Sample]; + N14[label="+"]; + N15[label=Query]; + N00 -> N02; + N01 -> N02; + N02 -> N03; + N02 -> N04; + N02 -> N09; + N02 -> N10; + N03 -> N05; + N04 -> N05; + N05 -> N07; + N06 -> N07; + N06 -> N12; + N07 -> N08; + N08 -> N14; + N09 -> N11; + N10 -> N11; + N11 -> N12; + N12 -> N13; + N13 -> N14; + N14 -> N15; +} +""" + self.assertEqual(expected.strip(), observed.strip()) + + # After: + + observed = BMGInference().to_dot(queries, observations, after_transform=True) + expected = """ +digraph "graph" { + N00[label=0.0]; + N01[label=1.0]; + N02[label=Normal]; + N03[label=Sample]; + N04[label=Sample]; + N05[label=Normal]; + N06[label=Sample]; + N07[label=Normal]; + N08[label=Sample]; + N09[label=Sample]; + N10[label=Sample]; + N11[label=Normal]; + N12[label=Sample]; + N13[label=Normal]; + N14[label=Sample]; + N15[label=2]; + N16[label=1]; + N17[label=ToMatrix]; + N18[label=ToMatrix]; + N19[label=MatrixAdd]; + N20[label=Query]; + N00 -> N02; + N01 -> N02; + N01 -> N05; + N01 -> N07; + N01 -> N11; + N01 -> N13; + N02 -> N03; + N02 -> N04; + N02 -> N09; + N02 -> N10; + N03 -> N05; + N04 -> N07; + N05 -> N06; + N06 -> N17; + N07 -> N08; + N08 -> N17; + N09 -> N11; + N10 -> N13; + N11 -> N12; + N12 -> N18; + N13 -> N14; + N14 -> N18; + N15 -> N17; + N15 -> N18; + N16 -> N17; + N16 -> N18; + N17 -> N19; + N18 -> N19; + N19 -> N20; +} +""" + self.assertEqual(expected.strip(), observed.strip()) + + # BMG + + with self.assertRaises(ValueError): + g, _ = BMGInference().to_graph(queries, observations) + # observed = g.to_dot() + # expected = "" + # self.assertEqual(expected.strip(), observed.strip())