Skip to content

Commit

Permalink
Squeeze dyn dim allow compatible (#3035)
Browse files Browse the repository at this point in the history
  • Loading branch information
CharlieL7 authored May 10, 2024
1 parent 9cdd85d commit 48b49ac
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
10 changes: 8 additions & 2 deletions src/include/migraphx/op/squeeze.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,18 @@ struct squeeze
auto input_shape = inputs[0];
if(input_shape.dynamic())
{
// Allow for any dynamic_dimension that intersects with {1, 1}.
// Assuming that the shape at run-time will be compatible.
if(std::any_of(axes.begin(), axes.end(), [&](auto axis) {
return input_shape.dyn_dims()[axis] != 1;
return not input_shape.dyn_dims()
.at(axis)
.intersection(shape::dynamic_dimension{1, 1})
.has_value();
;
}))
{
MIGRAPHX_THROW(
"SQUEEZE: dynamic axis dimension should be equal to {1, 1, 0} or {1, 1, 1}");
"SQUEEZE: dynamic axis dimension should have an intersection with {1, 1}");
}
std::vector<shape::dynamic_dimension> dyn_dims = {};
if(axes.empty())
Expand Down
6 changes: 5 additions & 1 deletion test/op_shape_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4569,7 +4569,11 @@ TEST_CASE(test_squeeze_dyn)
migraphx::shape s3{migraphx::shape::float_type, {{1, 4}, {3, 3}, {3, 3}}};
expect_shape(s3, migraphx::make_op("squeeze"), s1);

throws_shape(migraphx::make_op("squeeze", {{"axes", {0}}}), s1);
// allowing to squeeze dynamic_dimension that intersect with {1, 1}
migraphx::shape s4{migraphx::shape::float_type, {{1, 1}, {3, 3}, {1, 1}, {3, 3}}};
expect_shape(s4, migraphx::make_op("squeeze", {{"axes", {0}}}), s1);

throws_shape(migraphx::make_op("squeeze", {{"axes", {2}}}), s1);
}

TEST_CASE(test_squeeze_dyn_neg_axes)
Expand Down

0 comments on commit 48b49ac

Please sign in to comment.