Skip to content

Commit

Permalink
Replace use of broadcastable with shape in aesara.tensor.subtensor
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Nov 4, 2022
1 parent b586063 commit fa01971
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions tests/tensor/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def test_ok_elem_2(self):
def test_ok_row(self):
n = self.shared(np.arange(6, dtype=self.dtype).reshape((2, 3)))
t = n[1]
assert not any(n.type.broadcastable)
assert not any(s == 1 for s in n.type.shape)
assert isinstance(t.owner.op, Subtensor)
tval = self.eval_output_and_check(t)
assert tval.shape == (3,)
Expand All @@ -475,7 +475,7 @@ def test_ok_col(self):
n = self.shared(np.arange(6, dtype=self.dtype).reshape((2, 3)))
t = n[:, 0]
assert isinstance(t.owner.op, Subtensor)
assert not any(n.type.broadcastable)
assert not any(s == 1 for s in n.type.shape)
tval = self.eval_output_and_check(t)
assert tval.shape == (2,)
assert np.all(tval == [0, 3])
Expand Down Expand Up @@ -1773,15 +1773,17 @@ def test_index_into_vec_w_vec(self):
def test_index_into_vec_w_matrix(self):
a = self.v[self.ix2]
assert a.dtype == self.v.dtype, (a.dtype, self.v.dtype)
assert a.broadcastable == self.ix2.broadcastable, (
a.broadcastable,
self.ix2.broadcastable,
assert a.type.ndim == self.ix2.type.ndim
assert all(
s1 == s2
for s1, s2 in zip(a.type.shape, self.ix2.type.shape)
if s1 == 1 or s2 == 1
)

def test_index_into_mat_w_row(self):
a = self.m[self.ixr]
assert a.dtype == self.m.dtype, (a.dtype, self.m.dtype)
assert a.broadcastable == (True, False, False)
assert a.type.shape == (1, None, None)

def test_index_w_int_and_vec(self):
# like test_ok_list, but with a single index on the first one
Expand Down Expand Up @@ -2446,7 +2448,7 @@ def test_AdvancedSubtensor_bool(self):
)

abs_res = n[~isinf(n)]
assert abs_res.broadcastable == (False,)
assert abs_res.type.shape == (None,)


@config.change_flags(compute_test_value="raise")
Expand All @@ -2467,9 +2469,7 @@ def idx_as_tensor(x):
def bcast_shape_tuple(x):
if not hasattr(x, "shape"):
return x
return tuple(
s if not bcast else 1 for s, bcast in zip(tuple(x.shape), x.broadcastable)
)
return tuple(s if ss != 1 else 1 for s, ss in zip(tuple(x.shape), x.type.shape))


test_idx = np.ix_(np.array([True, True]), np.array([True]), np.array([True, True]))
Expand Down

0 comments on commit fa01971

Please sign in to comment.