From 00dae07fdc34d9f6c3373ed6bc82c37f01df4c64 Mon Sep 17 00:00:00 2001 From: Khalique Ahmed <15948690+kahmed10@users.noreply.github.com> Date: Mon, 2 Oct 2023 10:59:58 -0500 Subject: [PATCH 1/8] add generic case for broadcast transpose --- src/simplify_algebra.cpp | 29 +++++++++++++++++++++++++++++ src/simplify_reshapes.cpp | 16 +++++++++++++--- test/optimize_module_test.cpp | 34 ++++++++++++++++++++++++++++++++++ 3 files changed, 76 insertions(+), 3 deletions(-) diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index 43fb8ec17b9..3a93beb0556 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -521,6 +521,24 @@ struct find_inner_broadcast }) < (lens.size() - 1); })) return; + auto bcast_strides = broadcasts.front()->get_shape().strides().size(); + std::vector common_axis(bcast_strides, 0); + // go through the strides of each broadcast, + // keep track of values that are equal to 0 in a dimension + for(auto i = 0; i < bcast_strides; i++) + { + for(auto j = 0; j < broadcasts.size(); j++) + { + if(broadcasts[j]->get_shape().strides()[i] == 0) + common_axis[i]++; + } + } + // if no common broadcast axis, transformation is not useful + if(std::find_if(common_axis.begin(), common_axis.end(), [](auto num_common) { + return num_common > 1; + }) == common_axis.end()) + return; + std::vector inputs; std::transform(broadcasts.begin(), broadcasts.end(), @@ -543,6 +561,17 @@ struct find_inner_broadcast return 3; })); auto op = insert_common_op(m, ins, ins->get_operator(), inputs); + std::vector broadcast_shapes; + std::transform(broadcasts.begin(), broadcasts.end(), std::back_inserter(broadcast_shapes), [](auto broadcast){ + return broadcast->get_shape(); + }); + std::vector common_shapes; + std::transform(op->inputs().begin(), op->inputs().end(), std::back_inserter(common_shapes), [](auto common){ + return common->get_shape(); + }); + if(broadcast_shapes == common_shapes and std::all_of(op->inputs().begin(), op->inputs().end(), [](auto i){ + return i->name() == "broadcast" or i->name() == "multibroadcast";})) + return; m.replace_instruction(ins, broadcasts.front()->get_operator(), op); } }; diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index 08ea498e720..21dfd30a412 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -641,10 +641,20 @@ struct find_broadcast_transpose auto ins_lens = ins->get_shape().lens(); auto bcast_ins = r.instructions["bcast_ins"]; auto input = bcast_ins->inputs().front(); - // for now, focusing on scalar transformation + // scalar transformation does not need extra transpose if(not input->get_shape().scalar()) - return; - + { + // find common shape + auto in_lens = input->get_shape().lens(); + int lens_diff = ins_lens.size() - in_lens.size(); + if(lens_diff > 0) + { + std::vector unsqueeze_axes(lens_diff); + std::iota(unsqueeze_axes.begin(), unsqueeze_axes.end(), 0); + input = m.insert_instruction(bcast_ins, make_op("unsqueeze", {{"axes", unsqueeze_axes}}), input); + } + input = m.insert_instruction(bcast_ins, ins->get_operator(), input); + } auto new_mbcast = m.insert_instruction( bcast_ins, make_op("multibroadcast", {{"out_lens", ins_lens}}), input); m.replace_instruction(ins, new_mbcast); diff --git a/test/optimize_module_test.cpp b/test/optimize_module_test.cpp index d54219f1d93..6acb7ca404b 100644 --- a/test/optimize_module_test.cpp +++ b/test/optimize_module_test.cpp @@ -62,4 +62,38 @@ TEST_CASE(broadcast_transpose_inner_broadcast) EXPECT(m1 == m2); } +TEST_CASE(broadcast_transpose_inner_broadcast_generic) +{ + // first optimizes broadcast+transpose to unsqueeze+transpose+broadcast, + // then finds inner broadcast to become mul+broadcast + migraphx::module m1; + { + auto l1 = m1.add_parameter("x", {migraphx::shape::float_type, {5, 10}}); + auto l2 = m1.add_parameter("y", {migraphx::shape::float_type, {5}}); + auto mb1 = + m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 5, 10}}}), l1); + auto mb2 = + m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 10, 5}}}), l2); + auto t1 = + m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), mb2); + auto mul = m1.add_instruction(migraphx::make_op("mul"), mb1, t1); + m1.add_return({mul}); + } + run_pass(m1); + migraphx::module m2; + { + auto l1 = m2.add_parameter("x", {migraphx::shape::float_type, {5, 10}}); + auto l2 = m2.add_parameter("y", {migraphx::shape::float_type, {5}}); + auto unsqueeze = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), l2); + auto transpose = m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), unsqueeze); + auto mb1 = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 5, 10}}}), l1); + auto mb2 = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 5, 10}}}), transpose); + auto mul = m2.add_instruction(migraphx::make_op("mul"), mb1, mb2); + auto mb3 = + m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 5, 10}}}), mul); + m2.add_return({mb3}); + } + EXPECT(m1 == m2); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } From 0dc81c3b9839592fdf338e68683396acb77fef8d Mon Sep 17 00:00:00 2001 From: Khalique Ahmed <15948690+kahmed10@users.noreply.github.com> Date: Mon, 2 Oct 2023 11:00:04 -0500 Subject: [PATCH 2/8] formatting --- src/simplify_algebra.cpp | 22 +++++++++++++--------- src/simplify_reshapes.cpp | 5 +++-- test/optimize_module_test.cpp | 17 ++++++++++------- 3 files changed, 26 insertions(+), 18 deletions(-) diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index 3a93beb0556..f5cddb5cbad 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -524,7 +524,7 @@ struct find_inner_broadcast auto bcast_strides = broadcasts.front()->get_shape().strides().size(); std::vector common_axis(bcast_strides, 0); // go through the strides of each broadcast, - // keep track of values that are equal to 0 in a dimension + // keep track of values that are equal to 0 in a dimension for(auto i = 0; i < bcast_strides; i++) { for(auto j = 0; j < broadcasts.size(); j++) @@ -562,15 +562,19 @@ struct find_inner_broadcast })); auto op = insert_common_op(m, ins, ins->get_operator(), inputs); std::vector broadcast_shapes; - std::transform(broadcasts.begin(), broadcasts.end(), std::back_inserter(broadcast_shapes), [](auto broadcast){ - return broadcast->get_shape(); - }); + std::transform(broadcasts.begin(), + broadcasts.end(), + std::back_inserter(broadcast_shapes), + [](auto broadcast) { return broadcast->get_shape(); }); std::vector common_shapes; - std::transform(op->inputs().begin(), op->inputs().end(), std::back_inserter(common_shapes), [](auto common){ - return common->get_shape(); - }); - if(broadcast_shapes == common_shapes and std::all_of(op->inputs().begin(), op->inputs().end(), [](auto i){ - return i->name() == "broadcast" or i->name() == "multibroadcast";})) + std::transform(op->inputs().begin(), + op->inputs().end(), + std::back_inserter(common_shapes), + [](auto common) { return common->get_shape(); }); + if(broadcast_shapes == common_shapes and + std::all_of(op->inputs().begin(), op->inputs().end(), [](auto i) { + return i->name() == "broadcast" or i->name() == "multibroadcast"; + })) return; m.replace_instruction(ins, broadcasts.front()->get_operator(), op); } diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index 21dfd30a412..f6ed509e819 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -645,13 +645,14 @@ struct find_broadcast_transpose if(not input->get_shape().scalar()) { // find common shape - auto in_lens = input->get_shape().lens(); + auto in_lens = input->get_shape().lens(); int lens_diff = ins_lens.size() - in_lens.size(); if(lens_diff > 0) { std::vector unsqueeze_axes(lens_diff); std::iota(unsqueeze_axes.begin(), unsqueeze_axes.end(), 0); - input = m.insert_instruction(bcast_ins, make_op("unsqueeze", {{"axes", unsqueeze_axes}}), input); + input = m.insert_instruction( + bcast_ins, make_op("unsqueeze", {{"axes", unsqueeze_axes}}), input); } input = m.insert_instruction(bcast_ins, ins->get_operator(), input); } diff --git a/test/optimize_module_test.cpp b/test/optimize_module_test.cpp index 6acb7ca404b..8324be4a12f 100644 --- a/test/optimize_module_test.cpp +++ b/test/optimize_module_test.cpp @@ -82,15 +82,18 @@ TEST_CASE(broadcast_transpose_inner_broadcast_generic) run_pass(m1); migraphx::module m2; { - auto l1 = m2.add_parameter("x", {migraphx::shape::float_type, {5, 10}}); - auto l2 = m2.add_parameter("y", {migraphx::shape::float_type, {5}}); + auto l1 = m2.add_parameter("x", {migraphx::shape::float_type, {5, 10}}); + auto l2 = m2.add_parameter("y", {migraphx::shape::float_type, {5}}); auto unsqueeze = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), l2); - auto transpose = m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), unsqueeze); - auto mb1 = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 5, 10}}}), l1); - auto mb2 = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 5, 10}}}), transpose); + auto transpose = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), unsqueeze); + auto mb1 = + m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 5, 10}}}), l1); + auto mb2 = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 5, 10}}}), transpose); auto mul = m2.add_instruction(migraphx::make_op("mul"), mb1, mb2); - auto mb3 = - m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 5, 10}}}), mul); + auto mb3 = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {3, 5, 10}}}), mul); m2.add_return({mb3}); } EXPECT(m1 == m2); From fc38213a9569c7faa5d6471f316596471175b8cc Mon Sep 17 00:00:00 2001 From: Khalique Ahmed <15948690+kahmed10@users.noreply.github.com> Date: Mon, 2 Oct 2023 11:04:14 -0500 Subject: [PATCH 3/8] remove unnecessary code from simplify_algebra --- src/simplify_algebra.cpp | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index f5cddb5cbad..839dbe3be32 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -561,21 +561,6 @@ struct find_inner_broadcast return 3; })); auto op = insert_common_op(m, ins, ins->get_operator(), inputs); - std::vector broadcast_shapes; - std::transform(broadcasts.begin(), - broadcasts.end(), - std::back_inserter(broadcast_shapes), - [](auto broadcast) { return broadcast->get_shape(); }); - std::vector common_shapes; - std::transform(op->inputs().begin(), - op->inputs().end(), - std::back_inserter(common_shapes), - [](auto common) { return common->get_shape(); }); - if(broadcast_shapes == common_shapes and - std::all_of(op->inputs().begin(), op->inputs().end(), [](auto i) { - return i->name() == "broadcast" or i->name() == "multibroadcast"; - })) - return; m.replace_instruction(ins, broadcasts.front()->get_operator(), op); } }; From 3bc77083c05e37a97737f4540fced6cd83a3dfeb Mon Sep 17 00:00:00 2001 From: Khalique Ahmed <15948690+kahmed10@users.noreply.github.com> Date: Mon, 2 Oct 2023 14:55:24 -0500 Subject: [PATCH 4/8] add condition with just one broadcast --- src/simplify_algebra.cpp | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index 839dbe3be32..5a5363e5851 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -521,23 +521,26 @@ struct find_inner_broadcast }) < (lens.size() - 1); })) return; - auto bcast_strides = broadcasts.front()->get_shape().strides().size(); - std::vector common_axis(bcast_strides, 0); - // go through the strides of each broadcast, - // keep track of values that are equal to 0 in a dimension - for(auto i = 0; i < bcast_strides; i++) + if(broadcasts.size() > 1) { - for(auto j = 0; j < broadcasts.size(); j++) + auto bcast_strides = broadcasts.front()->get_shape().strides().size(); + std::vector common_axis(bcast_strides, 0); + // go through the strides of each broadcast, + // keep track of values that are equal to 0 in a dimension + for(auto i = 0; i < bcast_strides; i++) { - if(broadcasts[j]->get_shape().strides()[i] == 0) - common_axis[i]++; + for(const auto& broadcast : broadcasts) + { + if(broadcast->get_shape().strides()[i] == 0) + common_axis[i]++; + } } + // if no common broadcast axis, transformation is not useful + if(std::find_if(common_axis.begin(), common_axis.end(), [](auto num_common) { + return num_common > 1; + }) == common_axis.end()) + return; } - // if no common broadcast axis, transformation is not useful - if(std::find_if(common_axis.begin(), common_axis.end(), [](auto num_common) { - return num_common > 1; - }) == common_axis.end()) - return; std::vector inputs; std::transform(broadcasts.begin(), From 8d7948aa9189f60704bb0dc84d073199f7e2f245 Mon Sep 17 00:00:00 2001 From: Khalique Ahmed <15948690+kahmed10@users.noreply.github.com> Date: Mon, 2 Oct 2023 14:55:34 -0500 Subject: [PATCH 5/8] formatting --- src/simplify_algebra.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index 5a5363e5851..e4e1c6276c7 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -537,8 +537,8 @@ struct find_inner_broadcast } // if no common broadcast axis, transformation is not useful if(std::find_if(common_axis.begin(), common_axis.end(), [](auto num_common) { - return num_common > 1; - }) == common_axis.end()) + return num_common > 1; + }) == common_axis.end()) return; } From b4a9a9c5f056817d3e779b45cd19fa9834a1f320 Mon Sep 17 00:00:00 2001 From: Khalique Ahmed <15948690+kahmed10@users.noreply.github.com> Date: Wed, 4 Oct 2023 15:40:59 -0500 Subject: [PATCH 6/8] add tests, rename variables --- src/simplify_reshapes.cpp | 14 +++++----- test/simplify_algebra_test.cpp | 17 ++++++++++++ test/simplify_reshapes_test.cpp | 48 +++++++++++++++++++++++++++++++++ 3 files changed, 73 insertions(+), 6 deletions(-) diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index ab85e4a12f4..9c32d661dd4 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -642,8 +642,8 @@ struct find_broadcast_transpose void apply(module& m, const match::matcher_result& r) const { - auto ins = r.result; - auto ins_lens = ins->get_shape().lens(); + auto transpose = r.result; + auto transpose_lens = transpose->get_shape().lens(); auto bcast_ins = r.instructions["bcast_ins"]; auto input = bcast_ins->inputs().front(); // scalar transformation does not need extra transpose @@ -651,7 +651,8 @@ struct find_broadcast_transpose { // find common shape auto in_lens = input->get_shape().lens(); - int lens_diff = ins_lens.size() - in_lens.size(); + int lens_diff = transpose_lens.size() - in_lens.size(); + // insert unsqueeze if input lens < transpose lens if(lens_diff > 0) { std::vector unsqueeze_axes(lens_diff); @@ -659,11 +660,12 @@ struct find_broadcast_transpose input = m.insert_instruction( bcast_ins, make_op("unsqueeze", {{"axes", unsqueeze_axes}}), input); } - input = m.insert_instruction(bcast_ins, ins->get_operator(), input); + // apply transpose before the multibroadcast + input = m.insert_instruction(bcast_ins, transpose->get_operator(), input); } auto new_mbcast = m.insert_instruction( - bcast_ins, make_op("multibroadcast", {{"out_lens", ins_lens}}), input); - m.replace_instruction(ins, new_mbcast); + bcast_ins, make_op("multibroadcast", {{"out_lens", transpose_lens}}), input); + m.replace_instruction(transpose, new_mbcast); } }; diff --git a/test/simplify_algebra_test.cpp b/test/simplify_algebra_test.cpp index cea2fd54a56..33cf2345799 100644 --- a/test/simplify_algebra_test.cpp +++ b/test/simplify_algebra_test.cpp @@ -669,6 +669,23 @@ TEST_CASE(simplify_inner_broadcast_different_broadcasts) EXPECT(m1 == m2); } +TEST_CASE(simplify_inner_broadcast_no_common_axis) +{ + auto b = migraphx::make_op("multibroadcast", {{"out_lens", {1, 5, 10}}}); + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {5, 10}}); + auto y = m1.add_parameter("y", {migraphx::shape::int32_type, {1, 5, 1}}); + auto xb = m1.add_instruction(b, x); + auto yb = m1.add_instruction(b, y); + auto sum = m1.add_instruction(migraphx::make_op("add"), xb, yb); + m1.add_instruction(pass_op{}, sum); + } + migraphx::module m2 = m1; + run_pass(m1); + EXPECT(m1 == m2); +} + TEST_CASE(simplify_add_conv1) { migraphx::module m; diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index c16e9010ae7..f8914a5a438 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -67,6 +67,54 @@ migraphx::module make_concat_multibroadcast(const std::vector& in_lens, return m; } +TEST_CASE(broadcast_transpose) +{ + migraphx::module m1; + { + auto l = m1.add_parameter("x", {migraphx::shape::float_type, {5}}); + auto mb = + m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 5}}}), l); + auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1}}}), mb); + m1.add_return({t1}); + } + run_pass(m1); + migraphx::module m2; + { + auto l = m2.add_parameter("x", {migraphx::shape::float_type, {5}}); + auto u1 = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), l); + auto t1 = m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1}}}), u1); + auto mb = + m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5, 2, 3}}}), t1); + m2.add_return({mb}); + } + + EXPECT(m1 == m2); +} + +TEST_CASE(broadcast_transpose_opt) +{ + // extra transpose from transformation will be optimized out + migraphx::module m1; + { + auto l = m1.add_parameter("x", {migraphx::shape::float_type, {5}}); + auto mb = + m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 5}}}), l); + auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 2}}}), mb); + m1.add_return({t1}); + } + run_pass(m1); + migraphx::module m2; + { + auto l = m2.add_parameter("x", {migraphx::shape::float_type, {5}}); + auto u1 = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), l); + auto mb = + m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 2, 5}}}), u1); + m2.add_return({mb}); + } + + EXPECT(m1 == m2); +} + TEST_CASE(broadcast_transpose_scalar) { migraphx::module m1; From f385f43d1a9e77e948c0ce1a7e6851bac61ff419 Mon Sep 17 00:00:00 2001 From: Khalique Ahmed <15948690+kahmed10@users.noreply.github.com> Date: Wed, 4 Oct 2023 15:41:11 -0500 Subject: [PATCH 7/8] formatting --- src/simplify_reshapes.cpp | 4 ++-- test/simplify_reshapes_test.cpp | 13 ++++++++----- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index 9c32d661dd4..efb4d604f21 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -642,8 +642,8 @@ struct find_broadcast_transpose void apply(module& m, const match::matcher_result& r) const { - auto transpose = r.result; - auto transpose_lens = transpose->get_shape().lens(); + auto transpose = r.result; + auto transpose_lens = transpose->get_shape().lens(); auto bcast_ins = r.instructions["bcast_ins"]; auto input = bcast_ins->inputs().front(); // scalar transformation does not need extra transpose diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index f8914a5a438..c656ea54bba 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -74,15 +74,17 @@ TEST_CASE(broadcast_transpose) auto l = m1.add_parameter("x", {migraphx::shape::float_type, {5}}); auto mb = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 5}}}), l); - auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1}}}), mb); + auto t1 = + m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1}}}), mb); m1.add_return({t1}); } run_pass(m1); migraphx::module m2; { - auto l = m2.add_parameter("x", {migraphx::shape::float_type, {5}}); + auto l = m2.add_parameter("x", {migraphx::shape::float_type, {5}}); auto u1 = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), l); - auto t1 = m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1}}}), u1); + auto t1 = + m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1}}}), u1); auto mb = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5, 2, 3}}}), t1); m2.add_return({mb}); @@ -99,13 +101,14 @@ TEST_CASE(broadcast_transpose_opt) auto l = m1.add_parameter("x", {migraphx::shape::float_type, {5}}); auto mb = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 5}}}), l); - auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 2}}}), mb); + auto t1 = + m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 2}}}), mb); m1.add_return({t1}); } run_pass(m1); migraphx::module m2; { - auto l = m2.add_parameter("x", {migraphx::shape::float_type, {5}}); + auto l = m2.add_parameter("x", {migraphx::shape::float_type, {5}}); auto u1 = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), l); auto mb = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 2, 5}}}), u1); From ea62d7aa627d66c4b560c9323df16f3fa17c59bd Mon Sep 17 00:00:00 2001 From: Khalique Ahmed <15948690+kahmed10@users.noreply.github.com> Date: Tue, 10 Oct 2023 15:02:16 -0500 Subject: [PATCH 8/8] update copyright and add description to struct --- src/simplify_algebra.cpp | 2 +- src/simplify_reshapes.cpp | 3 +++ test/simplify_algebra_test.cpp | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index e4e1c6276c7..301b1d9970f 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index efb4d604f21..9899e875cea 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -632,6 +632,9 @@ struct find_transpose_contiguous_reshaper_unary } }; +// simplifies broadcast->transpose to transpose->broadcast +// in the case of a scalar, simply rewrite to broadcast +// this can allow for further optimizations with find_inner_broadcast() in simplify_algebra.cpp struct find_broadcast_transpose { auto matcher() const diff --git a/test/simplify_algebra_test.cpp b/test/simplify_algebra_test.cpp index 33cf2345799..f5c64d2260f 100644 --- a/test/simplify_algebra_test.cpp +++ b/test/simplify_algebra_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal