Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/develop' into fuse_contiguous_po…
Browse files Browse the repository at this point in the history
…intwise
  • Loading branch information
umangyadav committed Mar 15, 2024
2 parents 4b43c35 + ef285c9 commit a739662
Show file tree
Hide file tree
Showing 7 changed files with 235 additions and 32 deletions.
6 changes: 3 additions & 3 deletions src/onnx/parse_groupnorm.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -118,8 +118,8 @@ struct parse_groupnorm : op_parser<parse_groupnorm>
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale);
auto bias_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias);
auto scaled = info.add_instruction(make_op("mul"), result, scale_bcast);
auto y = info.add_instruction(make_op("add"), scaled, bias_bcast);
auto scaled = info.add_instruction(make_op("mul"), result, scale_bcast);
auto y = info.add_instruction(make_op("add"), scaled, bias_bcast);
return info.add_instruction(make_op("reshape", {{"dims", x_dims}}), y);
}
};
Expand Down
65 changes: 62 additions & 3 deletions src/simplify_qdq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace {

template <class... Ms>
auto skip_post_dq_ops(Ms... ms)
Expand Down Expand Up @@ -298,6 +299,64 @@ bool compare_literals(instruction_ref ins1, instruction_ref ins2)
return (x == y) or diff_shapes_equal_vals;
}

template <class Iterator>
bool precedes(Iterator x, Iterator y, Iterator last)
{
auto r = range(std::next(x), last);
return any_of(iterator_for(r), [&](auto it) { return it == y; });
}

struct match_qlinear_reused
{
auto matcher() const
{
return match::name("quantizelinear")(
match::used_once(), match::arg(0)(match::none_of(match::used_once()).bind("x")));
}

void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
assert(ins != x_ins);

auto dq_inputs = ins->inputs();
dq_inputs[0] = ins;
auto outputs = x_ins->outputs();
if(outputs.size() != 2)
return;
for(auto output : outputs)
{
if(output->name() == "quantizelinear")
continue;
if(not output->get_operator().attributes().contains("pointwise"))
continue;
if(not precedes(ins, output, m.end()))
continue;
auto dq = m.insert_instruction(std::next(ins), make_op("dequantizelinear"), dq_inputs);
instruction::replace_argument(output, x_ins, dq);
}
}
};

bool is_same_value(instruction_ref a, instruction_ref b)
{
if(a == b)
return true;
return compare_literals(a, b);
}

bool is_same_scale_zero(instruction_ref a, instruction_ref b)
{
if(a->inputs().size() != b->inputs().size())
return false;
if(not is_same_value(a->inputs().at(1), b->inputs().at(1)))
return false;
if(a->inputs().size() == 2)
return true;
return is_same_value(a->inputs().at(2), b->inputs().at(2));
}

void remove_qdq_pairs(module& m)
{
for(auto ins : iterator_for(m))
Expand All @@ -308,23 +367,23 @@ void remove_qdq_pairs(module& m)
if(arg->name() == "dequantizelinear")
{
auto q = arg->inputs().front();
if((q->name() == "quantizelinear") and
compare_literals(arg->inputs().at(1), q->inputs().at(1)) and
compare_literals(arg->inputs().at(2), q->inputs().at(2)))
if((q->name() == "quantizelinear") and is_same_scale_zero(arg, q))
{
instruction::replace_argument(ins, arg, q->inputs().front());
}
}
}
}
}
} // namespace

void simplify_qdq::apply(module& m) const
{
match::find_matches(m, match_find_quantizable_ops{});
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
remove_qdq_pairs(m);
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
match::find_matches(m, match_qlinear_reused{});
}

} // namespace MIGRAPHX_INLINE_NS
Expand Down
35 changes: 26 additions & 9 deletions src/targets/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <migraphx/register_op.hpp>
#include <migraphx/env.hpp>
#include <migraphx/algorithm.hpp>
#include <optional>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand Down Expand Up @@ -149,7 +150,7 @@ struct mlir_op
if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs.");

auto type = mod->get_output_shapes().front().type();
auto type = mod->get_output_shapes().front().type();
auto mod_params = mod->get_parameter_names();
std::sort(mod_params.begin(), mod_params.end());
std::unordered_map<instruction_ref, shape> mod_ins_shapes;
Expand Down Expand Up @@ -524,15 +525,17 @@ struct find_mlir_standalone_attention_op
static size_t counter = 0;
module_ref mm = mpm.create_module("mlir_" + std::to_string(counter++));
auto gemm_softmax_gemm = r.instructions["gemm_softmax_gemm"];
std::vector<instruction_ref> inputs;
mm->set_bypass();

std::unordered_map<instruction_ref, instruction_ref> ins_map;
auto gemm0_inputs = gemm_softmax_gemm->inputs();
gemm0_inputs.pop_back();
auto orig_inputs = gemm_softmax_gemm->inputs();

std::vector<instruction_ref> gemm0_inputs = {orig_inputs[0], orig_inputs[1]};
auto [gemm0, top_gemm0_inputs] =
fuse_input_ops_and_gemm_based_op(mm, gemm0_inputs, make_op("dot"));

std::vector<instruction_ref> inputs;
inputs.insert(inputs.begin(), top_gemm0_inputs.begin(), top_gemm0_inputs.end());

// handle scale
auto v = gemm_softmax_gemm->get_operator().to_value();
assert(v.contains("scale"));
Expand All @@ -542,18 +545,31 @@ struct find_mlir_standalone_attention_op
make_op("multibroadcast", {{"out_lens", gemm0->get_shape().lens()}}), scale_lit);
auto scaled_gemm0 = mm->add_instruction(make_op("mul"), gemm0, scale_lit_mbcast);

std::optional<instruction_ref> bias{nullopt};
if(orig_inputs.size() == 4)
{
auto bias_input = orig_inputs[2];
instruction_ref bias_param =
mm->add_parameter("y_bias", bias_input->get_shape().as_standard());
bias = mm->add_instruction(make_op("add"), scaled_gemm0, bias_param);
inputs.push_back(bias_input);
}

auto softmax = mm->add_instruction(
make_op("softmax", {{"axis", gemm0->get_shape().lens().size() - 1}}), scaled_gemm0);
auto [old_upper_v, upper_v_op_stream] =
get_fusable_input_op_stream(gemm_softmax_gemm->inputs()[2]);
make_op("softmax", {{"axis", gemm0->get_shape().lens().size() - 1}}),
bias ? bias.value() : scaled_gemm0);
auto [old_upper_v, upper_v_op_stream] = get_fusable_input_op_stream(orig_inputs.back());
instruction_ref new_upper_v =
mm->add_parameter("z", old_upper_v->get_shape().as_standard());
for(const auto& op : reverse(upper_v_op_stream))
{
new_upper_v = mm->add_instruction(op, {new_upper_v});
}
inputs.push_back(old_upper_v);
auto gemm1 = mm->add_instruction(make_op("dot"), {softmax, new_upper_v});

auto gemm1 = mm->add_instruction(make_op("dot"), {softmax, new_upper_v});

std::unordered_map<instruction_ref, instruction_ref> ins_map;
ins_map[gemm_softmax_gemm] = gemm1;
auto ins_to_replace = gemm1;
auto ins_to_be_replaced = gemm_softmax_gemm;
Expand All @@ -567,6 +583,7 @@ struct find_mlir_standalone_attention_op
ins_to_be_replaced = r.instructions["trailing_pm"];
}
mm->add_return({ins_to_replace});

mpm.get_module().replace_instruction(
ins_to_be_replaced, mlir_op{gemm1->get_operator()}, inputs, {mm});
}
Expand Down
27 changes: 22 additions & 5 deletions src/targets/gpu/include/migraphx/gpu/gemm_softmax_gemm.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -26,6 +26,7 @@

#include <migraphx/make_op.hpp>
#include <migraphx/check_shapes.hpp>
#include <sstream>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand Down Expand Up @@ -55,14 +56,30 @@ struct gemm_softmax_gemm
check_shapes{inputs, *this}.same_ndims();
if(inputs.size() < 3)
MIGRAPHX_THROW(name() + ": Expected 3 inputs but got " + to_string(inputs.size()));
auto a = inputs[0];
auto b = inputs[1];
auto b1 = inputs[2];

const bool is_bias_enabled = inputs.size() == 4;
auto a = inputs[0];
auto b = inputs[1];
auto b1 = inputs[is_bias_enabled ? 3 : 2];

for(const auto& input : inputs)
{
check_gemm_shape(input);
}
return op.compute_shape({op.compute_shape({a, b}), b1});
auto gemm0_shape = op.compute_shape({a, b});
if(is_bias_enabled)
{
auto bias_shape = inputs[2];
if(bias_shape.lens() != gemm0_shape.lens())
{
std::stringstream err_msg;
err_msg << name() << ": has inconsistent bias size"
<< ". Expected: " << gemm0_shape << ". Given: " << bias_shape;
MIGRAPHX_THROW(err_msg.str());
}
}

return op.compute_shape({gemm0_shape, b1});
}

static bool is_ck_supported_type(shape::type_t t) { return contains({shape::half_type}, t); }
Expand Down
23 changes: 21 additions & 2 deletions src/targets/gpu/prefuse_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/matcher.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/gpu/gemm_softmax_gemm.hpp>
Expand Down Expand Up @@ -171,6 +172,17 @@ auto is_test_gemm(bool enable_attention)
});
}

auto is_bias_supported()
{
return match::make_basic_pred_matcher([=](instruction_ref) {
#ifdef MIGRAPHX_USE_COMPOSABLEKERNEL
return not enabled(MIGRAPHX_ENABLE_CK{});
#else
return true;
#endif
});
}

struct find_gemm_softmax_gemm
{
bool enable_attention = false;
Expand All @@ -182,8 +194,11 @@ struct find_gemm_softmax_gemm
.bind("gemm1")));
auto mul = match::name("mul")(
match::nargs(2), match::either_arg(0, 1)(match::is_constant().bind("scale"), gemm1));
auto add = match::name("add")(is_bias_supported(),
match::nargs(2),
match::any_arg(0, 1)(match::none_of(mul).bind("bias")));
auto softmax =
match::name("softmax")(match::arg(0)(match::any_of(mul, gemm1))).bind("softmax");
match::name("softmax")(match::arg(0)(match::any_of(mul, add, gemm1))).bind("softmax");

return match::name("dot")(
match::any_of(is_ck_gemm(), is_mlir_gemm(), is_test_gemm(enable_attention))
Expand All @@ -210,7 +225,11 @@ struct find_gemm_softmax_gemm
});
}

auto inputs = gemm1_ins->inputs(); // A, B
auto inputs = gemm1_ins->inputs(); // A, B
if(contains(r.instructions, "bias"))
{
inputs.push_back(r.instructions["bias"]);
}
inputs.push_back(gemm2_ins->inputs().back()); // B1

mpm.get_module().replace_instruction(
Expand Down
68 changes: 68 additions & 0 deletions test/simplify_qdq_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <test.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/matcher.hpp>
Expand All @@ -39,6 +40,10 @@ bool is_convolution(const migraphx::instruction& ins) { return ins.name() == "co
bool is_dot(const migraphx::instruction& ins) { return ins.name() == "dot"; }

void run_pass(migraphx::module& m) { run_passes(m, {migraphx::simplify_qdq{}}); }
void run_cse(migraphx::module& m)
{
run_passes(m, {migraphx::eliminate_common_subexpression{}, migraphx::dead_code_elimination{}});
}

migraphx::instruction_ref broadcast_scale(migraphx::module& m,
migraphx::instruction_ref scale,
Expand Down Expand Up @@ -1419,6 +1424,69 @@ TEST_CASE(dot_correctness)
EXPECT(migraphx::verify::verify_rms_range(rv1, rv2));
}

TEST_CASE(dot_reused)
{
migraphx::shape sh{migraphx::shape::float_type, {256, 256}};

migraphx::module m1;
{
auto x = m1.add_parameter("x", sh);
auto y = m1.add_parameter("y", sh);
auto w1 = m1.add_parameter("w1", sh);
auto w2 = m1.add_parameter("w2", sh);
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0});

auto q1 = add_quantize_op(m1, "quantizelinear", x, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", w1, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot1 = m1.add_instruction(migraphx::make_op("dot"), d1, d2);
auto add1 = m1.add_instruction(migraphx::make_op("add"), dot1, y);

auto q3 = add_quantize_op(m1, "quantizelinear", add1, scale, zero);
auto d3 = add_quantize_op(m1, "dequantizelinear", q3, scale, zero);
auto q4 = add_quantize_op(m1, "quantizelinear", w2, scale, zero);
auto d4 = add_quantize_op(m1, "dequantizelinear", q4, scale, zero);
auto dot2 = m1.add_instruction(migraphx::make_op("dot"), d3, d4);
auto add2 = m1.add_instruction(migraphx::make_op("add"), dot2, add1);
m1.add_return({add2});
}

migraphx::module m2;
{
auto x = m2.add_parameter("x", sh);
auto y = m2.add_parameter("y", sh);
auto w1 = m2.add_parameter("w1", sh);
auto w2 = m2.add_parameter("w2", sh);
auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto zero2 = m2.add_literal(std::int32_t{0});

auto q1 = add_quantize_op(m2, "quantizelinear", x, scale, zero);
auto q2 = add_quantize_op(m2, "quantizelinear", w1, scale, zero);

auto dot1 = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2);
auto out_scale1 = add_scale_mul(m2, scale, scale, 1, 1, sh.lens());
auto d1 = add_quantize_op(m2, "dequantizelinear", dot1, out_scale1, zero2);
auto add1 = m2.add_instruction(migraphx::make_op("add"), d1, y);

auto q3 = add_quantize_op(m2, "quantizelinear", add1, scale, zero);
auto q4 = add_quantize_op(m2, "quantizelinear", w2, scale, zero);
auto dot2 = m2.add_instruction(migraphx::make_op("quant_dot"), q3, q4);
auto out_scale2 = add_scale_mul(m2, scale, scale, 1, 1, sh.lens());
auto d2 = add_quantize_op(m2, "dequantizelinear", dot2, out_scale2, zero2);
auto d3 = add_quantize_op(m2, "dequantizelinear", q3, q3->inputs()[1], q3->inputs()[2]);
auto add2 = m2.add_instruction(migraphx::make_op("add"), d2, d3);
m2.add_return({add2});
}

run_pass(m1);
run_cse(m1);
run_cse(m2);
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(dot_asymmetric_correctness)
{
migraphx::shape sh1{migraphx::shape::float_type, {10, 4}};
Expand Down
Loading

0 comments on commit a739662

Please sign in to comment.