Skip to content

Commit

Permalink
Revert "Revert "Handle broadcasts across dot and concat (#1689) (#1731)…
Browse files Browse the repository at this point in the history
…"" (#1837)
  • Loading branch information
umangyadav authored Jun 14, 2023
1 parent 7b82fc3 commit 17c0317
Show file tree
Hide file tree
Showing 3 changed files with 213 additions and 23 deletions.
19 changes: 11 additions & 8 deletions src/normalize_attributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,12 @@ inline namespace MIGRAPHX_INLINE_NS {
*
* See normalize_attribute.hpp for explaining the options.
*/
template <class Message>
auto tune_attribute(const std::vector<int64_t>& vec,
const std::vector<int64_t>& axes,
const value& val,
const std::vector<std::size_t>& lens)
const std::vector<std::size_t>& lens,
Message m)
{
std::vector<int64_t> result(vec);
int64_t n_rank = lens.size();
Expand Down Expand Up @@ -84,14 +86,14 @@ auto tune_attribute(const std::vector<int64_t>& vec,
{
if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less_equal<>{}))
{
MIGRAPHX_THROW("TUNE_VECTOR: value out of range!");
MIGRAPHX_THROW(m() + "value out of range!");
}
}
else
{
if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less<>{}))
{
MIGRAPHX_THROW("TUNE_VECTOR: value out of range!");
MIGRAPHX_THROW(m() + "value out of range!");
}
}
}
Expand Down Expand Up @@ -124,14 +126,14 @@ auto tune_attribute(const std::vector<int64_t>& vec,
if(not std::equal(
min_vals.begin(), min_vals.end(), result.begin(), std::less_equal<>{}))
{
MIGRAPHX_THROW("TUNE_VECTOR: attribute out of range!");
MIGRAPHX_THROW(m() + "attribute out of range!");
}
}
else
{
if(not std::equal(result.begin(), result.end(), min_vals.begin(), std::less<>{}))
{
MIGRAPHX_THROW("TUNE_VECTOR: attribute out of range!");
MIGRAPHX_THROW(m() + "attribute out of range!");
}
}
}
Expand Down Expand Up @@ -193,7 +195,8 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
const auto& key = rv.get_key();
if(val.contains(key))
{
auto vv = val.at(key).without_key();
auto message = [&] { return op.name() + ": " + key + ": "; };
auto vv = val.at(key).without_key();
if(vv.is_array())
{
std::vector<int64_t> axes;
Expand All @@ -202,7 +205,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
axes = val.at("axes").without_key().to_vector<int64_t>();
}
auto vec = vv.to_vector<int64_t>();
auto result = tune_attribute(vec, axes, rv.without_key(), lens);
auto result = tune_attribute(vec, axes, rv.without_key(), lens, message);
val[key] = result;
op.from_value(val);
val = op.to_value();
Expand All @@ -211,7 +214,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
else
{
auto num = vv.to<int64_t>();
auto result = tune_attribute({num}, {num}, rv.without_key(), lens);
auto result = tune_attribute({num}, {num}, rv.without_key(), lens, message);
val[key] = result.front();
op.from_value(val);
val = op.to_value();
Expand Down
131 changes: 116 additions & 15 deletions src/simplify_algebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,30 +361,118 @@ struct find_inner_broadcast
{
auto matcher() const { return pointwise(match::all_of[match::inputs()](match::broadcast())); }

static auto non_scalar_op(const std::string& name)
{
return [=](instruction_ref ins) {
if(ins->get_shape().scalar())
return false;
return ins->name() == name;
};
}

void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto broadcasts = ins->inputs();
if(broadcasts.empty())
return;
bool mixed_broadcasts = any_of(broadcasts, non_scalar_op("broadcast")) and
any_of(broadcasts, non_scalar_op("multibroadcast"));
// If the broadcast is not a single dimension, then dont perform inner_broadcast
if(mixed_broadcasts and any_of(broadcasts, [&](instruction_ref i) {
if(i->get_shape().scalar())
return false;
if(i->name() == "multibroadcast")
return false;
auto input = i->inputs().at(0);
const auto& lens = input->get_shape().lens();
return std::count_if(lens.begin(), lens.end(), [&](std::size_t d) {
return d == 1;
}) < (lens.size() - 1);
}))
return;
std::vector<instruction_ref> inputs;
std::transform(broadcasts.begin(),
broadcasts.end(),
std::back_inserter(inputs),
[](auto i) { return i->inputs().front(); });
if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) {
return i->get_shape() != inputs.front()->get_shape() and
i->get_shape().elements() != 1;
}))
return;

auto b_it = std::find_if(broadcasts.begin(), broadcasts.end(), [&](auto i) {
return not i->get_shape().scalar();
});
if(b_it == broadcasts.end())
b_it = broadcasts.begin();
[&](instruction_ref i) {
auto input = i->inputs().front();
if(mixed_broadcasts and not i->get_shape().scalar() and
i->get_shape().lens().size() > 1)
return m.insert_instruction(i, make_op("squeeze"), input);
return input;
});

std::sort(broadcasts.begin(), broadcasts.end(), by(std::less<>{}, [](instruction_ref i) {
if(i->get_shape().scalar())
return 2;
else if(i->name() == "broadcast")
return 0;
if(i->name() == "multibroadcast")
return 1;
return 3;
}));
auto op = insert_common_op(m, ins, ins->get_operator(), inputs);
m.replace_instruction(ins, (*b_it)->get_operator(), op);
m.replace_instruction(ins, broadcasts.front()->get_operator(), op);
}
};

struct find_dot_broadcast
{
auto matcher() const
{
return match::name("dot")(match::all_of[match::inputs()](match::broadcast()));
}

void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a = ins->inputs()[0];
auto b = ins->inputs()[1];
if(a->get_operator().name() != b->get_operator().name())
return;
if(ins->get_shape().lens().size() < 3)
return;
auto nbatch_axes = ins->get_shape().lens().size() - 2;
const auto& a_strides = a->get_shape().strides();
const auto& b_strides = b->get_shape().strides();
// Find leading batch axes that are broadcasted
auto p =
std::mismatch(a_strides.begin(),
a_strides.begin() + nbatch_axes,
b_strides.begin(),
b_strides.begin() + nbatch_axes,
[](auto astride, auto bstride) { return astride == 0 and bstride == 0; });
auto naxes = p.first - a_strides.begin();
assert(naxes <= nbatch_axes);
std::vector<std::size_t> axes(naxes);
std::iota(axes.begin(), axes.end(), 0);

auto insert_broadcast = [&](instruction_ref b_ins) -> instruction_ref {
auto input = b_ins->inputs()[0];
std::vector<std::size_t> lens(b_ins->get_shape().lens().begin() + naxes,
b_ins->get_shape().lens().end());
if(b_ins->name() == "multibroadcast")
{
return m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", lens}}), input);
}
else if(b_ins->name() == "broadcast")
{
auto v = b_ins->get_operator().to_value();
auto axis = v.at("axis").to<std::size_t>() - naxes;
return m.insert_instruction(
ins, make_op("broadcast", {{"axis", axis}, {"out_lens", lens}}), input);
}
assert(false);
return m.end();
};
auto a1 = insert_broadcast(a);
auto b1 = insert_broadcast(b);
auto dot = m.insert_instruction(ins, make_op("dot"), a1, b1);
auto broadcast = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", ins->get_shape().lens()}}), dot);
m.replace_instruction(ins, broadcast);
}
};

Expand All @@ -393,7 +481,8 @@ struct find_concat_op
auto matcher() const
{
return match::name("concat")(match::any_of[match::inputs()](
match::any_of(match::pointwise(), match::name("broadcast")), match::used_once()));
match::any_of(match::pointwise(), match::name("broadcast", "multibroadcast")),
match::used_once()));
}

template <class Iterator>
Expand All @@ -412,7 +501,8 @@ struct find_concat_op

static bool is_valid_op(const operation& op)
{
return op.name() == "broadcast" or op.attributes().contains("pointwise");
return contains({"broadcast", "multibroadcast"}, op.name()) or
op.attributes().contains("pointwise");
}

void apply(module& m, const match::matcher_result& r) const
Expand Down Expand Up @@ -440,6 +530,16 @@ struct find_concat_op
op = b;
iaxis = 0;
}
else if(op.name() == "multibroadcast")
{
shape bshape = (*start)->get_shape();
auto input = (*start)->inputs()[0];
if(iaxis >= bshape.strides().size() or bshape.strides()[iaxis] == 0)
return {start, last};
op.from_value({{"out_lens", get_output_lens(start, last, iaxis)}});
auto delta = bshape.lens().size() - input->get_shape().lens().size();
iaxis -= delta;
}

std::vector<instruction_ref> concats;
for(std::size_t i = 0; i < x->inputs().size(); i++)
Expand Down Expand Up @@ -1260,6 +1360,7 @@ void simplify_algebra::apply(module& m) const
{
match::find_matches(m,
find_inner_broadcast{},
find_dot_broadcast{},
find_double_add_lit_broadcast{},
find_add_lit_broadcast{},
find_add_convs{},
Expand Down
86 changes: 86 additions & 0 deletions test/simplify_algebra_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,60 @@ TEST_CASE(simplify_inner_broadcast_scalar)
EXPECT(m1 == m2);
}

TEST_CASE(simplify_inner_broadcast_different_dims)
{
auto b = migraphx::op::multibroadcast{{2, 384, 768}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {384, 768}});
auto y = m1.add_parameter("y", {migraphx::shape::int32_type, {768}});
auto xb = m1.add_instruction(b, x);
auto yb = m1.add_instruction(b, y);
auto sum = m1.add_instruction(migraphx::make_op("add"), xb, yb);
m1.add_instruction(pass_op{}, sum);
}
run_pass(m1);

migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {384, 768}});
auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {768}});
auto yb = m2.add_instruction(migraphx::op::multibroadcast{{384, 768}}, y);
auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb);
auto sumb = m2.add_instruction(b, sum);
m2.add_instruction(pass_op{}, sumb);
}
EXPECT(m1 == m2);
}

TEST_CASE(simplify_inner_broadcast_different_broadcasts)
{
auto b = migraphx::op::broadcast{1, {1, 24, 112, 112}};
auto mb = migraphx::op::multibroadcast{{1, 24, 112, 112}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {24}});
auto y = m1.add_parameter("y", {migraphx::shape::int32_type, {24, 1, 1}});
auto xb = m1.add_instruction(b, x);
auto yb = m1.add_instruction(mb, y);
auto sum = m1.add_instruction(migraphx::make_op("add"), xb, yb);
m1.add_instruction(pass_op{}, sum);
}
run_pass(m1);

migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {24}});
auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {24, 1, 1}});
auto xs = m2.add_instruction(migraphx::make_op("squeeze"), x);
auto ys = m2.add_instruction(migraphx::make_op("squeeze"), y);
auto sum = m2.add_instruction(migraphx::make_op("add"), xs, ys);
auto sumb = m2.add_instruction(b, sum);
m2.add_instruction(pass_op{}, sumb);
}
EXPECT(m1 == m2);
}

TEST_CASE(simplify_add_conv1)
{
migraphx::module m;
Expand Down Expand Up @@ -3003,6 +3057,38 @@ TEST_CASE(reorder_slice_ins_deps)
EXPECT(m == create_module());
}

TEST_CASE(dot_broadcast_different_rank)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {768}});
auto y = m1.add_parameter("y", {migraphx::shape::float_type, {768, 3072}});
auto xb = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 384, 768}}}), x);
auto yb = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 768, 3072}}}), y);
auto dot = m1.add_instruction(migraphx::make_op("dot"), xb, yb);
m1.add_return({dot});
};

migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {768}});
auto y = m2.add_parameter("y", {migraphx::shape::float_type, {768, 3072}});
auto xb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {384, 768}}}), x);
auto yb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {768, 3072}}}), y);
auto dot = m2.add_instruction(migraphx::make_op("dot"), xb, yb);
auto broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 384, 3072}}}), dot);
m2.add_return({broadcast});
};

run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(dot_fusion_reshape)
{
migraphx::module m1;
Expand Down

0 comments on commit 17c0317

Please sign in to comment.