Skip to content

Commit

Permalink
Replace use of broadcastable with shape in aesara.tensor.basic
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Nov 4, 2022
1 parent 92389c2 commit 74bd45f
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 52 deletions.
57 changes: 31 additions & 26 deletions aesara/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.rewriting.utils import rewrite_graph
from aesara.graph.type import Type
from aesara.graph.type import HasShape, Type
from aesara.link.c.op import COp
from aesara.link.c.params_type import ParamsType
from aesara.misc.safe_asarray import _asarray
Expand Down Expand Up @@ -348,8 +348,8 @@ def get_scalar_constant_value(
if isinstance(inp, Constant):
return np.asarray(np.shape(inp.data)[i])
# The shape of a broadcastable dimension is 1
if hasattr(inp.type, "broadcastable") and inp.type.broadcastable[i]:
return np.asarray(1)
if isinstance(inp.type, HasShape) and inp.type.shape[i] is not None:
return np.asarray(inp.type.shape[i])

# Don't act as the constant_folding optimization here as this
# fct is used too early in the optimization phase. This would
Expand Down Expand Up @@ -502,21 +502,16 @@ def get_scalar_constant_value(
owner.inputs[1], max_recur=max_recur
)
grandparent = leftmost_parent.owner.inputs[0]
gp_broadcastable = grandparent.type.broadcastable
gp_shape = grandparent.type.shape
ndim = grandparent.type.ndim
if grandparent.owner and isinstance(
grandparent.owner.op, Unbroadcast
):
ggp_broadcastable = grandparent.owner.inputs[0].broadcastable
l = [
b1 or b2
for b1, b2 in zip(ggp_broadcastable, gp_broadcastable)
]
gp_broadcastable = tuple(l)
ggp_shape = grandparent.owner.inputs[0].type.shape
l = [s1 == 1 or s2 == 1 for s1, s2 in zip(ggp_shape, gp_shape)]
gp_shape = tuple(l)

assert ndim == len(gp_broadcastable)

if not (idx < len(gp_broadcastable)):
if not (idx < ndim):
msg = (
"get_scalar_constant_value detected "
f"deterministic IndexError: x.shape[{int(idx)}] "
Expand All @@ -528,8 +523,9 @@ def get_scalar_constant_value(
msg += f" x={v}"
raise ValueError(msg)

if gp_broadcastable[idx]:
return np.asarray(1)
gp_shape_val = gp_shape[idx]
if gp_shape_val is not None and gp_shape_val > -1:
return np.asarray(gp_shape_val)

if isinstance(grandparent, Constant):
return np.asarray(np.shape(grandparent.data)[idx])
Expand Down Expand Up @@ -1511,15 +1507,16 @@ def grad(self, inputs, grads):
axis_kept = []
for i, (ib, gb) in enumerate(
zip(
inputs[0].broadcastable,
inputs[0].type.shape,
# We need the dimensions corresponding to x
grads[0].broadcastable[-inputs[0].ndim :],
grads[0].type.shape[-inputs[0].ndim :],
)
):
if ib and not gb:
if ib == 1 and gb != 1:
axis_broadcasted.append(i + n_axes_to_sum)
else:
axis_kept.append(i)

gx = gz.sum(axis=axis + axis_broadcasted)
if axis_broadcasted:
new_order = ["x"] * x.ndim
Expand Down Expand Up @@ -1865,11 +1862,14 @@ def transpose(x, axes=None):
"""
_x = as_tensor_variable(x)

if axes is None:
axes = list(range((_x.ndim - 1), -1, -1))
ret = DimShuffle(_x.broadcastable, axes)(_x)
if _x.name and axes == list(range((_x.ndim - 1), -1, -1)):
axes = list(range((_x.type.ndim - 1), -1, -1))
ret = DimShuffle(tuple(s == 1 for s in _x.type.shape), axes)(_x)

if _x.name and axes == list(range((_x.type.ndim - 1), -1, -1)):
ret.name = _x.name + ".T"

return ret


Expand Down Expand Up @@ -3207,11 +3207,11 @@ def _rec_perform(self, node, x, y, inverse, out, curdim):
if xs0 == ys0:
for i in range(xs0):
self._rec_perform(node, x[i], y[i], inverse, out[i], curdim + 1)
elif ys0 == 1 and node.inputs[1].type.broadcastable[curdim]:
elif ys0 == 1 and node.inputs[1].type.shape[curdim] == 1:
# Broadcast y
for i in range(xs0):
self._rec_perform(node, x[i], y[0], inverse, out[i], curdim + 1)
elif xs0 == 1 and node.inputs[0].type.broadcastable[curdim]:
elif xs0 == 1 and node.inputs[0].type.shape[curdim] == 1:
# Broadcast x
for i in range(ys0):
self._rec_perform(node, x[0], y[i], inverse, out[i], curdim + 1)
Expand Down Expand Up @@ -3270,7 +3270,7 @@ def grad(self, inp, grads):
broadcasted_dims = [
dim
for dim in range(gz.type.ndim)
if x.type.broadcastable[dim] and not gz.type.broadcastable[dim]
if x.type.shape[dim] == 1 and gz.type.shape[dim] != 1
]
gx = Sum(axis=broadcasted_dims)(gx)

Expand All @@ -3285,8 +3285,13 @@ def grad(self, inp, grads):
newdims.append(i)
i += 1

gx = DimShuffle(gx.type.broadcastable, newdims)(gx)
assert gx.type.broadcastable == x.type.broadcastable
gx = DimShuffle(tuple(s == 1 for s in gx.type.shape), newdims)(gx)
assert gx.type.ndim == x.type.ndim
assert all(
s1 == s2
for s1, s2 in zip(gx.type.shape, x.type.shape)
if s1 == 1 or s2 == 1
)

# if x is an integer type, then so is the output.
# this means f(x+eps) = f(x) so the gradient with respect
Expand Down
52 changes: 26 additions & 26 deletions tests/tensor/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,10 +458,10 @@ def test_make_vector_fail(self):
res = MakeVector("int32")(a, b)

res = MakeVector()(a)
assert res.broadcastable == (True,)
assert res.type.shape == (1,)

res = MakeVector()()
assert res.broadcastable == (False,)
assert res.type.shape == (0,)

def test_infer_shape(self):
adscal = dscalar()
Expand Down Expand Up @@ -1665,18 +1665,18 @@ def test_broadcastable_flag_assignment_mixed_otheraxes(self):
a = self.shared(a_val, shape=(None, None, 1))
b = self.shared(b_val, shape=(1, None, 1))
c = self.join_op(1, a, b)
assert c.type.broadcastable[0] and c.type.broadcastable[2]
assert not c.type.broadcastable[1]
assert c.type.shape[0] == 1 and c.type.shape[2] == 1
assert c.type.shape[1] != 1

# Opt can remplace the int by an Aesara constant
c = self.join_op(constant(1), a, b)
assert c.type.broadcastable[0] and c.type.broadcastable[2]
assert not c.type.broadcastable[1]
assert c.type.shape[0] == 1 and c.type.shape[2] == 1
assert c.type.shape[1] != 1

# In case futur opt insert other useless stuff
c = self.join_op(cast(constant(1), dtype="int32"), a, b)
assert c.type.broadcastable[0] and c.type.broadcastable[2]
assert not c.type.broadcastable[1]
assert c.type.shape[0] == 1 and c.type.shape[2] == 1
assert c.type.shape[1] != 1

f = function([], c, mode=self.mode)
topo = f.maker.fgraph.toposort()
Expand All @@ -1703,7 +1703,7 @@ def test_broadcastable_flag_assignment_mixed_thisaxes(self):
a = self.shared(a_val, shape=(None, None, 1))
b = self.shared(b_val, shape=(1, None, 1))
c = self.join_op(0, a, b)
assert not c.type.broadcastable[0]
assert c.type.shape[0] != 1

f = function([], c, mode=self.mode)
topo = f.maker.fgraph.toposort()
Expand Down Expand Up @@ -1736,7 +1736,7 @@ def test_broadcastable_flags_all_broadcastable_on_joinaxis(self):
a = self.shared(a_val, shape=(1, None, 1))
b = self.shared(b_val, shape=(1, None, 1))
c = self.join_op(0, a, b)
assert not c.type.broadcastable[0]
assert c.type.shape[0] != 1

f = function([], c, mode=self.mode)
topo = f.maker.fgraph.toposort()
Expand All @@ -1754,9 +1754,9 @@ def test_broadcastable_single_input_broadcastable_dimension(self):
a_val = rng.random((1, 4, 1)).astype(self.floatX)
a = self.shared(a_val, shape=(1, None, 1))
b = self.join_op(0, a)
assert b.type.broadcastable[0]
assert b.type.broadcastable[2]
assert not b.type.broadcastable[1]
assert b.type.shape[0] == 1
assert b.type.shape[2] == 1
assert b.type.shape[1] != 1

f = function([], b, mode=self.mode)
topo = f.maker.fgraph.toposort()
Expand All @@ -1782,13 +1782,13 @@ def test_broadcastable_flags_many_dims_and_inputs(self):
d = TensorType(dtype=self.floatX, shape=(1, None, 1, 1, None, 1))()
e = TensorType(dtype=self.floatX, shape=(1, None, 1, None, None, 1))()
f = self.join_op(0, a, b, c, d, e)
fb = f.type.broadcastable
fb = tuple(s == 1 for s in f.type.shape)
assert not fb[0] and fb[1] and fb[2] and fb[3] and not fb[4] and fb[5]
g = self.join_op(1, a, b, c, d, e)
gb = g.type.broadcastable
gb = tuple(s == 1 for s in g.type.shape)
assert gb[0] and not gb[1] and gb[2] and gb[3] and not gb[4] and gb[5]
h = self.join_op(4, a, b, c, d, e)
hb = h.type.broadcastable
hb = tuple(s == 1 for s in h.type.shape)
assert hb[0] and hb[1] and hb[2] and hb[3] and not hb[4] and hb[5]

f = function([a, b, c, d, e], f, mode=self.mode)
Expand Down Expand Up @@ -1981,8 +1981,8 @@ def test_TensorFromScalar():
s = aes.constant(56)
t = tensor_from_scalar(s)
assert t.owner.op is tensor_from_scalar
assert t.type.broadcastable == (), t.type.broadcastable
assert t.type.ndim == 0, t.type.ndim
assert t.type.shape == ()
assert t.type.ndim == 0
assert t.type.dtype == s.type.dtype

v = eval_outputs([t])
Expand Down Expand Up @@ -2129,23 +2129,23 @@ def test_flatten_broadcastable():

inp = TensorType("float64", shape=(None, None, None, None))()
out = flatten(inp, ndim=2)
assert out.broadcastable == (False, False)
assert out.type.shape == (None, None)

inp = TensorType("float64", shape=(None, None, None, 1))()
out = flatten(inp, ndim=2)
assert out.broadcastable == (False, False)
assert out.type.shape == (None, None)

inp = TensorType("float64", shape=(None, 1, None, 1))()
out = flatten(inp, ndim=2)
assert out.broadcastable == (False, False)
assert out.type.shape == (None, None)

inp = TensorType("float64", shape=(None, 1, 1, 1))()
out = flatten(inp, ndim=2)
assert out.broadcastable == (False, True)
assert out.type.shape == (None, 1)

inp = TensorType("float64", shape=(1, None, 1, 1))()
out = flatten(inp, ndim=3)
assert out.broadcastable == (True, False, True)
assert out.type.shape == (1, None, 1)


def test_flatten_ndim_invalid():
Expand Down Expand Up @@ -2938,8 +2938,8 @@ def permute_fixed(s_input):

def test_3b_2(self):
# Test permute_row_elements on a more complex broadcasting pattern:
# input.type.broadcastable = (False, True, False),
# p.type.broadcastable = (False, False).
# input.type.shape = (None, 1, None),
# p.type.shape = (None, None).

input = TensorType("floatX", shape=(None, 1, None))()
p = imatrix()
Expand Down Expand Up @@ -4046,7 +4046,7 @@ def test_broadcasted(self):
B = np.asarray(np.random.random((4, 1)), dtype="float32")
for m in self.modes:
f = function([a, b], choose(a, b, mode=m))
assert choose(a, b, mode=m).broadcastable[0]
assert choose(a, b, mode=m).type.shape[0] == 1
t_c = f(A, B)
n_c = np.choose(A, B, mode=m)
assert np.allclose(t_c, n_c)
Expand Down

0 comments on commit 74bd45f

Please sign in to comment.