Skip to content

Commit

Permalink
[MLIR] Updates needed for general stride support (#2428)
Browse files Browse the repository at this point in the history
  • Loading branch information
krzysz00 authored Dec 4, 2023
1 parent 609ad96 commit e3e0054
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 84 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build
msgpack/[email protected] -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/composable_kernel@70eefcf4f263aa5c25f3c9ff0db8f6f199ef0fb9 -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/rocMLIR@9e66e8050209f03349a41b6b497f0da2b285a53b -DBUILD_FAT_LIBROCKCOMPILER=On
ROCmSoftwarePlatform/rocMLIR@a6880f1e6daec99876cd6a4820fbc69c57216401 -DBUILD_FAT_LIBROCKCOMPILER=On
18 changes: 11 additions & 7 deletions src/targets/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,7 @@ struct mlir_op
}
if(ins->name() == "@return")
{
auto s = ins_shapes[ins->inputs().at(0)].with_type(type);
if(not s.standard())
MIGRAPHX_THROW("MLIR doesnt support non-standard output");
return s;
return ins_shapes[ins->inputs().at(0)].with_type(type);
}
std::vector<shape> input_shapes;
input_shapes.resize(ins->inputs().size());
Expand All @@ -139,9 +136,16 @@ get_fusable_input_op_stream(instruction_ref lower_input)
{
instruction_ref upper_input = lower_input;
std::vector<operation> op_stream;
while(
contains({"slice", "transpose", "contiguous", "reshape", "squeeze", "flatten", "unsqueeze"},
upper_input->name()))
while(contains({"slice",
"transpose",
"multibroadcast",
"broadcast",
"contiguous",
"reshape",
"squeeze",
"flatten",
"unsqueeze"},
upper_input->name()))
{
operation op = upper_input->get_operator();
if(contains({"squeeze", "flatten", "unsqueeze"}, upper_input->name()))
Expand Down
62 changes: 15 additions & 47 deletions src/targets/gpu/mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
#include <mlir-c/Pass.h>
#include <mlir-c/Support.h>
#include <mutex>
#if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION != 3
#if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION != 4
#warning "Incompatible version of rocMLIR library used, disabling"
// Only undefine when not using cppcheck
#ifndef CPPCHECK
Expand Down Expand Up @@ -319,31 +319,30 @@ struct mlir_program
return result;
}

MlirType make_tensor(const shape& s) const
MlirType make_mlir_shaped(const shape& s) const
{
if(not s.standard())
MIGRAPHX_THROW("MLIR expects all tensors to be in standard shape");
if(s.dynamic())
MIGRAPHX_THROW("MLIR does not support dynamic shapes");
std::vector<int64_t> lens(s.lens().begin(), s.lens().end());
return mlirRankedTensorTypeGet(
lens.size(), lens.data(), make_type(s.type()), mlirAttributeGetNull());
std::vector<int64_t> strides(s.strides().begin(), s.strides().end());
return rocmlirMIXRShapedTypeGet(
lens.size(), lens.data(), strides.data(), make_type(s.type()));
}

template <class Range>
std::vector<MlirType> make_tensors(const Range& r)
std::vector<MlirType> make_mlir_shapeds(const Range& r)
{
std::vector<MlirType> result;
std::transform(r.begin(), r.end(), std::back_inserter(result), [&](const auto& s) {
return make_tensor(s);
return make_mlir_shaped(s);
});
return result;
}

MlirType make_function_type(const std::vector<shape>& inputs, const std::vector<shape>& outputs)
{
auto in = make_tensors(inputs);
auto out = make_tensors(outputs);
auto in = make_mlir_shapeds(inputs);
auto out = make_mlir_shapeds(outputs);
return mlirFunctionTypeGet(ctx.get(), in.size(), in.data(), out.size(), out.data());
}

Expand Down Expand Up @@ -505,11 +504,7 @@ struct mlir_program

mlir_operation_state& add_results(const std::vector<shape>& outputs)
{
std::vector<shape> reshaped(outputs.size());
std::transform(outputs.begin(), outputs.end(), reshaped.begin(), [](const shape& r) {
return shape{r.type(), r.lens()};
});
auto x = prog->make_tensors(reshaped);
auto x = prog->make_mlir_shapeds(outputs);
if(not x.empty())
{
mlirOperationStateAddResults(&op_state, x.size(), x.data());
Expand Down Expand Up @@ -582,7 +577,7 @@ struct mlir_program
std::vector<shape> outputs = m.get_output_shapes();

std::vector<MlirLocation> arg_locs(inputs.size(), location);
auto body_inputs = make_tensors(inputs);
auto body_inputs = make_mlir_shapeds(inputs);
mlir_region region = mlirRegionCreate();
mlir_block fbody = mlirBlockCreate(body_inputs.size(), body_inputs.data(), arg_locs.data());
MlirBlock result = fbody.get();
Expand All @@ -608,7 +603,7 @@ struct mlir_program
return "func.return";
if(ins->name() == "@literal")
{
return "tosa.const";
return "migraphx.literal";
}
return "migraphx." + ins->name();
}
Expand Down Expand Up @@ -667,7 +662,8 @@ struct mlir_program
if(ins->name() == "@literal")
{
literal r = ins->get_literal();
MlirType tensor_type = make_tensor(ins->get_shape());
MlirType shaped_type = make_mlir_shaped(ins->get_shape());
MlirType tensor_type = rocmlirMIXRShapedTypeAsTensor(shaped_type);
MlirAttribute mlir_value_attr =
mlirDenseElementsAttrRawBufferGet(tensor_type, r.get_shape().bytes(), r.data());
ops.add_attributes({{"value", mlir_value_attr}});
Expand Down Expand Up @@ -945,35 +941,7 @@ void adjust_param_shapes(module& m, const std::vector<shape>& inputs)
auto param = m.get_parameter(name);
if(input.standard())
continue;
auto lens = input.lens();
auto strides = input.strides();
std::vector<operation> ops;
if(input.transposed())
{
auto perm = find_permutation(input);
auto iperm = invert_permutation(perm);
lens = reorder_dims(lens, iperm);
strides = reorder_dims(strides, iperm);
ops.push_back(make_op("transpose", {{"permutation", perm}}));
}
if(input.broadcasted())
{
std::transform(lens.begin(),
lens.end(),
strides.begin(),
lens.begin(),
[](auto len, auto stride) -> std::size_t {
if(stride == 0)
return 1;
return len;
});
ops.push_back(make_op("multibroadcast", {{"out_lens", input.lens()}}));
}
auto new_param =
std::accumulate(ops.begin(),
ops.end(),
m.add_parameter(name + ".0", shape{input.type(), lens}),
[&](auto x, auto op) { return m.insert_instruction(param, op, x); });
auto new_param = m.add_parameter(name + ".0", input);
m.replace_instruction(param, new_param);
m.remove_instruction(param);
}
Expand Down
81 changes: 52 additions & 29 deletions test/gpu/mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@ TEST_CASE(conv)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @mlir_convolution(%arg0: tensor<2x8x3x3xf32>, %arg1: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.convolution(%arg1, %arg0) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32>
return %0 : tensor<1x2x2x2xf32>
func.func @mlir_convolution(%arg0: !migraphx.shaped<2x8x3x3xf32, 72x9x3x1>, %arg1: !migraphx.shaped<1x8x4x4xf32, 128x16x4x1>) -> !migraphx.shaped<1x2x2x2xf32, 8x4x2x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xf32, 128x16x4x1>, <2x8x3x3xf32, 72x9x3x1> -> <1x2x2x2xf32, 8x4x2x1>
return %0 : !migraphx.shaped<1x2x2x2xf32, 8x4x2x1>
}
}
)__migraphx__";
Expand All @@ -160,15 +160,38 @@ module {
EXPECT(verify_mlir(m));
}

TEST_CASE(conv_nhwc)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @mlir_convolution(%arg0: !migraphx.shaped<2x8x3x3xf32, 72x1x24x8>, %arg1: !migraphx.shaped<1x8x4x4xf32, 128x1x32x8>) -> !migraphx.shaped<1x2x2x2xf32, 8x1x4x2> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xf32, 128x1x32x8>, <2x8x3x3xf32, 72x1x24x8> -> <1x2x2x2xf32, 8x1x4x2>
return %0 : !migraphx.shaped<1x2x2x2xf32, 8x1x4x2>
}
}
)__migraphx__";
migraphx::module m;
auto x = m.add_parameter("x", {migraphx::shape::float_type, {1, 8, 4, 4}, {128, 1, 32, 8}});
auto w = m.add_parameter("w", {migraphx::shape::float_type, {2, 8, 3, 3}, {72, 1, 24, 8}});
auto conv = m.add_instruction(migraphx::make_op("convolution"), x, w);
m.add_return({conv});
auto s = migraphx::gpu::dump_mlir(m);
// Skip test if MLIR is not enabled
if(s.empty())
return;
CHECK(encode(s) == encode(mlir_output));
EXPECT(verify_mlir(m));
}

TEST_CASE(conv_add_relu)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @mlir_convolution_add_relu(%arg0: tensor<1x2x2x2xf32>, %arg1: tensor<2x8x3x3xf32>, %arg2: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.convolution(%arg2, %arg1) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32>
%1 = migraphx.add(%0, %arg0) : (tensor<1x2x2x2xf32>, tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32>
%2 = migraphx.relu(%1) : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32>
return %2 : tensor<1x2x2x2xf32>
func.func @mlir_convolution_add_relu(%arg0: !migraphx.shaped<1x2x2x2xf32, 8x4x2x1>, %arg1: !migraphx.shaped<2x8x3x3xf32, 72x9x3x1>, %arg2: !migraphx.shaped<1x8x4x4xf32, 128x16x4x1>) -> !migraphx.shaped<1x2x2x2xf32, 8x4x2x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.convolution %arg2, %arg1 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xf32, 128x16x4x1>, <2x8x3x3xf32, 72x9x3x1> -> <1x2x2x2xf32, 8x4x2x1>
%1 = migraphx.add %0, %arg0 : <1x2x2x2xf32, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1> -> <1x2x2x2xf32, 8x4x2x1>
%2 = migraphx.relu %1 : <1x2x2x2xf32, 8x4x2x1> -> <1x2x2x2xf32, 8x4x2x1>
return %2 : !migraphx.shaped<1x2x2x2xf32, 8x4x2x1>
}
}
)__migraphx__";
Expand All @@ -192,10 +215,10 @@ TEST_CASE(quant_dot_add)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @mlir_quant_dot_add(%arg0: tensor<1x5x4xi8>, %arg1: tensor<1x4x3xi8>, %arg2: tensor<1x5x3xi32>) -> tensor<1x5x3xi32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.quant_dot(%arg0, %arg1) : (tensor<1x5x4xi8>, tensor<1x4x3xi8>) -> tensor<1x5x3xi32>
%1 = migraphx.add(%0, %arg2) : (tensor<1x5x3xi32>, tensor<1x5x3xi32>) -> tensor<1x5x3xi32>
return %1 : tensor<1x5x3xi32>
func.func @mlir_quant_dot_add(%arg0: !migraphx.shaped<1x5x4xi8, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xi8, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xi32, 15x3x1>) -> !migraphx.shaped<1x5x3xi32, 15x3x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.quant_dot %arg0, %arg1 : <1x5x4xi8, 20x4x1>, <1x4x3xi8, 12x3x1> -> <1x5x3xi32, 15x3x1>
%1 = migraphx.add %0, %arg2 : <1x5x3xi32, 15x3x1>, <1x5x3xi32, 15x3x1> -> <1x5x3xi32, 15x3x1>
return %1 : !migraphx.shaped<1x5x3xi32, 15x3x1>
}
}
)__migraphx__";
Expand All @@ -219,10 +242,10 @@ TEST_CASE(dot_add)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @mlir_dot_add(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>, %arg2: tensor<1x5x3xf32>) -> tensor<1x5x3xf32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.dot(%arg0, %arg1) : (tensor<1x5x4xf32>, tensor<1x4x3xf32>) -> tensor<1x5x3xf32>
%1 = migraphx.add(%0, %arg2) : (tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32>
return %1 : tensor<1x5x3xf32>
func.func @mlir_dot_add(%arg0: !migraphx.shaped<1x5x4xf32, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xf32, 15x3x1>) -> !migraphx.shaped<1x5x3xf32, 15x3x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.dot %arg0, %arg1 : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1>
%1 = migraphx.add %0, %arg2 : <1x5x3xf32, 15x3x1>, <1x5x3xf32, 15x3x1> -> <1x5x3xf32, 15x3x1>
return %1 : !migraphx.shaped<1x5x3xf32, 15x3x1>
}
}
)__migraphx__";
Expand All @@ -245,11 +268,11 @@ TEST_CASE(conv_int8_dequantize_quantize)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @mlir_quant_convolution_dequantizelinear_quantizelinear(%arg0: tensor<2x8x3x3xi8>, %arg1: tensor<1x8x4x4xi8>, %arg2: tensor<1x2x2x2xf32>, %arg3: tensor<1x2x2x2xi32>) -> tensor<1x2x2x2xi32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.quant_convolution(%arg1, %arg0) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xi8>, tensor<2x8x3x3xi8>) -> tensor<1x2x2x2xi32>
%1 = migraphx.dequantizelinear(%0, %arg2, %arg3) : (tensor<1x2x2x2xi32>, tensor<1x2x2x2xf32>, tensor<1x2x2x2xi32>) -> tensor<1x2x2x2xf32>
%2 = migraphx.quantizelinear(%1, %arg2, %arg3) : (tensor<1x2x2x2xf32>, tensor<1x2x2x2xf32>, tensor<1x2x2x2xi32>) -> tensor<1x2x2x2xi32>
return %2 : tensor<1x2x2x2xi32>
func.func @mlir_quant_convolution_dequantizelinear_quantizelinear(%arg0: !migraphx.shaped<2x8x3x3xi8, 72x9x3x1>, %arg1: !migraphx.shaped<1x8x4x4xi8, 128x16x4x1>, %arg2: !migraphx.shaped<1x2x2x2xf32, 8x4x2x1>, %arg3: !migraphx.shaped<1x2x2x2xi32, 8x4x2x1>) -> !migraphx.shaped<1x2x2x2xi32, 8x4x2x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.quant_convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xi8, 128x16x4x1>, <2x8x3x3xi8, 72x9x3x1> -> <1x2x2x2xi32, 8x4x2x1>
%1 = migraphx.dequantizelinear %0, %arg2, %arg3 : <1x2x2x2xi32, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1>, !migraphx.shaped<1x2x2x2xi32, 8x4x2x1> -> <1x2x2x2xf32, 8x4x2x1>
%2 = migraphx.quantizelinear %1, %arg2, %arg3 : <1x2x2x2xf32, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1>, !migraphx.shaped<1x2x2x2xi32, 8x4x2x1> -> <1x2x2x2xi32, 8x4x2x1>
return %2 : !migraphx.shaped<1x2x2x2xi32, 8x4x2x1>
}
}
)__migraphx__";
Expand Down Expand Up @@ -278,10 +301,10 @@ TEST_CASE(dot_convert)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @mlir_dot_convert(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>) -> tensor<1x5x3xf16> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.dot(%arg0, %arg1) : (tensor<1x5x4xf32>, tensor<1x4x3xf32>) -> tensor<1x5x3xf32>
%1 = migraphx.convert(%0) {target_type = 1 : i64} : (tensor<1x5x3xf32>) -> tensor<1x5x3xf16>
return %1 : tensor<1x5x3xf16>
func.func @mlir_dot_convert(%arg0: !migraphx.shaped<1x5x4xf32, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>) -> !migraphx.shaped<1x5x3xf16, 15x3x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.dot %arg0, %arg1 : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1>
%1 = migraphx.convert %0 {target_type = 1 : i64} : <1x5x3xf32, 15x3x1> to <1x5x3xf16, 15x3x1>
return %1 : !migraphx.shaped<1x5x3xf16, 15x3x1>
}
}
)__migraphx__";
Expand All @@ -304,10 +327,10 @@ TEST_CASE(dot_where)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @mlir_dot_where(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>, %arg2: tensor<1x5x3xi8>, %arg3: tensor<1x5x3xf32>) -> tensor<1x5x3xf32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.dot(%arg0, %arg1) : (tensor<1x5x4xf32>, tensor<1x4x3xf32>) -> tensor<1x5x3xf32>
%1 = migraphx.where(%arg2, %0, %arg3) : (tensor<1x5x3xi8>, tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32>
return %1 : tensor<1x5x3xf32>
func.func @mlir_dot_where(%arg0: !migraphx.shaped<1x5x4xf32, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xi8, 15x3x1>, %arg3: !migraphx.shaped<1x5x3xf32, 15x3x1>) -> !migraphx.shaped<1x5x3xf32, 15x3x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.dot %arg0, %arg1 : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1>
%1 = migraphx.where %arg2, %0, %arg3 : <1x5x3xi8, 15x3x1>, <1x5x3xf32, 15x3x1>, <1x5x3xf32, 15x3x1> -> <1x5x3xf32, 15x3x1>
return %1 : !migraphx.shaped<1x5x3xf32, 15x3x1>
}
}
)__migraphx__";
Expand Down

0 comments on commit e3e0054

Please sign in to comment.