diff --git a/src/include/migraphx/op/broadcast_with_dims.hpp b/src/include/migraphx/op/broadcast_with_dims.hpp index 489a7fd25fd..66267008c26 100644 --- a/src/include/migraphx/op/broadcast_with_dims.hpp +++ b/src/include/migraphx/op/broadcast_with_dims.hpp @@ -55,7 +55,7 @@ struct broadcast_with_dims migraphx::check_shapes{inputs, *this, true}.has(2); // check that second input has a static shape (void)migraphx::check_shapes{inputs.begin() + 1, inputs.end(), *this, false}; - // output tensor rank greater of input_tensor rank or length of dims vector + // output tensor rank is greater of input_tensor rank or length of dims vector const auto& input_tensor_shape = inputs.at(0); const auto& dims_shape = inputs.at(1); size_t out_ndim = std::max(input_tensor_shape.ndim(), dims_shape.lens().at(0)); diff --git a/src/simplify_dyn_ops.cpp b/src/simplify_dyn_ops.cpp index 5a3d5bce186..3785c19609f 100644 --- a/src/simplify_dyn_ops.cpp +++ b/src/simplify_dyn_ops.cpp @@ -33,6 +33,35 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { +/** + * Convert broadcast_with_dims operators with a static input tensor and a constant `dims` input + * into multibroadcast op with a static output shape attribute. + * + */ +struct find_broadcast_with_dims_static +{ + auto matcher() const + { + return match::name("broadcast_with_dims")(match::nargs(2), + match::arg(0)(match::static_shape()), + match::arg(1)(match::is_constant())); + } + + void apply(module& m, const match::matcher_result& mr) const + { + auto ins = mr.result; + auto inputs = ins->inputs(); + + // read the values of arg(1) to create input to multibroadcast + std::vector sizes_vec; + inputs.at(1)->eval().visit( + [&](auto output) { sizes_vec.assign(output.begin(), output.end()); }); + + m.replace_instruction( + ins, make_op("multibroadcast", {{"out_lens", sizes_vec}}), inputs.at(0)); + } +}; + /** * Convert a Resize op. with Nearest mode to an implementation using Gather op. * From: resize[scales={...}/sizes={...},](static, constant) @@ -586,6 +615,7 @@ struct simplify_select_module_output_shape void simplify_dyn_ops::apply(module& m) const { match::find_matches(m, + find_broadcast_with_dims_static{}, find_resize_static{}, find_static_dimensions_of{}, find_const_alloc_reshapes{}, diff --git a/test/op_shape_test.cpp b/test/op_shape_test.cpp index 25f3f9d6d1c..6a0a476c381 100644 --- a/test/op_shape_test.cpp +++ b/test/op_shape_test.cpp @@ -926,6 +926,22 @@ TEST_CASE(broadcast_for_dot_dyn1) s0); } +TEST_CASE(broadcast_for_dot_dyn2) +{ + migraphx::shape s0{migraphx::shape::float_type, {{6, 12}, {4, 4}, {8, 8}}}; + migraphx::shape s1{migraphx::shape::float_type, {{1, 4, {1, 2, 4}}, {2, 10}, {8, 8}, {4, 4}}}; + expect_shape( + migraphx::shape{migraphx::shape::float_type, {{1, 4, {1, 2, 4}}, {6, 10}, {4, 4}, {8, 8}}}, + migraphx::make_op("broadcast_for_dot"), + s0, + s1); + expect_shape( + migraphx::shape{migraphx::shape::float_type, {{1, 4, {1, 2, 4}}, {6, 10}, {8, 8}, {4, 4}}}, + migraphx::make_op("broadcast_for_dot"), + s1, + s0); +} + TEST_CASE(broadcast_with_dims0) { using migraphx::shape; @@ -950,22 +966,6 @@ TEST_CASE(broadcast_with_dims1) s1); } -TEST_CASE(broadcast_for_dot_dyn2) -{ - migraphx::shape s0{migraphx::shape::float_type, {{6, 12}, {4, 4}, {8, 8}}}; - migraphx::shape s1{migraphx::shape::float_type, {{1, 4, {1, 2, 4}}, {2, 10}, {8, 8}, {4, 4}}}; - expect_shape( - migraphx::shape{migraphx::shape::float_type, {{1, 4, {1, 2, 4}}, {6, 10}, {4, 4}, {8, 8}}}, - migraphx::make_op("broadcast_for_dot"), - s0, - s1); - expect_shape( - migraphx::shape{migraphx::shape::float_type, {{1, 4, {1, 2, 4}}, {6, 10}, {8, 8}, {4, 4}}}, - migraphx::make_op("broadcast_for_dot"), - s1, - s0); -} - TEST_CASE(flatten_shape) { migraphx::shape input{migraphx::shape::float_type, {2, 4, 6, 8}}; diff --git a/test/simplify_dyn_ops_test.cpp b/test/simplify_dyn_ops_test.cpp index caad3cb2771..ebc79b0bce4 100644 --- a/test/simplify_dyn_ops_test.cpp +++ b/test/simplify_dyn_ops_test.cpp @@ -34,6 +34,56 @@ void run_pass(migraphx::module& m) migraphx::run_passes(m, {migraphx::simplify_dyn_ops{}, migraphx::dead_code_elimination{}}); } +TEST_CASE(broadcast_with_dims) +{ + migraphx::module m0; + { + // the X input + migraphx::shape sx{migraphx::shape::float_type, {3, 1, 1}}; + auto inx = m0.add_parameter("x", sx); + + // the shape input. Broadcast to this + migraphx::shape dims_s{migraphx::shape::int64_type, {4}}; + std::vector dims = {2, 3, 4, 5}; + auto out_dims = m0.add_literal(migraphx::literal{dims_s, dims}); + + auto r = m0.add_instruction(migraphx::make_op("broadcast_with_dims"), inx, out_dims); + m0.add_return({r}); + } + run_pass(m0); + + migraphx::module m1; + { + migraphx::shape sx{migraphx::shape::float_type, {3, 1, 1}}; + auto inx = m1.add_parameter("x", sx); + + auto r = m1.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4, 5}}}), inx); + m1.add_return({r}); + } + EXPECT(m0 == m1); +} + +TEST_CASE(broadcast_with_dims_invalid) +{ + migraphx::module m0; + { + // X input shape is not broadcastable to given shape + migraphx::shape sx{migraphx::shape::float_type, {3, 1, 2}}; + auto inx = m0.add_parameter("x", sx); + + // the shape input. Broadcast to this + migraphx::shape dims_s{migraphx::shape::int64_type, {4}}; + std::vector dims = {2, 3, 4, 5}; + auto out_dims = m0.add_literal(migraphx::literal{dims_s, dims}); + + auto r = m0.add_instruction(migraphx::make_op("broadcast_with_dims"), inx, out_dims); + m0.add_return({r}); + } + // replacement will be rejected by multibroadcast operation + EXPECT(test::throws([&] { run_pass(m0); })); +} + TEST_CASE(resize) { migraphx::module m0;