Skip to content

Commit

Permalink
Merge branch 'develop' into fix_parse_dynamicquantizelinear
Browse files Browse the repository at this point in the history
  • Loading branch information
TedThemistokleous authored Mar 18, 2024
2 parents fc2c4e8 + 21b71c6 commit 6dc8ef1
Show file tree
Hide file tree
Showing 8 changed files with 249 additions and 31 deletions.
35 changes: 26 additions & 9 deletions src/targets/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <migraphx/register_op.hpp>
#include <migraphx/env.hpp>
#include <migraphx/algorithm.hpp>
#include <optional>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand Down Expand Up @@ -149,7 +150,7 @@ struct mlir_op
if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs.");

auto type = mod->get_output_shapes().front().type();
auto type = mod->get_output_shapes().front().type();
auto mod_params = mod->get_parameter_names();
std::sort(mod_params.begin(), mod_params.end());
std::unordered_map<instruction_ref, shape> mod_ins_shapes;
Expand Down Expand Up @@ -524,15 +525,17 @@ struct find_mlir_standalone_attention_op
static size_t counter = 0;
module_ref mm = mpm.create_module("mlir_" + std::to_string(counter++));
auto gemm_softmax_gemm = r.instructions["gemm_softmax_gemm"];
std::vector<instruction_ref> inputs;
mm->set_bypass();

std::unordered_map<instruction_ref, instruction_ref> ins_map;
auto gemm0_inputs = gemm_softmax_gemm->inputs();
gemm0_inputs.pop_back();
auto orig_inputs = gemm_softmax_gemm->inputs();

std::vector<instruction_ref> gemm0_inputs = {orig_inputs[0], orig_inputs[1]};
auto [gemm0, top_gemm0_inputs] =
fuse_input_ops_and_gemm_based_op(mm, gemm0_inputs, make_op("dot"));

std::vector<instruction_ref> inputs;
inputs.insert(inputs.begin(), top_gemm0_inputs.begin(), top_gemm0_inputs.end());

// handle scale
auto v = gemm_softmax_gemm->get_operator().to_value();
assert(v.contains("scale"));
Expand All @@ -542,18 +545,31 @@ struct find_mlir_standalone_attention_op
make_op("multibroadcast", {{"out_lens", gemm0->get_shape().lens()}}), scale_lit);
auto scaled_gemm0 = mm->add_instruction(make_op("mul"), gemm0, scale_lit_mbcast);

std::optional<instruction_ref> bias{nullopt};
if(orig_inputs.size() == 4)
{
auto bias_input = orig_inputs[2];
instruction_ref bias_param =
mm->add_parameter("y_bias", bias_input->get_shape().as_standard());
bias = mm->add_instruction(make_op("add"), scaled_gemm0, bias_param);
inputs.push_back(bias_input);
}

auto softmax = mm->add_instruction(
make_op("softmax", {{"axis", gemm0->get_shape().lens().size() - 1}}), scaled_gemm0);
auto [old_upper_v, upper_v_op_stream] =
get_fusable_input_op_stream(gemm_softmax_gemm->inputs()[2]);
make_op("softmax", {{"axis", gemm0->get_shape().lens().size() - 1}}),
bias ? bias.value() : scaled_gemm0);
auto [old_upper_v, upper_v_op_stream] = get_fusable_input_op_stream(orig_inputs.back());
instruction_ref new_upper_v =
mm->add_parameter("z", old_upper_v->get_shape().as_standard());
for(const auto& op : reverse(upper_v_op_stream))
{
new_upper_v = mm->add_instruction(op, {new_upper_v});
}
inputs.push_back(old_upper_v);
auto gemm1 = mm->add_instruction(make_op("dot"), {softmax, new_upper_v});

auto gemm1 = mm->add_instruction(make_op("dot"), {softmax, new_upper_v});

std::unordered_map<instruction_ref, instruction_ref> ins_map;
ins_map[gemm_softmax_gemm] = gemm1;
auto ins_to_replace = gemm1;
auto ins_to_be_replaced = gemm_softmax_gemm;
Expand All @@ -567,6 +583,7 @@ struct find_mlir_standalone_attention_op
ins_to_be_replaced = r.instructions["trailing_pm"];
}
mm->add_return({ins_to_replace});

mpm.get_module().replace_instruction(
ins_to_be_replaced, mlir_op{gemm1->get_operator()}, inputs, {mm});
}
Expand Down
12 changes: 8 additions & 4 deletions src/targets/gpu/fuse_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -767,11 +767,15 @@ struct find_contiguous
}
};

struct find_pointwise_contiguous
struct find_pointwise_layout_contiguous
{
auto matcher() const
{
return match::name("gpu::contiguous")(match::arg(0)(precompile_name("pointwise")));
auto is_layout = precompile_name("layout")(
match::arg(0)(match::used_once(), precompile_name("pointwise")));
auto is_contiguous = match::name("gpu::contiguous")(
match::arg(0)(match::used_once(), precompile_name("pointwise")));
return match::any_of(is_layout, is_contiguous);
}

void apply(module& m, const match::matcher_result& r) const
Expand All @@ -782,7 +786,7 @@ struct find_pointwise_contiguous
auto args = pw->inputs();
args.back() = alloc;

// Ensure the output shape of the pointwise module is contiguous
// Ensure the output shape of the pointwise module retains the memory layout
auto pw_op_val = pw->get_operator().to_value();
pw_op_val["output_shape"] = to_value(ins->get_shape());

Expand Down Expand Up @@ -852,7 +856,7 @@ struct find_concat_pointwise

void fuse_ops::apply(module& m) const
{
match::find_matches(m, find_pointwise_contiguous{});
match::find_matches(m, find_pointwise_layout_contiguous{});
run_passes(m, {dead_code_elimination{}});
match::find_matches(m, find_conv_pointwise{ctx}, find_conv_bias_relu{ctx}, find_conv_bias{ctx});
run_passes(m, {dead_code_elimination{}});
Expand Down
27 changes: 22 additions & 5 deletions src/targets/gpu/include/migraphx/gpu/gemm_softmax_gemm.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -26,6 +26,7 @@

#include <migraphx/make_op.hpp>
#include <migraphx/check_shapes.hpp>
#include <sstream>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand Down Expand Up @@ -55,14 +56,30 @@ struct gemm_softmax_gemm
check_shapes{inputs, *this}.same_ndims();
if(inputs.size() < 3)
MIGRAPHX_THROW(name() + ": Expected 3 inputs but got " + to_string(inputs.size()));
auto a = inputs[0];
auto b = inputs[1];
auto b1 = inputs[2];

const bool is_bias_enabled = inputs.size() == 4;
auto a = inputs[0];
auto b = inputs[1];
auto b1 = inputs[is_bias_enabled ? 3 : 2];

for(const auto& input : inputs)
{
check_gemm_shape(input);
}
return op.compute_shape({op.compute_shape({a, b}), b1});
auto gemm0_shape = op.compute_shape({a, b});
if(is_bias_enabled)
{
auto bias_shape = inputs[2];
if(bias_shape.lens() != gemm0_shape.lens())
{
std::stringstream err_msg;
err_msg << name() << ": has inconsistent bias size"
<< ". Expected: " << gemm0_shape << ". Given: " << bias_shape;
MIGRAPHX_THROW(err_msg.str());
}
}

return op.compute_shape({gemm0_shape, b1});
}

static bool is_ck_supported_type(shape::type_t t) { return contains({shape::half_type}, t); }
Expand Down
23 changes: 21 additions & 2 deletions src/targets/gpu/prefuse_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/matcher.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/gpu/gemm_softmax_gemm.hpp>
Expand Down Expand Up @@ -171,6 +172,17 @@ auto is_test_gemm(bool enable_attention)
});
}

auto is_bias_supported()
{
return match::make_basic_pred_matcher([=](instruction_ref) {
#ifdef MIGRAPHX_USE_COMPOSABLEKERNEL
return not enabled(MIGRAPHX_ENABLE_CK{});
#else
return true;
#endif
});
}

struct find_gemm_softmax_gemm
{
bool enable_attention = false;
Expand All @@ -182,8 +194,11 @@ struct find_gemm_softmax_gemm
.bind("gemm1")));
auto mul = match::name("mul")(
match::nargs(2), match::either_arg(0, 1)(match::is_constant().bind("scale"), gemm1));
auto add = match::name("add")(is_bias_supported(),
match::nargs(2),
match::any_arg(0, 1)(match::none_of(mul).bind("bias")));
auto softmax =
match::name("softmax")(match::arg(0)(match::any_of(mul, gemm1))).bind("softmax");
match::name("softmax")(match::arg(0)(match::any_of(mul, add, gemm1))).bind("softmax");

return match::name("dot")(
match::any_of(is_ck_gemm(), is_mlir_gemm(), is_test_gemm(enable_attention))
Expand All @@ -210,7 +225,11 @@ struct find_gemm_softmax_gemm
});
}

auto inputs = gemm1_ins->inputs(); // A, B
auto inputs = gemm1_ins->inputs(); // A, B
if(contains(r.instructions, "bias"))
{
inputs.push_back(r.instructions["bias"]);
}
inputs.push_back(gemm2_ins->inputs().back()); // B1

mpm.get_module().replace_instruction(
Expand Down
83 changes: 83 additions & 0 deletions test/gpu/fuse_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,89 @@ TEST_CASE(pointwise_contiguous)
EXPECT(p1 == p2);
}

TEST_CASE(pointwise_layout_convolution)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 320, 128, 128}};
migraphx::shape s2{migraphx::shape::float_type, {320, 320, 3, 3}, {2880, 1, 960, 320}};
migraphx::shape s3{migraphx::shape::float_type, {2, 320, 128, 128}, {5242880, 1, 40960, 320}};
// workspace for gpu::convolution, memory space can change based on gfx arch and rocm version,
// For the unit-test just use some random number.
migraphx::shape s4{migraphx::shape::int8_type, {41943040}};

auto create_program = [=]() {
migraphx::program p;
auto* mm = p.get_main_module();
auto x1 = mm->add_parameter("x1", s1);
auto x2 = mm->add_parameter("x2", s1);
auto weights = mm->add_parameter("weights", s2);
auto alloc = migraphx::make_op("allocate", {{"shape", to_value(s1)}});
auto alloc_ins = mm->add_instruction(alloc);
auto* pwm = create_pointwise_module(
p, "main:pointwise0", {x1, x2}, [=](auto* pm, const auto& inputs) {
auto mul_ins = pm->add_instruction(migraphx::make_op("mul"), inputs[0], inputs[1]);
return pm->add_instruction(migraphx::make_op("sigmoid"), mul_ins);
});
auto pw_ins =
mm->add_instruction(make_precompile_op("pointwise"), {x1, x2, alloc_ins}, {pwm});

auto alloc_ins2 =
mm->add_instruction(migraphx::make_op("allocate", {{"shape", to_value(s3)}}));
auto layout_op = migraphx::make_op("layout", {{"permutation", {0, 2, 3, 1}}});
auto layout_ins = mm->add_instruction(make_precompile_op(layout_op), {pw_ins, alloc_ins2});
auto conv_op = migraphx::make_op("convolution", {{"padding", {1, 1, 1, 1}}});
auto alloc_ins3 =
mm->add_instruction(migraphx::make_op("allocate", {{"shape", to_value(s4)}}));
auto alloc_ins4 =
mm->add_instruction(migraphx::make_op("allocate", {{"shape", to_value(s3)}}));
auto conv =
mm->add_instruction(migraphx::make_op("gpu::convolution", {{"op", conv_op.to_value()}}),
layout_ins,
weights,
alloc_ins3,
alloc_ins4);
mm->add_return({conv});
return p;
};
auto create_fused_program = [=]() {
migraphx::program p;
auto* mm = p.get_main_module();
auto x1 = mm->add_parameter("x1", s1);
auto x2 = mm->add_parameter("x2", s1);
auto weights = mm->add_parameter("weights", s2);
auto alloc = migraphx::make_op("allocate", {{"shape", to_value(s3)}});
auto alloc_ins = mm->add_instruction(alloc);
auto* pwm = create_pointwise_module(
p, "main:pointwise0", {x1, x2}, [=](auto* pm, const auto& inputs) {
auto mul_ins = pm->add_instruction(migraphx::make_op("mul"), inputs[0], inputs[1]);
return pm->add_instruction(migraphx::make_op("sigmoid"), mul_ins);
});
auto pw_op = migraphx::make_op("pointwise");
auto pre_comp_op = migraphx::make_op(
"gpu::precompile_op",
{{"op", migraphx::to_value(pw_op)}, {"output_shape", migraphx::to_value(s3)}});

auto pw_ins = mm->add_instruction(pre_comp_op, {x1, x2, alloc_ins}, {pwm});

auto conv_op = migraphx::make_op("convolution", {{"padding", {1, 1, 1, 1}}});
auto alloc_ins2 =
mm->add_instruction(migraphx::make_op("allocate", {{"shape", to_value(s4)}}));
auto alloc_ins3 =
mm->add_instruction(migraphx::make_op("allocate", {{"shape", to_value(s3)}}));
auto conv =
mm->add_instruction(migraphx::make_op("gpu::convolution", {{"op", conv_op.to_value()}}),
pw_ins,
weights,
alloc_ins2,
alloc_ins3);
mm->add_return({conv});
return p;
};
migraphx::program p1 = create_program();
run_pass(p1);
migraphx::program p2 = create_fused_program();
EXPECT(p1 == p2);
}

TEST_CASE(concat_pointwise_contiguous)
{
migraphx::shape s1 = migraphx::shape::from_permutation(
Expand Down
2 changes: 1 addition & 1 deletion test/onnx/.onnxrt-commit
Original file line number Diff line number Diff line change
@@ -1 +1 @@
01c376a0b9ebd251d5712fa14a448335a2bde780
0b2a75b274e45c7a510bfdae9071a97a69e75618
43 changes: 33 additions & 10 deletions test/verify/gemm_softmax_gemm_relu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,54 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>

struct gemm_softmax_gemm_relu : verify_program<gemm_softmax_gemm_relu>
enum class bias
{
without,
with,
with_standard_shape
};

template <bias Config>
struct gemm_softmax_gemm_relu : verify_program<gemm_softmax_gemm_relu<Config>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 256}};
migraphx::shape m2_shape{migraphx::shape::half_type, {1, 12, 256, 256}};
auto m2_elements = m2_shape.elements();
auto m2_elements = m1_shape.elements();
auto a = mm->add_parameter("1", m1_shape);
auto b = mm->add_parameter("2", m1_shape);
auto b1 = mm->add_parameter("3", m1_shape);
std::vector<float> eights(m2_elements, 0.125);
auto eight = mm->add_literal(migraphx::literal{m2_shape, eights});
std::vector<float> zeros(m2_elements, 0);
auto zero = mm->add_literal(migraphx::literal{m2_shape, zeros});
auto eight = mm->add_literal(migraphx::literal{m1_shape, eights});

b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
auto scale = mm->add_instruction(migraphx::make_op("mul"), gemm1, eight);
auto bias = mm->add_instruction(migraphx::make_op("add"), scale, zero);
auto softmax = mm->add_instruction(migraphx::make_op("softmax", {{"axis", 3}}), bias);
auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
auto scale = mm->add_instruction(migraphx::make_op("mul"), gemm1, eight);

std::optional<migraphx::instruction_ref> add_bias{std::nullopt};
if constexpr(Config == bias::with or Config == bias::with_standard_shape)
{
auto bias_shape = m1_shape;
if(Config != bias::with_standard_shape)
{
bias_shape = migraphx::shape::from_permutation(
bias_shape.type(), bias_shape.lens(), {0, 1, 3, 2});
}
auto bias_term = mm->add_parameter("4", bias_shape);
add_bias = mm->add_instruction(migraphx::make_op("add"), scale, bias_term);
}

auto softmax = mm->add_instruction(migraphx::make_op("softmax", {{"axis", 3}}),
Config == bias::without ? scale : add_bias.value());
auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), softmax, b1);
mm->add_instruction(migraphx::make_op("relu"), gemm2);
return p;
}
std::string section() const { return "gemm"; }
};

template struct gemm_softmax_gemm_relu<bias::without>;
template struct gemm_softmax_gemm_relu<bias::with>;
template struct gemm_softmax_gemm_relu<bias::with_standard_shape>;
Loading

0 comments on commit 6dc8ef1

Please sign in to comment.