Skip to content

Commit

Permalink
Replace use of broadcastable with shape in aesara.tensor.shape
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Nov 4, 2022
1 parent 74bd45f commit 6c0fc17
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 31 deletions.
45 changes: 30 additions & 15 deletions aesara/tensor/rewriting/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,8 @@ def set_shape(self, r, s, override=False):
else:
shape_vars.append(self.unpack(s[i], r))
assert all(
not hasattr(r.type, "broadcastable")
or not r.type.broadcastable[i]
not hasattr(r.type, "shape")
or r.type.shape[i] != 1
or self.lscalar_one.equals(shape_vars[i])
or self.lscalar_one.equals(extract_constant(shape_vars[i]))
for i in range(r.type.ndim)
Expand Down Expand Up @@ -447,9 +447,9 @@ def update_shape(self, r, other_r):
merged_shape.append(other_shape[i])
assert all(
(
not hasattr(r.type, "broadcastable")
or not r.type.broadcastable[i]
and not other_r.type.broadcastable[i]
not hasattr(r.type, "shape")
or r.type.shape[i] != 1
and other_r.type.shape[i] != 1
)
or self.lscalar_one.equals(merged_shape[i])
or self.lscalar_one.equals(
Expand All @@ -474,8 +474,8 @@ def set_shape_i(self, r, i, s_i):
else:
new_shape.append(s_j)
assert all(
not hasattr(r.type, "broadcastable")
or not r.type.broadcastable[idx]
not hasattr(r.type, "shape")
or r.type.shape[idx] != 1
or self.lscalar_one.equals(new_shape[idx])
or self.lscalar_one.equals(extract_constant(new_shape[idx]))
for idx in range(r.type.ndim)
Expand Down Expand Up @@ -781,7 +781,11 @@ def f(fgraph, node):
# We should try to figure out why we lost the information about this
# constant value... but in the meantime, better not apply this
# rewrite.
if rval.broadcastable == node.outputs[0].broadcastable:
if rval.type.ndim == node.outputs[0].type.ndim and all(
s1 == s1
for s1, s2 in zip(rval.type.shape, node.outputs[0].type.shape)
if s1 == 1 or s2 == 1
):
return [rval]
else:
return False
Expand Down Expand Up @@ -816,7 +820,11 @@ def local_useless_reshape(fgraph, node):
if (
inp.type.ndim == 1
and output.type.ndim == 1
and inp.type.broadcastable == output.type.broadcastable
and all(
s1 == s2
for s1, s2 in zip(inp.type.shape, output.type.shape)
if s1 == 1 or s2 == 1
)
):
return [inp]

Expand Down Expand Up @@ -862,7 +870,7 @@ def local_useless_reshape(fgraph, node):
shape_match[dim] = True
continue

# Match 1 if input.broadcastable[dim] is True
# Match 1 if input.type.shape[dim] == 1
cst_outshp_i = extract_constant(outshp_i, only_process_constants=1)
if inp.type.shape[dim] == 1 and cst_outshp_i == 1:
shape_match[dim] = True
Expand Down Expand Up @@ -931,7 +939,11 @@ def local_reshape_to_dimshuffle(fgraph, node):
if index != output.type.ndim:
inner = op.__class__(len(new_output_shape))(inp, new_output_shape)
copy_stack_trace(output, inner)
new_node = [DimShuffle(inner.type.broadcastable, dimshuffle_new_order)(inner)]
new_node = [
DimShuffle(tuple(s == 1 for s in inner.type.shape), dimshuffle_new_order)(
inner
)
]
copy_stack_trace(output, new_node)
return new_node

Expand Down Expand Up @@ -1096,10 +1108,9 @@ def local_useless_dimshuffle_in_reshape(fgraph, node):

new_order = node.inputs[0].owner.op.new_order
inp = node.inputs[0].owner.inputs[0]
broadcastables = node.inputs[0].broadcastable
new_order_of_nonbroadcast = []
for i, bd in zip(new_order, broadcastables):
if not bd:
for i, s in zip(new_order, node.inputs[0].type.shape):
if s != 1:
new_order_of_nonbroadcast.append(i)
no_change_in_order = all(
new_order_of_nonbroadcast[i] <= new_order_of_nonbroadcast[i + 1]
Expand All @@ -1123,7 +1134,11 @@ def local_useless_unbroadcast(fgraph, node):
"""
if isinstance(node.op, Unbroadcast):
x = node.inputs[0]
if x.broadcastable == node.outputs[0].broadcastable:
if x.type.ndim == node.outputs[0].type.ndim and all(
s1 == s2
for s1, s2 in zip(x.type.shape, node.outputs[0].type.shape)
if s1 == 1 or s2 == 1
):
# No broadcastable flag was modified
# No need to copy over stack trace,
# because x should already have a stack trace.
Expand Down
31 changes: 15 additions & 16 deletions tests/tensor/test_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@

def test_shape_basic():
s = shape([])
assert s.type.broadcastable == (True,)
assert s.type.shape == (1,)

s = shape([10])
assert s.type.broadcastable == (True,)
assert s.type.shape == (1,)

s = shape(lscalar())
assert s.type.broadcastable == (False,)
assert s.type.shape == (0,)

class MyType(Type):
def filter(self, *args, **kwargs):
Expand All @@ -71,7 +71,7 @@ def __eq__(self, other):
return isinstance(other, MyType) and other.thingy == self.thingy

s = shape(Variable(MyType(), None))
assert s.type.broadcastable == (False,)
assert s.type.shape == (None,)

s = shape(np.array(1))
assert np.array_equal(eval_outputs([s]), [])
Expand Down Expand Up @@ -119,15 +119,14 @@ def test_basics(self):
b = dmatrix()
d = dmatrix()

# basic to 1 dim(without list)
c = reshape(b, as_tensor_variable(6), ndim=1)
f = self.function([b], c)

b_val1 = np.asarray([[0, 1, 2], [3, 4, 5]])
c_val1 = np.asarray([0, 1, 2, 3, 4, 5])
b_val2 = b_val1.T
c_val2 = np.asarray([0, 3, 1, 4, 2, 5])

# basic to 1 dim(without list)
c = reshape(b, as_tensor_variable(6), ndim=1)
f = self.function([b], c)
f_out1 = f(b_val1)
f_out2 = f(b_val2)
assert np.array_equal(f_out1, c_val1), (f_out1, c_val1)
Expand Down Expand Up @@ -191,10 +190,10 @@ def just_vals(v):
f(np.asarray([[0, 1, 2], [3, 4, 5]])),
np.asarray([[[0], [1], [2]], [[3], [4], [5]]]),
)
assert f.maker.fgraph.toposort()[-1].outputs[0].type.broadcastable == (
False,
False,
True,
assert f.maker.fgraph.toposort()[-1].outputs[0].type.shape == (
None,
None,
1,
)

# test broadcast flag for constant value of 1 if it cannot be
Expand All @@ -205,10 +204,10 @@ def just_vals(v):
f(np.asarray([[0, 1, 2], [3, 4, 5]])),
np.asarray([[[0], [1]], [[2], [3]], [[4], [5]]]),
)
assert f.maker.fgraph.toposort()[-1].outputs[0].type.broadcastable == (
False,
False,
True,
assert f.maker.fgraph.toposort()[-1].outputs[0].type.shape == (
None,
None,
1,
)

def test_m1(self):
Expand Down

0 comments on commit 6c0fc17

Please sign in to comment.