From 6c0fc177b0d9d26e1355376aac2bcf5021f7c72a Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Tue, 1 Nov 2022 19:22:28 -0500 Subject: [PATCH] Replace use of broadcastable with shape in aesara.tensor.shape --- aesara/tensor/rewriting/shape.py | 45 +++++++++++++++++++++----------- tests/tensor/test_shape.py | 31 +++++++++++----------- 2 files changed, 45 insertions(+), 31 deletions(-) diff --git a/aesara/tensor/rewriting/shape.py b/aesara/tensor/rewriting/shape.py index 5844367344..87d77b1322 100644 --- a/aesara/tensor/rewriting/shape.py +++ b/aesara/tensor/rewriting/shape.py @@ -364,8 +364,8 @@ def set_shape(self, r, s, override=False): else: shape_vars.append(self.unpack(s[i], r)) assert all( - not hasattr(r.type, "broadcastable") - or not r.type.broadcastable[i] + not hasattr(r.type, "shape") + or r.type.shape[i] != 1 or self.lscalar_one.equals(shape_vars[i]) or self.lscalar_one.equals(extract_constant(shape_vars[i])) for i in range(r.type.ndim) @@ -447,9 +447,9 @@ def update_shape(self, r, other_r): merged_shape.append(other_shape[i]) assert all( ( - not hasattr(r.type, "broadcastable") - or not r.type.broadcastable[i] - and not other_r.type.broadcastable[i] + not hasattr(r.type, "shape") + or r.type.shape[i] != 1 + and other_r.type.shape[i] != 1 ) or self.lscalar_one.equals(merged_shape[i]) or self.lscalar_one.equals( @@ -474,8 +474,8 @@ def set_shape_i(self, r, i, s_i): else: new_shape.append(s_j) assert all( - not hasattr(r.type, "broadcastable") - or not r.type.broadcastable[idx] + not hasattr(r.type, "shape") + or r.type.shape[idx] != 1 or self.lscalar_one.equals(new_shape[idx]) or self.lscalar_one.equals(extract_constant(new_shape[idx])) for idx in range(r.type.ndim) @@ -781,7 +781,11 @@ def f(fgraph, node): # We should try to figure out why we lost the information about this # constant value... but in the meantime, better not apply this # rewrite. - if rval.broadcastable == node.outputs[0].broadcastable: + if rval.type.ndim == node.outputs[0].type.ndim and all( + s1 == s1 + for s1, s2 in zip(rval.type.shape, node.outputs[0].type.shape) + if s1 == 1 or s2 == 1 + ): return [rval] else: return False @@ -816,7 +820,11 @@ def local_useless_reshape(fgraph, node): if ( inp.type.ndim == 1 and output.type.ndim == 1 - and inp.type.broadcastable == output.type.broadcastable + and all( + s1 == s2 + for s1, s2 in zip(inp.type.shape, output.type.shape) + if s1 == 1 or s2 == 1 + ) ): return [inp] @@ -862,7 +870,7 @@ def local_useless_reshape(fgraph, node): shape_match[dim] = True continue - # Match 1 if input.broadcastable[dim] is True + # Match 1 if input.type.shape[dim] == 1 cst_outshp_i = extract_constant(outshp_i, only_process_constants=1) if inp.type.shape[dim] == 1 and cst_outshp_i == 1: shape_match[dim] = True @@ -931,7 +939,11 @@ def local_reshape_to_dimshuffle(fgraph, node): if index != output.type.ndim: inner = op.__class__(len(new_output_shape))(inp, new_output_shape) copy_stack_trace(output, inner) - new_node = [DimShuffle(inner.type.broadcastable, dimshuffle_new_order)(inner)] + new_node = [ + DimShuffle(tuple(s == 1 for s in inner.type.shape), dimshuffle_new_order)( + inner + ) + ] copy_stack_trace(output, new_node) return new_node @@ -1096,10 +1108,9 @@ def local_useless_dimshuffle_in_reshape(fgraph, node): new_order = node.inputs[0].owner.op.new_order inp = node.inputs[0].owner.inputs[0] - broadcastables = node.inputs[0].broadcastable new_order_of_nonbroadcast = [] - for i, bd in zip(new_order, broadcastables): - if not bd: + for i, s in zip(new_order, node.inputs[0].type.shape): + if s != 1: new_order_of_nonbroadcast.append(i) no_change_in_order = all( new_order_of_nonbroadcast[i] <= new_order_of_nonbroadcast[i + 1] @@ -1123,7 +1134,11 @@ def local_useless_unbroadcast(fgraph, node): """ if isinstance(node.op, Unbroadcast): x = node.inputs[0] - if x.broadcastable == node.outputs[0].broadcastable: + if x.type.ndim == node.outputs[0].type.ndim and all( + s1 == s2 + for s1, s2 in zip(x.type.shape, node.outputs[0].type.shape) + if s1 == 1 or s2 == 1 + ): # No broadcastable flag was modified # No need to copy over stack trace, # because x should already have a stack trace. diff --git a/tests/tensor/test_shape.py b/tests/tensor/test_shape.py index 6f396d5b18..e93829d6d2 100644 --- a/tests/tensor/test_shape.py +++ b/tests/tensor/test_shape.py @@ -55,13 +55,13 @@ def test_shape_basic(): s = shape([]) - assert s.type.broadcastable == (True,) + assert s.type.shape == (1,) s = shape([10]) - assert s.type.broadcastable == (True,) + assert s.type.shape == (1,) s = shape(lscalar()) - assert s.type.broadcastable == (False,) + assert s.type.shape == (0,) class MyType(Type): def filter(self, *args, **kwargs): @@ -71,7 +71,7 @@ def __eq__(self, other): return isinstance(other, MyType) and other.thingy == self.thingy s = shape(Variable(MyType(), None)) - assert s.type.broadcastable == (False,) + assert s.type.shape == (None,) s = shape(np.array(1)) assert np.array_equal(eval_outputs([s]), []) @@ -119,15 +119,14 @@ def test_basics(self): b = dmatrix() d = dmatrix() - # basic to 1 dim(without list) - c = reshape(b, as_tensor_variable(6), ndim=1) - f = self.function([b], c) - b_val1 = np.asarray([[0, 1, 2], [3, 4, 5]]) c_val1 = np.asarray([0, 1, 2, 3, 4, 5]) b_val2 = b_val1.T c_val2 = np.asarray([0, 3, 1, 4, 2, 5]) + # basic to 1 dim(without list) + c = reshape(b, as_tensor_variable(6), ndim=1) + f = self.function([b], c) f_out1 = f(b_val1) f_out2 = f(b_val2) assert np.array_equal(f_out1, c_val1), (f_out1, c_val1) @@ -191,10 +190,10 @@ def just_vals(v): f(np.asarray([[0, 1, 2], [3, 4, 5]])), np.asarray([[[0], [1], [2]], [[3], [4], [5]]]), ) - assert f.maker.fgraph.toposort()[-1].outputs[0].type.broadcastable == ( - False, - False, - True, + assert f.maker.fgraph.toposort()[-1].outputs[0].type.shape == ( + None, + None, + 1, ) # test broadcast flag for constant value of 1 if it cannot be @@ -205,10 +204,10 @@ def just_vals(v): f(np.asarray([[0, 1, 2], [3, 4, 5]])), np.asarray([[[0], [1]], [[2], [3]], [[4], [5]]]), ) - assert f.maker.fgraph.toposort()[-1].outputs[0].type.broadcastable == ( - False, - False, - True, + assert f.maker.fgraph.toposort()[-1].outputs[0].type.shape == ( + None, + None, + 1, ) def test_m1(self):