Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable MLIR by default for more cases #2274

Merged
merged 18 commits into from
Oct 16, 2023
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 85 additions & 79 deletions src/targets/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,13 @@ struct module;
namespace gpu {

MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wanted to note the agreement from yesterday's meeting that this becomes MIGRAPHX_ENABLE_EXTRA_MLIR

MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MLIR);

bool mlir_enabled()
{
#ifdef MIGRAPHX_MLIR
const bool mlir_enabled = enabled(MIGRAPHX_ENABLE_MLIR{});
if(mlir_enabled)
{
return true;
}
else
{

std::cerr << "WARNING: MIGraphX built with MLIR but it is not enabled. Please set the env "
"var MIGRAPHX_ENABLE_MLIR to use MLIR kernel generator."
<< std::endl;
return false;
}
const bool mlir_disabled = enabled(MIGRAPHX_DISABLE_MLIR{});
return not mlir_disabled;
#else
return false;
#endif
Expand Down Expand Up @@ -150,27 +140,70 @@ fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op)
return {new_gemm_based_op, top_inputs};
}

MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins)
enum class mlir_mode
{
if(ins->name() != "convolution" and ins->name() != "quant_convolution")
return false;
value v = ins->get_operator().to_value();
auto group = v.at("group").to<int>();
if(group != 1)
return false;
// Avoid MLIR assertion: Index < Length && "Invalid index!"
if(ins->get_shape().lens().size() != 4)
return false;
return true;
all,
int8,
fast,
none
};

auto is_mlir_dot(mlir_mode mode)
{
return match::make_basic_pred_matcher([=](instruction_ref ins) {
if(mode == mlir_mode::none)
return false;
if(ins->name() != "dot" and ins->name() != "quant_dot")
return false;
if(mode != mlir_mode::fast)
return true;
auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape();
// auto m = a.lens()[a.lens().size() - 2];
// auto n = b.lens().back();
auto k = a.lens().back();
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// to avoid poor-performing GEMM kernels from MLIR
// To-do: Investigate a more precise strategy
return k <= 2048;
});
}

auto is_mlir_conv(mlir_mode mode)
{
return match::make_basic_pred_matcher([=](instruction_ref ins) {
if(mode == mlir_mode::none)
return false;
if(ins->name() != "convolution" and ins->name() != "quant_convolution")
return false;
value v = ins->get_operator().to_value();
auto group = v.at("group").to<int>();
if(group != 1)
return false;
// Avoid MLIR assertion: Index < Length && "Invalid index!"
if(ins->get_shape().lens().size() != 4)
return false;
if(ins->get_shape().type() == shape::int8_type)
return true;
if(mode != mlir_mode::fast)
return true;
auto w = ins->inputs().at(1)->get_shape();
if(w.lens().size() != 4)
return true;
if(w.lens()[2] != w.lens()[3])
return true;
return (w.lens()[3] % 3) != 0;
});
}

struct find_mlir_fused_ops
{
mlir_mode conv_mode = mlir_mode::none;
mlir_mode dot_mode = mlir_mode::none;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: shouldn't we have an explicit constructor? Or /*conv_mode=*/ comments when we create this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me see if designated initializers work here(an explicit constructor wont allow that though).

auto matcher() const
{
auto dot_or_conv = match::skip(match::name("contiguous"))(
match::any_of(match::name("dot"), match::name("quant_dot"), is_mlir_conv())
.bind("gemm_based_op"));
match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode)).bind("gemm_based_op"));
return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x")));
}

Expand Down Expand Up @@ -302,8 +335,11 @@ struct find_mlir_fused_ops
}
};

template <auto Matcher>
struct find_mlir_standalone_op
{
mlir_mode mode = mlir_mode::none;
auto matcher() const { return Matcher(mode); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto conv_based_op = r.result;
Expand All @@ -325,15 +361,8 @@ struct find_mlir_standalone_op
}
};

struct find_mlir_standalone_convolution_op : find_mlir_standalone_op
{
auto matcher() const { return is_mlir_conv; }
};

struct find_mlir_standalone_dot_op : find_mlir_standalone_op
{
auto matcher() const { return match::any_of(match::name("dot"), match::name("quant_dot")); }
};
using find_mlir_standalone_convolution_op = find_mlir_standalone_op<&is_mlir_conv>;
using find_mlir_standalone_dot_op = find_mlir_standalone_op<&is_mlir_dot>;

/**
* @brief Declares a new MIGraphX environment variable which forces to generate
Expand All @@ -347,65 +376,42 @@ struct find_mlir_standalone_dot_op : find_mlir_standalone_op
* intended to be primarily used by rocMLIR developers.
*/
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_USE_SPECIFIC_OPS);
bool is_self_decide() { return string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, "").empty(); }

bool is_requested(std::string_view option)
bool is_requested(std::string_view option, bool fallback = false)
{
assert(not is_self_decide());
auto string_value = string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, "");
if(string_value.empty())
return fallback;
const auto options = split_string(string_value, ',');
return contains(options, option);
}

bool is_enabled(std::string_view op_name, context* ctx)
{
if(is_self_decide())
{
if(op_name == "fused")
{
return true;
}
else if(op_name == "convolution" or op_name == "quant_convolution")
{
if(ctx == nullptr)
{
return false;
}
else
{
const auto& device = ctx->get_current_device();
const std::string navi_family{"gfx110"};
return starts_with(device.get_gfx_name(), navi_family);
}
}
else
{
return false;
}
}
return is_requested(op_name);
}
} // namespace

#endif // MIGRAPHX_MLIR

void fuse_mlir::apply(module_pass_manager& mpm) const
{
#ifdef MIGRAPHX_MLIR
if(is_enabled("fused", this->ctx))
{
match::find_matches(mpm, find_mlir_fused_ops{});
}
const auto& device = ctx->get_current_device();
const std::string navi_family{"gfx110"};
const bool is_navi = starts_with(device.get_gfx_name(), navi_family);

if(is_enabled("convolution", this->ctx))
{
match::find_matches(mpm, find_mlir_standalone_convolution_op{});
}
auto get_mode = [&](std::string_view option, mlir_mode m1, mlir_mode m2 = mlir_mode::fast) {
if(is_requested(option))
return mlir_mode::all;
if(is_navi)
return mlir_mode::all;
return std::max(m1, m2);
};

if(is_enabled("dot", this->ctx))
{
match::find_matches(mpm, find_mlir_standalone_dot_op{});
}
mlir_mode mode = enabled(MIGRAPHX_ENABLE_MLIR{}) ? mlir_mode::fast : mlir_mode::none;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hold on, since we now have MIGRAPHX_DISABLE_MLIR, shouldn't this be keying off of _DISABLE_MLIR and not _ENABLE_MLIR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, DISABLE_MLIR will disable MLIR completely(and its already handled in target.cpp), whereas ENABLE_MLIR will enable it for gemm fusions(because it is not enabled by default because its not always faster).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah. So maybe we should rename it to something like MIGRAPHX_ENABLE_MLIR_GEMM_FUSION because it's a very confusing variable name


match::find_matches(
mpm, find_mlir_fused_ops{get_mode("fused", mlir_mode::fast), get_mode("fused", mode)});
match::find_matches(
mpm,
find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::int8)},
find_mlir_standalone_dot_op{get_mode("dot", mlir_mode::none)});
#else
(void)mpm;
#endif
Expand Down