From 10914b97a86fc20702695d669c0dc9b6b0d0c8ed Mon Sep 17 00:00:00 2001 From: charlie Date: Fri, 3 May 2024 09:56:54 -0500 Subject: [PATCH 1/3] initial --- src/include/migraphx/op/squeeze.hpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/include/migraphx/op/squeeze.hpp b/src/include/migraphx/op/squeeze.hpp index 5d09c1250f4..08e0aba36bd 100644 --- a/src/include/migraphx/op/squeeze.hpp +++ b/src/include/migraphx/op/squeeze.hpp @@ -59,12 +59,14 @@ struct squeeze auto input_shape = inputs[0]; if(input_shape.dynamic()) { + // Allow for any dynamic_dimension that intersects with {1, 1}. + // Assuming that the shape at run-time will be compatible. if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { - return input_shape.dyn_dims()[axis] != 1; + return not input_shape.dyn_dims().at(axis).intersection(shape::dynamic_dimension{1, 1}).has_value(); ; })) { MIGRAPHX_THROW( - "SQUEEZE: dynamic axis dimension should be equal to {1, 1, 0} or {1, 1, 1}"); + "SQUEEZE: dynamic axis dimension should have an intersection with {1, 1}"); } std::vector dyn_dims = {}; if(axes.empty()) From d74e9f13bbdd92ad03f8f6c44bcb795075148809 Mon Sep 17 00:00:00 2001 From: charlie Date: Fri, 3 May 2024 09:57:03 -0500 Subject: [PATCH 2/3] formatting --- src/include/migraphx/op/squeeze.hpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/include/migraphx/op/squeeze.hpp b/src/include/migraphx/op/squeeze.hpp index 08e0aba36bd..0394b480166 100644 --- a/src/include/migraphx/op/squeeze.hpp +++ b/src/include/migraphx/op/squeeze.hpp @@ -62,7 +62,11 @@ struct squeeze // Allow for any dynamic_dimension that intersects with {1, 1}. // Assuming that the shape at run-time will be compatible. if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { - return not input_shape.dyn_dims().at(axis).intersection(shape::dynamic_dimension{1, 1}).has_value(); ; + return not input_shape.dyn_dims() + .at(axis) + .intersection(shape::dynamic_dimension{1, 1}) + .has_value(); + ; })) { MIGRAPHX_THROW( From 8b3171491951b17e0de1cbda8b9b65f890d4eab4 Mon Sep 17 00:00:00 2001 From: charlie Date: Thu, 9 May 2024 10:29:39 -0500 Subject: [PATCH 3/3] Add test --- test/op_shape_test.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/op_shape_test.cpp b/test/op_shape_test.cpp index 333a7f9ef63..f88cc204935 100644 --- a/test/op_shape_test.cpp +++ b/test/op_shape_test.cpp @@ -4559,7 +4559,11 @@ TEST_CASE(test_squeeze_dyn) migraphx::shape s3{migraphx::shape::float_type, {{1, 4}, {3, 3}, {3, 3}}}; expect_shape(s3, migraphx::make_op("squeeze"), s1); - throws_shape(migraphx::make_op("squeeze", {{"axes", {0}}}), s1); + // allowing to squeeze dynamic_dimension that intersect with {1, 1} + migraphx::shape s4{migraphx::shape::float_type, {{1, 1}, {3, 3}, {1, 1}, {3, 3}}}; + expect_shape(s4, migraphx::make_op("squeeze", {{"axes", {0}}}), s1); + + throws_shape(migraphx::make_op("squeeze", {{"axes", {2}}}), s1); } TEST_CASE(test_squeeze_dyn_neg_axes)