Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Broadcast with dims matcher #2927

Merged
merged 10 commits into from
Apr 11, 2024
2 changes: 1 addition & 1 deletion src/include/migraphx/op/broadcast_with_dims.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ struct broadcast_with_dims
migraphx::check_shapes{inputs, *this, true}.has(2);
// check that second input has a static shape
(void)migraphx::check_shapes{inputs.begin() + 1, inputs.end(), *this, false};
// output tensor rank greater of input_tensor rank or length of dims vector
// output tensor rank is greater of input_tensor rank or length of dims vector
const auto& input_tensor_shape = inputs.at(0);
const auto& dims_shape = inputs.at(1);
size_t out_ndim = std::max(input_tensor_shape.ndim(), dims_shape.lens().at(0));
Expand Down
30 changes: 30 additions & 0 deletions src/simplify_dyn_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,35 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

/**
* Convert broadcast_with_dims operators with a static input tensor and a constant `dims` input
* into multibroadcast op with a static output shape attribute.
*
*/
struct find_broadcast_with_dims_static
{
auto matcher() const
{
return match::name("broadcast_with_dims")(match::nargs(2),
match::arg(0)(match::static_shape()),
match::arg(1)(match::is_constant()));
}

void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto inputs = ins->inputs();

// read the values of arg(1) to create input to multibroadcast
std::vector<size_t> sizes_vec;
inputs.at(1)->eval().visit(
[&](auto output) { sizes_vec.assign(output.begin(), output.end()); });

m.replace_instruction(
ins, make_op("multibroadcast", {{"out_lens", sizes_vec}}), inputs.at(0));
}
};

/**
* Convert a Resize op. with Nearest mode to an implementation using Gather op.
* From: resize[scales={...}/sizes={...},](static, constant)
Expand Down Expand Up @@ -586,6 +615,7 @@ struct simplify_select_module_output_shape
void simplify_dyn_ops::apply(module& m) const
{
match::find_matches(m,
find_broadcast_with_dims_static{},
find_resize_static{},
find_static_dimensions_of{},
find_const_alloc_reshapes{},
Expand Down
32 changes: 16 additions & 16 deletions test/op_shape_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,22 @@ TEST_CASE(broadcast_for_dot_dyn1)
s0);
}

TEST_CASE(broadcast_for_dot_dyn2)
{
migraphx::shape s0{migraphx::shape::float_type, {{6, 12}, {4, 4}, {8, 8}}};
migraphx::shape s1{migraphx::shape::float_type, {{1, 4, {1, 2, 4}}, {2, 10}, {8, 8}, {4, 4}}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{1, 4, {1, 2, 4}}, {6, 10}, {4, 4}, {8, 8}}},
migraphx::make_op("broadcast_for_dot"),
s0,
s1);
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{1, 4, {1, 2, 4}}, {6, 10}, {8, 8}, {4, 4}}},
migraphx::make_op("broadcast_for_dot"),
s1,
s0);
}

TEST_CASE(broadcast_with_dims0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CharlieL7 I think these two tests were from you, but Github is showing them as new to this PR. Did you mean to keep them, or not?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests look the same to me, just moved around.

{
using migraphx::shape;
Expand All @@ -950,22 +966,6 @@ TEST_CASE(broadcast_with_dims1)
s1);
}

TEST_CASE(broadcast_for_dot_dyn2)
{
migraphx::shape s0{migraphx::shape::float_type, {{6, 12}, {4, 4}, {8, 8}}};
migraphx::shape s1{migraphx::shape::float_type, {{1, 4, {1, 2, 4}}, {2, 10}, {8, 8}, {4, 4}}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{1, 4, {1, 2, 4}}, {6, 10}, {4, 4}, {8, 8}}},
migraphx::make_op("broadcast_for_dot"),
s0,
s1);
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{1, 4, {1, 2, 4}}, {6, 10}, {8, 8}, {4, 4}}},
migraphx::make_op("broadcast_for_dot"),
s1,
s0);
}

TEST_CASE(flatten_shape)
{
migraphx::shape input{migraphx::shape::float_type, {2, 4, 6, 8}};
Expand Down
50 changes: 50 additions & 0 deletions test/simplify_dyn_ops_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,56 @@ void run_pass(migraphx::module& m)
migraphx::run_passes(m, {migraphx::simplify_dyn_ops{}, migraphx::dead_code_elimination{}});
}

TEST_CASE(broadcast_with_dims)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CharlieL7 can you think of any other tests needed for this matcher PR?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good as is

{
migraphx::module m0;
{
// the X input
migraphx::shape sx{migraphx::shape::float_type, {3, 1, 1}};
auto inx = m0.add_parameter("x", sx);

// the shape input. Broadcast to this
migraphx::shape dims_s{migraphx::shape::int64_type, {4}};
std::vector<size_t> dims = {2, 3, 4, 5};
auto out_dims = m0.add_literal(migraphx::literal{dims_s, dims});

auto r = m0.add_instruction(migraphx::make_op("broadcast_with_dims"), inx, out_dims);
m0.add_return({r});
}
run_pass(m0);

migraphx::module m1;
{
migraphx::shape sx{migraphx::shape::float_type, {3, 1, 1}};
auto inx = m1.add_parameter("x", sx);

auto r = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4, 5}}}), inx);
m1.add_return({r});
}
EXPECT(m0 == m1);
}

TEST_CASE(broadcast_with_dims_invalid)
{
migraphx::module m0;
{
// X input shape is not broadcastable to given shape
migraphx::shape sx{migraphx::shape::float_type, {3, 1, 2}};
auto inx = m0.add_parameter("x", sx);

// the shape input. Broadcast to this
migraphx::shape dims_s{migraphx::shape::int64_type, {4}};
std::vector<size_t> dims = {2, 3, 4, 5};
auto out_dims = m0.add_literal(migraphx::literal{dims_s, dims});

auto r = m0.add_instruction(migraphx::make_op("broadcast_with_dims"), inx, out_dims);
m0.add_return({r});
}
// replacement will be rejected by multibroadcast operation
EXPECT(test::throws([&] { run_pass(m0); }));
}

TEST_CASE(resize)
{
migraphx::module m0;
Expand Down
Loading