Skip to content

Commit

Permalink
Merge branch 'develop' into simplify_select_module
Browse files Browse the repository at this point in the history
  • Loading branch information
causten authored Feb 16, 2024
2 parents 9f4b52e + b30a447 commit affd0b8
Show file tree
Hide file tree
Showing 4 changed files with 500 additions and 17 deletions.
1 change: 1 addition & 0 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
* @causten
126 changes: 125 additions & 1 deletion src/simplify_reshapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/op/as_shape.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/concat.hpp>
Expand Down Expand Up @@ -280,6 +281,78 @@ struct find_concat_multibroadcasts
}
};

struct find_concat_slice
{
auto matcher() const
{
return match::name("concat")(match::any_of[match::outputs()](match::name("slice")));
}

void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto inputs = ins->inputs();
auto outs = ins->outputs();
std::vector<migraphx::instruction_ref> slice_ins;
migraphx::transform_if(
outs.begin(),
outs.end(),
std::back_inserter(slice_ins),
[&](const auto& oins) { return oins->name() == "slice"; },
[&](const auto& oins) { return oins; });
int concat_axis = any_cast<op::concat>(ins->get_operator()).axis;
// prune slice candidates
std::vector<migraphx::instruction_ref> slice_candidates;
for(const auto& sins : range(slice_ins.begin(), slice_ins.end()))
{
auto sop = any_cast<op::slice>(sins->get_operator());
// slices with only one axis is allowed, because concat happens only one axis
if(sop.axes.size() != 1 or sop.axes.front() != concat_axis)
{
continue;
}
slice_candidates.push_back(sins);
}
if(slice_candidates.empty())
{
return;
}
std::vector<size_t> prefix_scan = {0};
std::transform(
inputs.begin(), inputs.end(), std::back_inserter(prefix_scan), [&](const auto& i) {
return prefix_scan.back() + i->get_shape().lens()[concat_axis];
});
for(const auto& sins : slice_candidates)
{
auto sop = any_cast<op::slice>(sins->get_operator());
size_t slice_start = sop.starts.front();
size_t slice_len = sop.ends.front() - slice_start;
auto fii = std::find_if(prefix_scan.begin(), prefix_scan.end(), [&](const auto& j) {
return j == slice_start;
});
if(fii == prefix_scan.end())
{
continue;
}
// slice_len == 0
else if(fii == prefix_scan.end() - 1)
{
assert(slice_len == 0 or slice_start >= prefix_scan.back());
continue;
}
else
{
size_t idx = std::distance(prefix_scan.begin(), fii);
if(inputs[idx]->get_shape().lens()[concat_axis] == slice_len)
{
assert((prefix_scan[idx + 1] - prefix_scan[idx]) == slice_len);
m.replace_instruction(sins, inputs[idx]);
}
}
}
}
};

struct find_concat_transpose
{
auto matcher() const
Expand Down Expand Up @@ -806,6 +879,55 @@ struct find_transpose_slice
}
};

struct find_reshape_reshape_dot
{
auto matcher() const
{
return match::name("dot")(match::used_once(),
match::args(match::name("reshape").bind("inp_rsp1"),
match::name("reshape").bind("inp_rsp2")));
}

// Gemm axis should not be altered by the reshape
auto is_valid_reshape(instruction_ref in, instruction_ref rsp) const
{
auto in_lens = in->get_shape().lens();
auto rsp_lens = rsp->get_shape().lens();

return std::equal(rsp_lens.end() - 2, rsp_lens.end(), in_lens.end() - 2, in_lens.end());
}

// Batch dims should match for both inputs
auto is_valid_inputs(instruction_ref in1, instruction_ref in2) const
{
auto in1_lens = in1->get_shape().lens();
auto in2_lens = in2->get_shape().lens();

return (
in1_lens.size() == in2_lens.size() and
std::equal(in1_lens.begin(), in1_lens.end() - 2, in2_lens.begin(), in2_lens.end() - 2));
}

void apply(module& m, const match::matcher_result& r) const
{
auto dot = r.result;
auto inp_rsp1 = r.instructions["inp_rsp1"];
auto inp_rsp2 = r.instructions["inp_rsp2"];

auto dot_lens = dot->get_shape().lens();

auto inp1 = inp_rsp1->inputs().front();
auto inp2 = inp_rsp2->inputs().front();

if(not(is_valid_reshape(inp1, inp_rsp1) and is_valid_reshape(inp2, inp_rsp2) and
is_valid_inputs(inp1, inp2)))
return;

auto new_dot = m.insert_instruction(dot, dot->get_operator(), inp1, inp2);
m.replace_instruction(dot, make_op("reshape", {{"dims", dot_lens}}), new_dot);
}
};

void simplify_reshapes::apply(module& m) const
{
for(int i = 0; i < depth; i++)
Expand All @@ -817,14 +939,16 @@ void simplify_reshapes::apply(module& m) const
find_reshaper{},
find_reshape_cont{},
find_transpose{},
find_concat_slice{},
find_concat_transpose{},
find_concat_multibroadcasts{},
find_nested_slice{},
find_nested_concat{},
find_transpose_slice{},
find_broadcast_transpose{},
find_slice_transpose{},
find_transpose_contiguous_reshaper_unary{});
find_transpose_contiguous_reshaper_unary{},
find_reshape_reshape_dot{});
dead_code_elimination{}.apply(m);
}
}
Expand Down
Loading

0 comments on commit affd0b8

Please sign in to comment.