Skip to content

Commit

Permalink
Fix bug where ShapeFeature would create circular shape graph
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jan 2, 2024
1 parent f951743 commit e180927
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
12 changes: 11 additions & 1 deletion pytensor/tensor/rewriting/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,17 @@ def update_shape(self, r, other_r):
# This mean the shape is equivalent
# We do not want to do the ancestor check in those cases
merged_shape.append(r_shape[i])
elif r_shape[i] in ancestors([other_shape[i]]):
elif any(
(
r_shape[i] == anc
or (
anc.owner
and isinstance(anc.owner.op, Shape)
and anc.owner.inputs[0] == r
)
)
for anc in ancestors([other_shape[i]])
):
# Another case where we want to use r_shape[i] is when
# other_shape[i] actually depends on r_shape[i]. In that case,
# we do not want to substitute an expression with another that
Expand Down
21 changes: 20 additions & 1 deletion tests/tensor/rewriting/test_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pytensor.graph.rewriting.basic import check_stack_trace, node_rewriter, out2in
from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.graph.type import Type
from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.basic import alloc, as_tensor_variable
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import add, exp, maximum
from pytensor.tensor.rewriting.basic import register_specialize
Expand Down Expand Up @@ -239,6 +239,25 @@ def test_no_shapeopt(self):
# FIXME: This is not a good test.
f([[1, 2], [2, 3]])

def test_shape_of_useless_alloc(self):
"""Test that local_shape_to_shape_i does not create circular graph.
Regression test for #565
"""
alpha = vector(shape=(None,), dtype="float64")
channel = vector(shape=(None,), dtype="float64")

broadcast_channel = alloc(
channel,
maximum(
shape(alpha)[0],
shape(channel)[0],
),
)
out = shape(broadcast_channel)
fn = function([alpha, channel], out)
assert fn([1.0, 2, 3], [1.0, 2, 3]) == (3,)


class TestReshape:
def setup_method(self):
Expand Down

0 comments on commit e180927

Please sign in to comment.