Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Commit

Permalink
Start adding broadcasting to compiler
Browse files Browse the repository at this point in the history
Summary:
This diff continues the work to add broadcasting semantics to the compiler and BMG. In this diff, I fix a bug in the type analyzer which was incorrectly detecting if two matrix types were broadcastable; it was only working if one of the types was a matrix size strictly bigger than the other, but broadcasting allows an n x 1 matrix to be broadcast with a 1 x m matrix regardless of the size of n and m.

I've added a test case that demonstrates that the existing vectorization algorithm gets as far as generating an incorrect BMG graph; the graph needs to have broadcast nodes added to it.  We'll fix that up in some upcoming diffs.

Reviewed By: AishwaryaSivaraman

Differential Revision: D40042673

fbshipit-source-id: 887c7df34763172c97cdf6ff03c9f14a22e754fb
  • Loading branch information
ericlippert authored and facebook-github-bot committed Oct 4, 2022
1 parent 18f7a02 commit 0c14b8d
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 22 deletions.
43 changes: 21 additions & 22 deletions src/beanmachine/ppl/compiler/lattice_typer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,19 @@
}


def _broadcast_size(left: bt.BMGMatrixType, right: bt.BMGMatrixType):
def helper(x, y):
return x == 1 or y == 1 or x == y

if not helper(left.rows, right.rows):
return None
if not helper(left.columns, right.columns):
return None
rows = left.rows if right.rows == 1 else right.rows
cols = left.columns if right.columns == 1 else right.columns
return (rows, cols)


class LatticeTyper(TyperBase[bt.BMGLatticeType]):

_dispatch: Dict[type, Callable]
Expand Down Expand Up @@ -220,41 +233,27 @@ def _lattice_type_for_element_type(
else:
raise ValueError("unrecognized element type")

def __assert_can_be_broadcast_to(
self, small: bt.BMGMatrixType, big: bt.BMGMatrixType
):
if small.rows == 1:
assert small.columns == 1 or small.columns == big.columns
else:
assert small.rows == big.rows

if small.columns == 1:
assert small.rows == 1 or small.rows == big.rows
else:
assert small.columns == big.columns

def _type_binary_elementwise_op(
self, node: bn.BinaryOperatorNode
) -> bt.BMGLatticeType:
left_type = self[node.left]
right_type = self[node.right]
assert isinstance(left_type, bt.BMGMatrixType)
assert isinstance(right_type, bt.BMGMatrixType)
r_count = right_type.columns * right_type.rows
l_count = left_type.columns * left_type.rows
if l_count < r_count:
self.__assert_can_be_broadcast_to(left_type, right_type)
else:
self.__assert_can_be_broadcast_to(right_type, left_type)
bsize = _broadcast_size(left_type, right_type)
if bsize is None:
return bt.Untypable
rows, cols = bsize

op_type = bt.supremum(
self._lattice_type_for_element_type(left_type.element_type),
self._lattice_type_for_element_type(right_type.element_type),
)
if bt.supremum(op_type, bt.NegativeReal) == bt.NegativeReal:
return bt.NegativeRealMatrix(left_type.rows, left_type.columns)
return bt.NegativeRealMatrix(rows, cols)
if bt.supremum(op_type, bt.PositiveReal) == bt.PositiveReal:
return bt.PositiveRealMatrix(left_type.rows, left_type.columns)
return bt.RealMatrix(left_type.rows, left_type.columns)
return bt.PositiveRealMatrix(rows, cols)
return bt.RealMatrix(rows, cols)

_matrix_tpe_constructors = {
bt.Real: lambda r, c: bt.RealMatrix(r, c),
Expand Down
150 changes: 150 additions & 0 deletions tests/ppl/compiler/broadcast_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import beanmachine.ppl as bm
from beanmachine.ppl.inference import BMGInference
from torch import tensor
from torch.distributions import Normal


@bm.random_variable
def n(n):
return Normal(0, 1)


@bm.random_variable
def n12():
return Normal(tensor([n(3), n(4)]), 1.0)


@bm.random_variable
def n21():
return Normal(tensor([[n(1)], [n(2)]]), 1.0)


@bm.functional
def broadcast_add():
return n12() + n21()


class BroadcastTest(unittest.TestCase):
def test_broadcast_add(self) -> None:
self.maxDiff = None
observations = {}
queries = [broadcast_add()]

observed = BMGInference().to_dot(queries, observations, after_transform=False)

# The model before the rewrite:

expected = """
digraph "graph" {
N00[label=0.0];
N01[label=1.0];
N02[label=Normal];
N03[label=Sample];
N04[label=Sample];
N05[label=Tensor];
N06[label=1.0];
N07[label=Normal];
N08[label=Sample];
N09[label=Sample];
N10[label=Sample];
N11[label=Tensor];
N12[label=Normal];
N13[label=Sample];
N14[label="+"];
N15[label=Query];
N00 -> N02;
N01 -> N02;
N02 -> N03;
N02 -> N04;
N02 -> N09;
N02 -> N10;
N03 -> N05;
N04 -> N05;
N05 -> N07;
N06 -> N07;
N06 -> N12;
N07 -> N08;
N08 -> N14;
N09 -> N11;
N10 -> N11;
N11 -> N12;
N12 -> N13;
N13 -> N14;
N14 -> N15;
}
"""
self.assertEqual(expected.strip(), observed.strip())

# After:

observed = BMGInference().to_dot(queries, observations, after_transform=True)
expected = """
digraph "graph" {
N00[label=0.0];
N01[label=1.0];
N02[label=Normal];
N03[label=Sample];
N04[label=Sample];
N05[label=Normal];
N06[label=Sample];
N07[label=Normal];
N08[label=Sample];
N09[label=Sample];
N10[label=Sample];
N11[label=Normal];
N12[label=Sample];
N13[label=Normal];
N14[label=Sample];
N15[label=2];
N16[label=1];
N17[label=ToMatrix];
N18[label=ToMatrix];
N19[label=MatrixAdd];
N20[label=Query];
N00 -> N02;
N01 -> N02;
N01 -> N05;
N01 -> N07;
N01 -> N11;
N01 -> N13;
N02 -> N03;
N02 -> N04;
N02 -> N09;
N02 -> N10;
N03 -> N05;
N04 -> N07;
N05 -> N06;
N06 -> N17;
N07 -> N08;
N08 -> N17;
N09 -> N11;
N10 -> N13;
N11 -> N12;
N12 -> N18;
N13 -> N14;
N14 -> N18;
N15 -> N17;
N15 -> N18;
N16 -> N17;
N16 -> N18;
N17 -> N19;
N18 -> N19;
N19 -> N20;
}
"""
self.assertEqual(expected.strip(), observed.strip())

# BMG

with self.assertRaises(ValueError):
g, _ = BMGInference().to_graph(queries, observations)
# observed = g.to_dot()
# expected = ""
# self.assertEqual(expected.strip(), observed.strip())

0 comments on commit 0c14b8d

Please sign in to comment.