-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable MLIR by default for more cases #2274
Changes from 4 commits
514e053
a86bf5d
30e509d
3e313ec
14e1833
8000221
4174d21
9500c0d
c133898
08da3dd
c69f5a1
b9a7baa
e6cb212
91c1fd1
dd6ff56
07cbee9
e9d978a
249b21d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,23 +37,13 @@ struct module; | |
namespace gpu { | ||
|
||
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR); | ||
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MLIR); | ||
|
||
bool mlir_enabled() | ||
{ | ||
#ifdef MIGRAPHX_MLIR | ||
const bool mlir_enabled = enabled(MIGRAPHX_ENABLE_MLIR{}); | ||
if(mlir_enabled) | ||
{ | ||
return true; | ||
} | ||
else | ||
{ | ||
|
||
std::cerr << "WARNING: MIGraphX built with MLIR but it is not enabled. Please set the env " | ||
"var MIGRAPHX_ENABLE_MLIR to use MLIR kernel generator." | ||
<< std::endl; | ||
return false; | ||
} | ||
const bool mlir_disabled = enabled(MIGRAPHX_DISABLE_MLIR{}); | ||
return not mlir_disabled; | ||
#else | ||
return false; | ||
#endif | ||
|
@@ -150,27 +140,70 @@ fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op) | |
return {new_gemm_based_op, top_inputs}; | ||
} | ||
|
||
MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins) | ||
enum class mlir_mode | ||
{ | ||
if(ins->name() != "convolution" and ins->name() != "quant_convolution") | ||
return false; | ||
value v = ins->get_operator().to_value(); | ||
auto group = v.at("group").to<int>(); | ||
if(group != 1) | ||
return false; | ||
// Avoid MLIR assertion: Index < Length && "Invalid index!" | ||
if(ins->get_shape().lens().size() != 4) | ||
return false; | ||
return true; | ||
all, | ||
int8, | ||
fast, | ||
none | ||
}; | ||
|
||
auto is_mlir_dot(mlir_mode mode) | ||
{ | ||
return match::make_basic_pred_matcher([=](instruction_ref ins) { | ||
if(mode == mlir_mode::none) | ||
return false; | ||
if(ins->name() != "dot" and ins->name() != "quant_dot") | ||
return false; | ||
if(mode != mlir_mode::fast) | ||
return true; | ||
auto a = ins->inputs().front()->get_shape(); | ||
auto b = ins->inputs().back()->get_shape(); | ||
// auto m = a.lens()[a.lens().size() - 2]; | ||
// auto n = b.lens().back(); | ||
auto k = a.lens().back(); | ||
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy | ||
// to avoid poor-performing GEMM kernels from MLIR | ||
// To-do: Investigate a more precise strategy | ||
return k <= 2048; | ||
}); | ||
} | ||
|
||
auto is_mlir_conv(mlir_mode mode) | ||
{ | ||
return match::make_basic_pred_matcher([=](instruction_ref ins) { | ||
if(mode == mlir_mode::none) | ||
return false; | ||
if(ins->name() != "convolution" and ins->name() != "quant_convolution") | ||
return false; | ||
value v = ins->get_operator().to_value(); | ||
auto group = v.at("group").to<int>(); | ||
if(group != 1) | ||
return false; | ||
// Avoid MLIR assertion: Index < Length && "Invalid index!" | ||
if(ins->get_shape().lens().size() != 4) | ||
return false; | ||
if(ins->get_shape().type() == shape::int8_type) | ||
return true; | ||
if(mode != mlir_mode::fast) | ||
return true; | ||
auto w = ins->inputs().at(1)->get_shape(); | ||
if(w.lens().size() != 4) | ||
return true; | ||
if(w.lens()[2] != w.lens()[3]) | ||
return true; | ||
return (w.lens()[3] % 3) != 0; | ||
}); | ||
} | ||
|
||
struct find_mlir_fused_ops | ||
{ | ||
mlir_mode conv_mode = mlir_mode::none; | ||
mlir_mode dot_mode = mlir_mode::none; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: shouldn't we have an explicit constructor? Or There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let me see if designated initializers work here(an explicit constructor wont allow that though). |
||
auto matcher() const | ||
{ | ||
auto dot_or_conv = match::skip(match::name("contiguous"))( | ||
match::any_of(match::name("dot"), match::name("quant_dot"), is_mlir_conv()) | ||
.bind("gemm_based_op")); | ||
match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode)).bind("gemm_based_op")); | ||
return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x"))); | ||
} | ||
|
||
|
@@ -302,8 +335,11 @@ struct find_mlir_fused_ops | |
} | ||
}; | ||
|
||
template <auto Matcher> | ||
struct find_mlir_standalone_op | ||
{ | ||
mlir_mode mode = mlir_mode::none; | ||
auto matcher() const { return Matcher(mode); } | ||
void apply(module_pass_manager& mpm, const match::matcher_result& r) const | ||
{ | ||
auto conv_based_op = r.result; | ||
|
@@ -325,15 +361,8 @@ struct find_mlir_standalone_op | |
} | ||
}; | ||
|
||
struct find_mlir_standalone_convolution_op : find_mlir_standalone_op | ||
{ | ||
auto matcher() const { return is_mlir_conv; } | ||
}; | ||
|
||
struct find_mlir_standalone_dot_op : find_mlir_standalone_op | ||
{ | ||
auto matcher() const { return match::any_of(match::name("dot"), match::name("quant_dot")); } | ||
}; | ||
using find_mlir_standalone_convolution_op = find_mlir_standalone_op<&is_mlir_conv>; | ||
using find_mlir_standalone_dot_op = find_mlir_standalone_op<&is_mlir_dot>; | ||
|
||
/** | ||
* @brief Declares a new MIGraphX environment variable which forces to generate | ||
|
@@ -347,65 +376,42 @@ struct find_mlir_standalone_dot_op : find_mlir_standalone_op | |
* intended to be primarily used by rocMLIR developers. | ||
*/ | ||
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_USE_SPECIFIC_OPS); | ||
bool is_self_decide() { return string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, "").empty(); } | ||
|
||
bool is_requested(std::string_view option) | ||
bool is_requested(std::string_view option, bool fallback = false) | ||
{ | ||
assert(not is_self_decide()); | ||
auto string_value = string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, ""); | ||
if(string_value.empty()) | ||
return fallback; | ||
const auto options = split_string(string_value, ','); | ||
return contains(options, option); | ||
} | ||
|
||
bool is_enabled(std::string_view op_name, context* ctx) | ||
{ | ||
if(is_self_decide()) | ||
{ | ||
if(op_name == "fused") | ||
{ | ||
return true; | ||
} | ||
else if(op_name == "convolution" or op_name == "quant_convolution") | ||
{ | ||
if(ctx == nullptr) | ||
{ | ||
return false; | ||
} | ||
else | ||
{ | ||
const auto& device = ctx->get_current_device(); | ||
const std::string navi_family{"gfx110"}; | ||
return starts_with(device.get_gfx_name(), navi_family); | ||
} | ||
} | ||
else | ||
{ | ||
return false; | ||
} | ||
} | ||
return is_requested(op_name); | ||
} | ||
} // namespace | ||
|
||
#endif // MIGRAPHX_MLIR | ||
|
||
void fuse_mlir::apply(module_pass_manager& mpm) const | ||
{ | ||
#ifdef MIGRAPHX_MLIR | ||
if(is_enabled("fused", this->ctx)) | ||
{ | ||
match::find_matches(mpm, find_mlir_fused_ops{}); | ||
} | ||
const auto& device = ctx->get_current_device(); | ||
const std::string navi_family{"gfx110"}; | ||
const bool is_navi = starts_with(device.get_gfx_name(), navi_family); | ||
|
||
if(is_enabled("convolution", this->ctx)) | ||
{ | ||
match::find_matches(mpm, find_mlir_standalone_convolution_op{}); | ||
} | ||
auto get_mode = [&](std::string_view option, mlir_mode m1, mlir_mode m2 = mlir_mode::fast) { | ||
if(is_requested(option)) | ||
return mlir_mode::all; | ||
if(is_navi) | ||
return mlir_mode::all; | ||
return std::max(m1, m2); | ||
}; | ||
|
||
if(is_enabled("dot", this->ctx)) | ||
{ | ||
match::find_matches(mpm, find_mlir_standalone_dot_op{}); | ||
} | ||
mlir_mode mode = enabled(MIGRAPHX_ENABLE_MLIR{}) ? mlir_mode::fast : mlir_mode::none; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hold on, since we now have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, DISABLE_MLIR will disable MLIR completely(and its already handled in target.cpp), whereas ENABLE_MLIR will enable it for gemm fusions(because it is not enabled by default because its not always faster). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah. So maybe we should rename it to something like |
||
|
||
match::find_matches( | ||
mpm, find_mlir_fused_ops{get_mode("fused", mlir_mode::fast), get_mode("fused", mode)}); | ||
match::find_matches( | ||
mpm, | ||
find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::int8)}, | ||
find_mlir_standalone_dot_op{get_mode("dot", mlir_mode::none)}); | ||
#else | ||
(void)mpm; | ||
#endif | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wanted to note the agreement from yesterday's meeting that this becomes
MIGRAPHX_ENABLE_EXTRA_MLIR