diff --git a/aesara/tensor/basic.py b/aesara/tensor/basic.py index e6db9019cd..4762d903d2 100644 --- a/aesara/tensor/basic.py +++ b/aesara/tensor/basic.py @@ -28,7 +28,7 @@ from aesara.graph.fg import FunctionGraph from aesara.graph.op import Op from aesara.graph.rewriting.utils import rewrite_graph -from aesara.graph.type import Type +from aesara.graph.type import HasShape, Type from aesara.link.c.op import COp from aesara.link.c.params_type import ParamsType from aesara.misc.safe_asarray import _asarray @@ -348,8 +348,8 @@ def get_scalar_constant_value( if isinstance(inp, Constant): return np.asarray(np.shape(inp.data)[i]) # The shape of a broadcastable dimension is 1 - if hasattr(inp.type, "broadcastable") and inp.type.broadcastable[i]: - return np.asarray(1) + if isinstance(inp.type, HasShape) and inp.type.shape[i] is not None: + return np.asarray(inp.type.shape[i]) # Don't act as the constant_folding optimization here as this # fct is used too early in the optimization phase. This would @@ -502,21 +502,16 @@ def get_scalar_constant_value( owner.inputs[1], max_recur=max_recur ) grandparent = leftmost_parent.owner.inputs[0] - gp_broadcastable = grandparent.type.broadcastable + gp_shape = grandparent.type.shape ndim = grandparent.type.ndim if grandparent.owner and isinstance( grandparent.owner.op, Unbroadcast ): - ggp_broadcastable = grandparent.owner.inputs[0].broadcastable - l = [ - b1 or b2 - for b1, b2 in zip(ggp_broadcastable, gp_broadcastable) - ] - gp_broadcastable = tuple(l) + ggp_shape = grandparent.owner.inputs[0].type.shape + l = [s1 == 1 or s2 == 1 for s1, s2 in zip(ggp_shape, gp_shape)] + gp_shape = tuple(l) - assert ndim == len(gp_broadcastable) - - if not (idx < len(gp_broadcastable)): + if not (idx < ndim): msg = ( "get_scalar_constant_value detected " f"deterministic IndexError: x.shape[{int(idx)}] " @@ -528,8 +523,9 @@ def get_scalar_constant_value( msg += f" x={v}" raise ValueError(msg) - if gp_broadcastable[idx]: - return np.asarray(1) + gp_shape_val = gp_shape[idx] + if gp_shape_val is not None and gp_shape_val > -1: + return np.asarray(gp_shape_val) if isinstance(grandparent, Constant): return np.asarray(np.shape(grandparent.data)[idx]) @@ -1511,15 +1507,16 @@ def grad(self, inputs, grads): axis_kept = [] for i, (ib, gb) in enumerate( zip( - inputs[0].broadcastable, + inputs[0].type.shape, # We need the dimensions corresponding to x - grads[0].broadcastable[-inputs[0].ndim :], + grads[0].type.shape[-inputs[0].ndim :], ) ): - if ib and not gb: + if ib == 1 and gb != 1: axis_broadcasted.append(i + n_axes_to_sum) else: axis_kept.append(i) + gx = gz.sum(axis=axis + axis_broadcasted) if axis_broadcasted: new_order = ["x"] * x.ndim @@ -1865,11 +1862,14 @@ def transpose(x, axes=None): """ _x = as_tensor_variable(x) + if axes is None: - axes = list(range((_x.ndim - 1), -1, -1)) - ret = DimShuffle(_x.broadcastable, axes)(_x) - if _x.name and axes == list(range((_x.ndim - 1), -1, -1)): + axes = list(range((_x.type.ndim - 1), -1, -1)) + ret = DimShuffle(tuple(s == 1 for s in _x.type.shape), axes)(_x) + + if _x.name and axes == list(range((_x.type.ndim - 1), -1, -1)): ret.name = _x.name + ".T" + return ret @@ -3207,11 +3207,11 @@ def _rec_perform(self, node, x, y, inverse, out, curdim): if xs0 == ys0: for i in range(xs0): self._rec_perform(node, x[i], y[i], inverse, out[i], curdim + 1) - elif ys0 == 1 and node.inputs[1].type.broadcastable[curdim]: + elif ys0 == 1 and node.inputs[1].type.shape[curdim] == 1: # Broadcast y for i in range(xs0): self._rec_perform(node, x[i], y[0], inverse, out[i], curdim + 1) - elif xs0 == 1 and node.inputs[0].type.broadcastable[curdim]: + elif xs0 == 1 and node.inputs[0].type.shape[curdim] == 1: # Broadcast x for i in range(ys0): self._rec_perform(node, x[0], y[i], inverse, out[i], curdim + 1) @@ -3270,7 +3270,7 @@ def grad(self, inp, grads): broadcasted_dims = [ dim for dim in range(gz.type.ndim) - if x.type.broadcastable[dim] and not gz.type.broadcastable[dim] + if x.type.shape[dim] == 1 and gz.type.shape[dim] != 1 ] gx = Sum(axis=broadcasted_dims)(gx) @@ -3285,8 +3285,13 @@ def grad(self, inp, grads): newdims.append(i) i += 1 - gx = DimShuffle(gx.type.broadcastable, newdims)(gx) - assert gx.type.broadcastable == x.type.broadcastable + gx = DimShuffle(tuple(s == 1 for s in gx.type.shape), newdims)(gx) + assert gx.type.ndim == x.type.ndim + assert all( + s1 == s2 + for s1, s2 in zip(gx.type.shape, x.type.shape) + if s1 == 1 or s2 == 1 + ) # if x is an integer type, then so is the output. # this means f(x+eps) = f(x) so the gradient with respect diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index eda86d2227..a38c4fa1da 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -458,10 +458,10 @@ def test_make_vector_fail(self): res = MakeVector("int32")(a, b) res = MakeVector()(a) - assert res.broadcastable == (True,) + assert res.type.shape == (1,) res = MakeVector()() - assert res.broadcastable == (False,) + assert res.type.shape == (0,) def test_infer_shape(self): adscal = dscalar() @@ -1665,18 +1665,18 @@ def test_broadcastable_flag_assignment_mixed_otheraxes(self): a = self.shared(a_val, shape=(None, None, 1)) b = self.shared(b_val, shape=(1, None, 1)) c = self.join_op(1, a, b) - assert c.type.broadcastable[0] and c.type.broadcastable[2] - assert not c.type.broadcastable[1] + assert c.type.shape[0] == 1 and c.type.shape[2] == 1 + assert c.type.shape[1] != 1 # Opt can remplace the int by an Aesara constant c = self.join_op(constant(1), a, b) - assert c.type.broadcastable[0] and c.type.broadcastable[2] - assert not c.type.broadcastable[1] + assert c.type.shape[0] == 1 and c.type.shape[2] == 1 + assert c.type.shape[1] != 1 # In case futur opt insert other useless stuff c = self.join_op(cast(constant(1), dtype="int32"), a, b) - assert c.type.broadcastable[0] and c.type.broadcastable[2] - assert not c.type.broadcastable[1] + assert c.type.shape[0] == 1 and c.type.shape[2] == 1 + assert c.type.shape[1] != 1 f = function([], c, mode=self.mode) topo = f.maker.fgraph.toposort() @@ -1703,7 +1703,7 @@ def test_broadcastable_flag_assignment_mixed_thisaxes(self): a = self.shared(a_val, shape=(None, None, 1)) b = self.shared(b_val, shape=(1, None, 1)) c = self.join_op(0, a, b) - assert not c.type.broadcastable[0] + assert c.type.shape[0] != 1 f = function([], c, mode=self.mode) topo = f.maker.fgraph.toposort() @@ -1736,7 +1736,7 @@ def test_broadcastable_flags_all_broadcastable_on_joinaxis(self): a = self.shared(a_val, shape=(1, None, 1)) b = self.shared(b_val, shape=(1, None, 1)) c = self.join_op(0, a, b) - assert not c.type.broadcastable[0] + assert c.type.shape[0] != 1 f = function([], c, mode=self.mode) topo = f.maker.fgraph.toposort() @@ -1754,9 +1754,9 @@ def test_broadcastable_single_input_broadcastable_dimension(self): a_val = rng.random((1, 4, 1)).astype(self.floatX) a = self.shared(a_val, shape=(1, None, 1)) b = self.join_op(0, a) - assert b.type.broadcastable[0] - assert b.type.broadcastable[2] - assert not b.type.broadcastable[1] + assert b.type.shape[0] == 1 + assert b.type.shape[2] == 1 + assert b.type.shape[1] != 1 f = function([], b, mode=self.mode) topo = f.maker.fgraph.toposort() @@ -1782,13 +1782,13 @@ def test_broadcastable_flags_many_dims_and_inputs(self): d = TensorType(dtype=self.floatX, shape=(1, None, 1, 1, None, 1))() e = TensorType(dtype=self.floatX, shape=(1, None, 1, None, None, 1))() f = self.join_op(0, a, b, c, d, e) - fb = f.type.broadcastable + fb = tuple(s == 1 for s in f.type.shape) assert not fb[0] and fb[1] and fb[2] and fb[3] and not fb[4] and fb[5] g = self.join_op(1, a, b, c, d, e) - gb = g.type.broadcastable + gb = tuple(s == 1 for s in g.type.shape) assert gb[0] and not gb[1] and gb[2] and gb[3] and not gb[4] and gb[5] h = self.join_op(4, a, b, c, d, e) - hb = h.type.broadcastable + hb = tuple(s == 1 for s in h.type.shape) assert hb[0] and hb[1] and hb[2] and hb[3] and not hb[4] and hb[5] f = function([a, b, c, d, e], f, mode=self.mode) @@ -1981,8 +1981,8 @@ def test_TensorFromScalar(): s = aes.constant(56) t = tensor_from_scalar(s) assert t.owner.op is tensor_from_scalar - assert t.type.broadcastable == (), t.type.broadcastable - assert t.type.ndim == 0, t.type.ndim + assert t.type.shape == () + assert t.type.ndim == 0 assert t.type.dtype == s.type.dtype v = eval_outputs([t]) @@ -2129,23 +2129,23 @@ def test_flatten_broadcastable(): inp = TensorType("float64", shape=(None, None, None, None))() out = flatten(inp, ndim=2) - assert out.broadcastable == (False, False) + assert out.type.shape == (None, None) inp = TensorType("float64", shape=(None, None, None, 1))() out = flatten(inp, ndim=2) - assert out.broadcastable == (False, False) + assert out.type.shape == (None, None) inp = TensorType("float64", shape=(None, 1, None, 1))() out = flatten(inp, ndim=2) - assert out.broadcastable == (False, False) + assert out.type.shape == (None, None) inp = TensorType("float64", shape=(None, 1, 1, 1))() out = flatten(inp, ndim=2) - assert out.broadcastable == (False, True) + assert out.type.shape == (None, 1) inp = TensorType("float64", shape=(1, None, 1, 1))() out = flatten(inp, ndim=3) - assert out.broadcastable == (True, False, True) + assert out.type.shape == (1, None, 1) def test_flatten_ndim_invalid(): @@ -2938,8 +2938,8 @@ def permute_fixed(s_input): def test_3b_2(self): # Test permute_row_elements on a more complex broadcasting pattern: - # input.type.broadcastable = (False, True, False), - # p.type.broadcastable = (False, False). + # input.type.shape = (None, 1, None), + # p.type.shape = (None, None). input = TensorType("floatX", shape=(None, 1, None))() p = imatrix() @@ -4046,7 +4046,7 @@ def test_broadcasted(self): B = np.asarray(np.random.random((4, 1)), dtype="float32") for m in self.modes: f = function([a, b], choose(a, b, mode=m)) - assert choose(a, b, mode=m).broadcastable[0] + assert choose(a, b, mode=m).type.shape[0] == 1 t_c = f(A, B) n_c = np.choose(A, B, mode=m) assert np.allclose(t_c, n_c)