Skip to content

Commit

Permalink
Merge branch 'develop' into fix_remove_rocblas_tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kahmed10 authored May 21, 2024
2 parents dfb7450 + 1f07af9 commit c97572a
Show file tree
Hide file tree
Showing 16 changed files with 470 additions and 83 deletions.
2 changes: 1 addition & 1 deletion docs/sphinx/requirements.in
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
rocm-docs-core==1.1.1
rocm-docs-core==1.1.2
sphinx-collapse
2 changes: 1 addition & 1 deletion docs/sphinx/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ requests==2.31.0
# via
# pygithub
# sphinx
rocm-docs-core==1.1.1
rocm-docs-core==1.1.2
# via -r requirements.in
smmap==5.0.0
# via gitdb
Expand Down
32 changes: 22 additions & 10 deletions src/fuse_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@
* THE SOFTWARE.
*/
#include <migraphx/fuse_reduce.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/program.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/rewrite_reshapes.hpp>
#include <iterator>
Expand Down Expand Up @@ -67,7 +67,7 @@ struct fused_reduce
if(not equal(names, inputs, [&](const auto& name, const auto& input) {
return shapes.at(name).lens() == input.lens();
}))
MIGRAPHX_THROW("Dimenstion does not match the submodule.");
MIGRAPHX_THROW("Input dimension does not match the submodule.");

return shape::from_permutation(sm->get_output_shapes().front().type(),
sm->get_output_shapes().front().lens(),
Expand All @@ -78,6 +78,17 @@ struct fused_reduce
};
MIGRAPHX_REGISTER_OP(fused_reduce);

/*
* Predicate matcher checks that input and output shapes have the same rank. This is assumed
* for broadcast instructions for these fusions.
*/
MIGRAPHX_PRED_MATCHER(input_output_ndim_match, instruction_ref ins)
{
auto input_shape = ins->inputs().front()->get_shape();
auto output_shape = ins->get_shape();
return input_shape.ndim() == output_shape.ndim();
}

static void insert_params(module_ref sm,
const std::vector<instruction_ref>& inputs,
std::unordered_map<instruction_ref, instruction_ref>& map_ins)
Expand Down Expand Up @@ -227,7 +238,8 @@ template <class... Ms>
static auto match_broadcast(Ms... ms)
{
return match::skip(match::name("contiguous"))(
match::name("multibroadcast")(match::arg(0)(ms...), match::used_once())
match::name("multibroadcast")(
match::arg(0)(ms...), match::used_once(), input_output_ndim_match())
.bind("broadcast"))
.bind("final_broadcast");
}
Expand Down Expand Up @@ -257,19 +269,19 @@ struct find_pointwise_reduce
{
auto matcher() const
{
// fused_reduce instruction with pointwise inputs.
return match::name("fused_reduce")(match_broadcastable_input("pointwise", "pointwise"));
}

void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto reduce = r.result;
auto input = r.instructions["pointwise"];

auto input = r.instructions["pointwise"];
const auto* pm = input->module_inputs().front();
const auto* old_rm = reduce->module_inputs().front();

auto* rm = mpm.create_module(pm->name() + ":" + old_rm->name());
rm->set_bypass();

std::unordered_map<instruction_ref, instruction_ref> map_ins;
// Insert pointwise
auto rins = insert_ins_in_submodule(rm, input, map_ins).front();
Expand Down
18 changes: 18 additions & 0 deletions src/include/migraphx/matcher.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,27 @@ struct matcher_result
});
}

void debug_print() const
{
for(const auto& it : ins_map)
{
std::cout << it.first << ": \n";
it.second->debug_print();
}
}

private:
std::unordered_map<std::string, instruction_ref> ins_map;
};

void debug_print() const
{
std::cout << "matcher_container: \n instructions:";
instructions.debug_print();
std::cout << " result: \n";
result->debug_print();
}

instruction_container instructions;
instruction_ref result;
};
Expand Down
39 changes: 29 additions & 10 deletions src/include/migraphx/rewrite_reshapes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,16 @@ struct rewrite_reshapes
{
auto reshape =
match::name("reshape", "squeeze", "unsqueeze", "flatten")(match::used_once());
auto skip_contiguous = [](auto... ms) {
return match::arg(0)(match::skip(
match::name("contiguous", "multibroadcast")(match::used_once()))(ms...));
auto skip_contiguous_broadcast =
match::skip(match::name("contiguous", "multibroadcast")(match::used_once()));
auto skip_contiguous_broadcast_arg = [&](auto... ms) {
return match::arg(0)(skip_contiguous_broadcast(ms...));
};
auto pointwise = match::name(op1)(match::used_once());
auto reshape_pointwise = reshape(skip_contiguous(pointwise.bind("x"))).bind("reshape");
return match::name(op2)(match::any_of[match::inputs()](reshape_pointwise));
auto reshape_pointwise =
reshape(skip_contiguous_broadcast_arg(pointwise.bind("x"))).bind("reshape");
return match::name(op2)(match::any_of[match::inputs()](
skip_contiguous_broadcast(reshape_pointwise).bind("input")));
}

template <class F>
Expand All @@ -107,17 +110,33 @@ struct rewrite_reshapes
return x_ins == input;
}

static std::optional<bool> is_broadcasted(instruction_ref start, instruction_ref last)
{
auto broadcast_ins =
find_input_if(start, last, [&](auto i) { return i->name() == "multibroadcast"; });
bool result = broadcast_ins != last;
if(result and not match_input(broadcast_ins, last))
return nullopt;
return result;
}

void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto reshape_ins = r.instructions["reshape"];
auto input_ins = r.instructions["input"];

auto broadcast_ins = find_input_if(
reshape_ins, x_ins, [&](auto i) { return i->name() == "multibroadcast"; });
const bool has_broadcast = broadcast_ins != x_ins;
if(has_broadcast and not match_input(broadcast_ins, x_ins))
const auto has_broadcast_before_reshape = is_broadcasted(reshape_ins, x_ins);
const auto has_broadcast_after_reshape = is_broadcasted(input_ins, reshape_ins);
if(not has_broadcast_before_reshape.has_value())
return;
if(not has_broadcast_after_reshape.has_value())
return;
if(*has_broadcast_after_reshape and *has_broadcast_before_reshape)
return;
const bool has_broadcast =
*has_broadcast_after_reshape or *has_broadcast_before_reshape;

auto dims1 = T::base_dims(ins);
auto dims2 = T::base_dims(x_ins);
Expand Down Expand Up @@ -153,7 +172,7 @@ struct rewrite_reshapes

auto inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
if(input == reshape_ins)
if(input == input_ins)
return new_x_ins;
return reshape_input(ins)(input);
});
Expand Down
11 changes: 10 additions & 1 deletion src/include/migraphx/tf.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -45,6 +45,15 @@ struct tf_options
MIGRAPHX_TF_EXPORT program parse_tf(const std::string& name,
const tf_options& options = tf_options{});

/// Create a program from an tf buffer
MIGRAPHX_TF_EXPORT program parse_tf_buffer(const std::string& buffer,
const tf_options& options = tf_options{});

/// Create a program from tf buffer
MIGRAPHX_TF_EXPORT program parse_tf_buffer(const void* data,
std::size_t size,
const tf_options& options = tf_options{});

MIGRAPHX_TF_EXPORT std::vector<std::string> get_tf_operators();

} // namespace MIGRAPHX_INLINE_NS
Expand Down
2 changes: 2 additions & 0 deletions src/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1081,10 +1081,12 @@ const module* program::get_module(const std::string& name) const { return &impl-

module* program::create_module(const std::string& name)
{

assert(not contains(impl->modules, name));
auto r = impl->modules.emplace(name, name);
return &(r.first->second);
}

module* program::create_module(const std::string& name, module m)
{
assert(not contains(impl->modules, name));
Expand Down
73 changes: 59 additions & 14 deletions src/simplify_reshapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,21 @@ struct find_nested_slice
}
};

/**
* Example case
* From:
* param0: lens = [3, 4], strides = [4, 1]
* param1: lens = [3, 4], strides = [4, 1]
* mb0: multibroadcast(param0, output_lens = [2, 3, 4])
* mb1: multibroadcast(param1, output_lens = [2, 3, 4])
* concat(mb0, mb1, axis = 2)
*
* To:
* param0: lens = [3, 4], strides = [4, 1]
* param1: lens = [3, 4], strides = [4, 1]
* con0: concat(param0, param1, axis = 1)
* multibroadcast(con0, lens = [2, 3, 4])
*/
struct find_concat_multibroadcasts
{
auto matcher() const
Expand All @@ -253,32 +268,62 @@ struct find_concat_multibroadcasts

void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto op = any_cast<op::concat>(ins->get_operator());
auto out_lens = ins->get_shape().lens();
auto inputs = ins->inputs();
auto in_strides = inputs.front()->get_shape().strides();
auto concat_ins = mr.result;
auto concat_op = any_cast<op::concat>(concat_ins->get_operator());
auto concat_out_lens = concat_ins->get_shape().lens();
auto concat_inputs = concat_ins->inputs();
auto front_mb_strides = concat_inputs.front()->get_shape().strides();
assert(concat_op.axis >= 0);

// Only apply when concat axis is not a broadcasted dimension
if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) {
return i->get_shape().strides()[op.axis] == 0;
if(std::any_of(concat_inputs.begin(), concat_inputs.end(), [&](auto i) {
return i->get_shape().strides()[concat_op.axis] == 0;
}))
{
return;
}

// Use inputs of multibroadcast ops as inputs to new concat op
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [](auto i) {
// Get the inputs of multibroadcast ops. Will be used as inputs to new concat op
std::vector<instruction_ref> mb_inputs(concat_inputs.size());
std::transform(concat_inputs.begin(), concat_inputs.end(), mb_inputs.begin(), [](auto i) {
return i->inputs().front();
});

// Check that the inputs into the multibroadcasts have the same rank
const auto& first_shape = mb_inputs.front()->get_shape();
if(not std::all_of(mb_inputs.begin() + 1, mb_inputs.end(), [&](auto mb_in) {
return mb_in->get_shape().ndim() == first_shape.ndim();
}))
{
return;
}

// Reduce axis by number of leading broadcasted dimensions
if(inputs.front()->get_shape().lens().size() < out_lens.size())
op.axis -= std::count(in_strides.begin(), in_strides.begin() + op.axis, 0);
if(mb_inputs.front()->get_shape().lens().size() < concat_out_lens.size())
{
concat_op.axis -=
std::count(front_mb_strides.begin(), front_mb_strides.begin() + concat_op.axis, 0);
}

auto concat = m.insert_instruction(ins, op, inputs);
m.replace_instruction(
ins, migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), concat);
// Inputs to multibroadcasts should have the same dimensions except for the axis to
// concatenate over
const auto& front_in_lens = mb_inputs.front()->get_shape().lens();
if(not std::all_of(mb_inputs.begin() + 1, mb_inputs.end(), [&](auto input_to_mb) {
const auto& lens = input_to_mb->get_shape().lens();
return std::equal(
lens.begin(), lens.begin() + concat_op.axis, front_in_lens.begin()) and
std::equal(lens.begin() + concat_op.axis + 1,
lens.end(),
front_in_lens.begin() + concat_op.axis + 1);
}))
{
return;
}

auto new_concat_ins = m.insert_instruction(concat_ins, concat_op, mb_inputs);
m.replace_instruction(concat_ins,
migraphx::make_op("multibroadcast", {{"out_lens", concat_out_lens}}),
new_concat_ins);
}
};

Expand Down
Loading

0 comments on commit c97572a

Please sign in to comment.