diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index 39c149ad87..907f89d84a 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -423,7 +423,17 @@ def update_shape(self, r, other_r): # This mean the shape is equivalent # We do not want to do the ancestor check in those cases merged_shape.append(r_shape[i]) - elif r_shape[i] in ancestors([other_shape[i]]): + elif any( + ( + r_shape[i] == anc + or ( + anc.owner + and isinstance(anc.owner.op, Shape) + and anc.owner.inputs[0] == r + ) + ) + for anc in ancestors([other_shape[i]]) + ): # Another case where we want to use r_shape[i] is when # other_shape[i] actually depends on r_shape[i]. In that case, # we do not want to substitute an expression with another that diff --git a/tests/tensor/rewriting/test_shape.py b/tests/tensor/rewriting/test_shape.py index c0fd7513b3..e094aa8ed5 100644 --- a/tests/tensor/rewriting/test_shape.py +++ b/tests/tensor/rewriting/test_shape.py @@ -15,7 +15,7 @@ from pytensor.graph.rewriting.basic import check_stack_trace, node_rewriter, out2in from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.graph.type import Type -from pytensor.tensor.basic import as_tensor_variable +from pytensor.tensor.basic import alloc, as_tensor_variable from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.math import add, exp, maximum from pytensor.tensor.rewriting.basic import register_specialize @@ -239,6 +239,25 @@ def test_no_shapeopt(self): # FIXME: This is not a good test. f([[1, 2], [2, 3]]) + def test_shape_of_useless_alloc(self): + """Test that local_shape_to_shape_i does not create circular graph. + + Regression test for #565 + """ + alpha = vector(shape=(None,), dtype="float64") + channel = vector(shape=(None,), dtype="float64") + + broadcast_channel = alloc( + channel, + maximum( + shape(alpha)[0], + shape(channel)[0], + ), + ) + out = shape(broadcast_channel) + fn = function([alpha, channel], out) + assert fn([1.0, 2, 3], [1.0, 2, 3]) == (3,) + class TestReshape: def setup_method(self):