Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize broadcast + transpose for nonscalars #2271

Merged
merged 12 commits into from
Oct 14, 2023
23 changes: 22 additions & 1 deletion src/simplify_algebra.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -521,6 +521,27 @@ struct find_inner_broadcast
}) < (lens.size() - 1);
}))
return;
if(broadcasts.size() > 1)
{
auto bcast_strides = broadcasts.front()->get_shape().strides().size();
std::vector<size_t> common_axis(bcast_strides, 0);
// go through the strides of each broadcast,
// keep track of values that are equal to 0 in a dimension
for(auto i = 0; i < bcast_strides; i++)
{
for(const auto& broadcast : broadcasts)
{
if(broadcast->get_shape().strides()[i] == 0)
common_axis[i]++;
}
}
// if no common broadcast axis, transformation is not useful
if(std::find_if(common_axis.begin(), common_axis.end(), [](auto num_common) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could use std::none_of instead.

return num_common > 1;
}) == common_axis.end())
return;
}

std::vector<instruction_ref> inputs;
std::transform(broadcasts.begin(),
broadcasts.end(),
Expand Down
30 changes: 23 additions & 7 deletions src/simplify_reshapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,9 @@ struct find_transpose_contiguous_reshaper_unary
}
};

// simplifies broadcast->transpose to transpose->broadcast
// in the case of a scalar, simply rewrite to broadcast
// this can allow for further optimizations with find_inner_broadcast() in simplify_algebra.cpp
struct find_broadcast_transpose
{
auto matcher() const
Expand All @@ -642,17 +645,30 @@ struct find_broadcast_transpose

void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto ins_lens = ins->get_shape().lens();
auto transpose = r.result;
auto transpose_lens = transpose->get_shape().lens();
auto bcast_ins = r.instructions["bcast_ins"];
auto input = bcast_ins->inputs().front();
// for now, focusing on scalar transformation
// scalar transformation does not need extra transpose
if(not input->get_shape().scalar())
return;

{
// find common shape
auto in_lens = input->get_shape().lens();
int lens_diff = transpose_lens.size() - in_lens.size();
// insert unsqueeze if input lens < transpose lens
if(lens_diff > 0)
{
std::vector<size_t> unsqueeze_axes(lens_diff);
std::iota(unsqueeze_axes.begin(), unsqueeze_axes.end(), 0);
input = m.insert_instruction(
bcast_ins, make_op("unsqueeze", {{"axes", unsqueeze_axes}}), input);
}
// apply transpose before the multibroadcast
input = m.insert_instruction(bcast_ins, transpose->get_operator(), input);
}
auto new_mbcast = m.insert_instruction(
bcast_ins, make_op("multibroadcast", {{"out_lens", ins_lens}}), input);
m.replace_instruction(ins, new_mbcast);
bcast_ins, make_op("multibroadcast", {{"out_lens", transpose_lens}}), input);
m.replace_instruction(transpose, new_mbcast);
}
};

Expand Down
37 changes: 37 additions & 0 deletions test/optimize_module_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,41 @@ TEST_CASE(broadcast_transpose_inner_broadcast)
EXPECT(m1 == m2);
}

TEST_CASE(broadcast_transpose_inner_broadcast_generic)
{
// first optimizes broadcast+transpose to unsqueeze+transpose+broadcast,
// then finds inner broadcast to become mul+broadcast
migraphx::module m1;
{
auto l1 = m1.add_parameter("x", {migraphx::shape::float_type, {5, 10}});
auto l2 = m1.add_parameter("y", {migraphx::shape::float_type, {5}});
auto mb1 =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 5, 10}}}), l1);
auto mb2 =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 10, 5}}}), l2);
auto t1 =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), mb2);
auto mul = m1.add_instruction(migraphx::make_op("mul"), mb1, t1);
m1.add_return({mul});
}
run_pass(m1);
migraphx::module m2;
{
auto l1 = m2.add_parameter("x", {migraphx::shape::float_type, {5, 10}});
auto l2 = m2.add_parameter("y", {migraphx::shape::float_type, {5}});
auto unsqueeze = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), l2);
auto transpose = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), unsqueeze);
auto mb1 =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 5, 10}}}), l1);
auto mb2 = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 5, 10}}}), transpose);
auto mul = m2.add_instruction(migraphx::make_op("mul"), mb1, mb2);
auto mb3 = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {3, 5, 10}}}), mul);
m2.add_return({mb3});
}
EXPECT(m1 == m2);
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }
19 changes: 18 additions & 1 deletion test/simplify_algebra_test.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -669,6 +669,23 @@ TEST_CASE(simplify_inner_broadcast_different_broadcasts)
EXPECT(m1 == m2);
}

TEST_CASE(simplify_inner_broadcast_no_common_axis)
{
auto b = migraphx::make_op("multibroadcast", {{"out_lens", {1, 5, 10}}});
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {5, 10}});
auto y = m1.add_parameter("y", {migraphx::shape::int32_type, {1, 5, 1}});
auto xb = m1.add_instruction(b, x);
auto yb = m1.add_instruction(b, y);
auto sum = m1.add_instruction(migraphx::make_op("add"), xb, yb);
m1.add_instruction(pass_op{}, sum);
Comment on lines +681 to +682
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Includes optimizations to find_inner_broadcast because otherwise a bunch of multibroadcasts were appearing after the transformation

Do you have a test for this specific case where multibroadcasts are appearing after find_inner_broadcast that are not being cleaned up by simplify_reshapes ? because for this test it will add multibroadcast but later i think it will get cleaned up by find_nop_reshaper.

Copy link
Collaborator Author

@kahmed10 kahmed10 Oct 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this exact test will fail currently on develop and insert a bunch of multibroadcasts:

./bin/test_simplify_algebra_test simplify_inner_broadcast_no_common_axis
[   RUN    ] simplify_inner_broadcast_no_common_axis
void simplify_inner_broadcast_no_common_axis()
/code/AMDMIGraphX/test/simplify_algebra_test.cpp:686:
    FAILED: m1 == m2 [ y = @param:y -> int32_type, {1, 5, 1}, {5, 1, 1}, target_id=0
x = @param:x -> int32_type, {5, 10}, {10, 1}, target_id=0
@2 = multibroadcast[out_lens={1, 5, 10},out_dyn_dims={}](x) -> int32_type, {1, 5, 10}, {0, 10, 1}, target_id=0
@3 = multibroadcast[out_lens={1, 5, 10},out_dyn_dims={}](y) -> int32_type, {1, 5, 10}, {5, 1, 0}, target_id=0
@4 = add(@2,@3) -> int32_type, {1, 5, 10}, {0, 10, 1}, target_id=0
@5 = multibroadcast[out_lens={1, 5, 10},out_dyn_dims={}](@4) -> int32_type, {1, 5, 10}, {0, 10, 1}, target_id=0
@6 = multibroadcast[out_lens={1, 5, 10},out_dyn_dims={}](@5) -> int32_type, {1, 5, 10}, {0, 10, 1}, target_id=0
@7 = multibroadcast[out_lens={1, 5, 10},out_dyn_dims={}](@6) -> int32_type, {1, 5, 10}, {0, 10, 1}, target_id=0
@8 = multibroadcast[out_lens={1, 5, 10},out_dyn_dims={}](@7) -> int32_type, {1, 5, 10}, {0, 10, 1}, target_id=0
@9 = multibroadcast[out_lens={1, 5, 10},out_dyn_dims={}](@8) -> int32_type, {1, 5, 10}, {0, 10, 1}, target_id=0
@10 = multibroadcast[out_lens={1, 5, 10},out_dyn_dims={}](@9) -> int32_type, {1, 5, 10}, {0, 10, 1}, target_id=0
@11 = multibroadcast[out_lens={1, 5, 10},out_dyn_dims={}](@10) -> int32_type, {1, 5, 10}, {0, 10, 1}, target_id=0
@12 = multibroadcast[out_lens={1, 5, 10},out_dyn_dims={}](@11) -> int32_type, {1, 5, 10}, {0, 10, 1}, target_id=0
@13 = pass(@12) -> int32_type, {1, 5, 10}, {0, 10, 1}, target_id=0
 == y = @param:y -> int32_type, {1, 5, 1}, {5, 1, 1}, target_id=0
x = @param:x -> int32_type, {5, 10}, {10, 1}, target_id=0
@2 = multibroadcast[out_lens={1, 5, 10},out_dyn_dims={}](x) -> int32_type, {1, 5, 10}, {0, 10, 1}, target_id=0
@3 = multibroadcast[out_lens={1, 5, 10},out_dyn_dims={}](y) -> int32_type, {1, 5, 10}, {5, 1, 0}, target_id=0
@4 = add(@2,@3) -> int32_type, {1, 5, 10}, {0, 10, 1}, target_id=0
@5 = pass(@4) -> int32_type, {1, 5, 10}, {0, 10, 1}, target_id=0
 ]
[  FAILED  ] simplify_inner_broadcast_no_common_axis: Test failure
[==========] 1 tests ran
[  FAILED  ] 1 tests failed
[  FAILED  ] simplify_inner_broadcast_no_common_axis

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those multi broadcasts should get cleaned up by find_nop_reshaper later

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also where all those broadcasts are being added. Looks to me that find_inner_broadcast would only add one multibroadcast after add

}
migraphx::module m2 = m1;
run_pass(m1);
EXPECT(m1 == m2);
}

TEST_CASE(simplify_add_conv1)
{
migraphx::module m;
Expand Down
51 changes: 51 additions & 0 deletions test/simplify_reshapes_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,57 @@ migraphx::module make_concat_multibroadcast(const std::vector<size_t>& in_lens,
return m;
}

TEST_CASE(broadcast_transpose)
{
migraphx::module m1;
{
auto l = m1.add_parameter("x", {migraphx::shape::float_type, {5}});
auto mb =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 5}}}), l);
auto t1 =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1}}}), mb);
m1.add_return({t1});
}
run_pass(m1);
migraphx::module m2;
{
auto l = m2.add_parameter("x", {migraphx::shape::float_type, {5}});
auto u1 = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), l);
auto t1 =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1}}}), u1);
auto mb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5, 2, 3}}}), t1);
m2.add_return({mb});
}

EXPECT(m1 == m2);
}

TEST_CASE(broadcast_transpose_opt)
{
// extra transpose from transformation will be optimized out
umangyadav marked this conversation as resolved.
Show resolved Hide resolved
migraphx::module m1;
{
auto l = m1.add_parameter("x", {migraphx::shape::float_type, {5}});
auto mb =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 5}}}), l);
auto t1 =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 2}}}), mb);
m1.add_return({t1});
}
run_pass(m1);
migraphx::module m2;
{
auto l = m2.add_parameter("x", {migraphx::shape::float_type, {5}});
auto u1 = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), l);
auto mb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 2, 5}}}), u1);
m2.add_return({mb});
}

EXPECT(m1 == m2);
}

TEST_CASE(broadcast_transpose_scalar)
{
migraphx::module m1;
Expand Down
Loading