Skip to content

Commit

Permalink
Merge pull request #2922 from ROCm/fix_parse_matmulinteger_611
Browse files Browse the repository at this point in the history
Fix parse MatMulinteger
  • Loading branch information
TedThemistokleous authored Mar 25, 2024
2 parents 05a7707 + 4cd4f52 commit c6f4a18
Show file tree
Hide file tree
Showing 18 changed files with 583 additions and 48 deletions.
3 changes: 2 additions & 1 deletion src/eliminate_data_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ void eliminate_data_type::apply(module& m) const
"scatternd_add",
"scatternd_mul",
"scatternd_none",
"select_module"};
"select_module",
"quantizelinear"};
if(unsupported_types.empty())
return;

Expand Down
140 changes: 98 additions & 42 deletions src/onnx/parse_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,77 @@ struct parse_matmul : op_parser<parse_matmul>
return {{"MatMul", "dot"}, {"MatMulInteger", "quant_dot"}};
}

static void broadcast_dimensions(const onnx_parser::node_info& info,
const std::vector<size_t>& s0_lens,
const std::vector<size_t>& s1_lens,
const instruction_ref& a0,
const instruction_ref& a1,
instruction_ref& ba0,
instruction_ref& ba1)
{
// try broadcasting if dimensions other than last two do not match
if(not std::equal(
s0_lens.rbegin() + 2, s0_lens.rend(), s1_lens.rbegin() + 2, s1_lens.rend()))
{
auto l0_it = s0_lens.begin() + s0_lens.size() - 2;
std::vector<std::size_t> l0_broadcasted_lens(s0_lens.begin(), l0_it);
auto l1_it = s1_lens.begin() + s1_lens.size() - 2;
std::vector<std::size_t> l1_broadcasted_lens(s1_lens.begin(), l1_it);
auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
l0_broadcasted_lens = output_lens;
l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, s0_lens.end());
l1_broadcasted_lens = output_lens;
l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, s1_lens.end());
if(s0_lens != l0_broadcasted_lens)
{
ba0 = info.add_instruction(
make_op("multibroadcast", {{"out_lens", l0_broadcasted_lens}}), a0);
}
if(s1_lens != l1_broadcasted_lens)
{
ba1 = info.add_instruction(
make_op("multibroadcast", {{"out_lens", l1_broadcasted_lens}}), a1);
}
}
}

// Convert to int16 prior to a shift to ensure we preserve accuracy here then
// convert back to int8
static instruction_ref add_int8_shift(const onnx_parser::node_info& info,
instruction_ref& unshifted_input)
{
auto int8_shift = info.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::int16_type}, {-128}});

auto unshifted_input_int16 = info.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::int16_type}}),
unshifted_input);

auto input_shifted_int16 = info.add_common_op("add", unshifted_input_int16, int8_shift);

return info.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::int8_type}}),
input_shifted_int16);
}

static instruction_ref set_bias_arg(const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args,
const int index,
const instruction_ref& input)
{
if(args.size() > index)
{
instruction_ref bias_arg = args[index];
if(bias_arg->get_shape().type() != input->get_shape().type())
{
MIGRAPHX_THROW("PARSE_QUANT_DOT: zero point must be the same type as data");
}

return info.add_common_op("sub", input, bias_arg);
}
return input;
}

instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
Expand Down Expand Up @@ -85,55 +156,40 @@ struct parse_matmul : op_parser<parse_matmul>
{
auto s0_lens = a0->get_shape().lens();
auto s1_lens = a1->get_shape().lens();
instruction_ref ba0 = a0;
instruction_ref ba1 = a1;
// try broadcasting if dimensions other than last two do not match
if(not std::equal(
s0_lens.rbegin() + 2, s0_lens.rend(), s1_lens.rbegin() + 2, s1_lens.rend()))

if(not is_quant_dot and args.size() > 2)
{
auto l0_it = s0_lens.begin() + s0_lens.size() - 2;
std::vector<std::size_t> l0_broadcasted_lens(s0_lens.begin(), l0_it);
auto l1_it = s1_lens.begin() + s1_lens.size() - 2;
std::vector<std::size_t> l1_broadcasted_lens(s1_lens.begin(), l1_it);
auto output_lens =
compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
l0_broadcasted_lens = output_lens;
l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, s0_lens.end());
l1_broadcasted_lens = output_lens;
l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, s1_lens.end());
if(s0_lens != l0_broadcasted_lens)
{
ba0 = info.add_instruction(
make_op("multibroadcast", {{"out_lens", l0_broadcasted_lens}}), a0);
}
if(s1_lens != l1_broadcasted_lens)
{
ba1 = info.add_instruction(
make_op("multibroadcast", {{"out_lens", l1_broadcasted_lens}}), a1);
}
MIGRAPHX_THROW("PARSE_MATMUL: Bias Args not supported for MatMul");
}

// parse a_zero_point and b_zero_point values
if(args.size() > 2)
instruction_ref ba0 = set_bias_arg(info, args, 2, a0);
instruction_ref ba1 = set_bias_arg(info, args, 3, a1);

// Only INT8 or UINT8 type currently supported
std::set<migraphx::shape::type_t> supported_types = {migraphx::shape::uint8_type,
migraphx::shape::int8_type};
const auto ba0_type = ba0->get_shape().type();
const auto ba1_type = ba1->get_shape().type();

if(is_quant_dot and
(not contains(supported_types, ba0_type) or not contains(supported_types, ba1_type)))
{
ba0 = info.add_instruction(
make_op("convert", {{"target_type", migraphx::shape::float_type}}), ba0);

ba0 = info.add_common_op("sub", ba0, args[2]);
if(args.size() > 3)
{
ba1 = info.add_instruction(
make_op("convert", {{"target_type", migraphx::shape::float_type}}), ba1);
ba1 = info.add_common_op("sub", ba1, args[3]);
}
dot_res = info.add_instruction(make_op("dot"), ba0, ba1);
dot_res = info.add_instruction(
make_op("convert", {{"target_type", migraphx::shape::int32_type}}), dot_res);
MIGRAPHX_THROW("PARSE_MATMULINTEGER: Unsupported type");
}
else

auto is_same_type = (ba0_type == ba1_type);

if(is_quant_dot and not is_same_type)
{
dot_res = info.add_instruction(make_op(opd.op_name), ba0, ba1);
if(ba0_type == migraphx::shape::uint8_type)
ba0 = add_int8_shift(info, ba0);

if(ba1_type == migraphx::shape::uint8_type)
ba1 = add_int8_shift(info, ba1);
}

broadcast_dimensions(info, s0_lens, s1_lens, a0, a1, ba0, ba1);
dot_res = info.add_instruction(make_op(opd.op_name), ba0, ba1);
}

// squeeze the appended or prepended dimensions
Expand Down
2 changes: 0 additions & 2 deletions src/targets/gpu/target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
unsupported_types.erase(shape::type_t::half_type);
unsupported_types.erase(shape::type_t::bool_type);
unsupported_types.erase(shape::type_t::int8_type);
unsupported_types.erase(shape::type_t::uint8_type);
unsupported_types.erase(shape::type_t::int32_type);
unsupported_types.erase(shape::type_t::tuple_type);
// whiltelist supported Ops for the FP8
Expand Down Expand Up @@ -131,7 +130,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
enable_pass(not mlir_enabled(), rewrite_quantization{}),
dead_code_elimination{},
// workaround for rocBLAS unsupported error when using uint8 in quant_dot
eliminate_data_type{{migraphx::shape::uint8_type}, shape::float_type, {"quant_dot"}},
eliminate_data_type{unsupported_types, shape::type_t::float_type},
simplify_reshapes{},
eliminate_identity{},
Expand Down
83 changes: 81 additions & 2 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5130,6 +5130,21 @@ def matmulinteger_dyn_error():
return ([node], [m1, m2], [y])


@onnx_test()
def matmulinteger_invalid_type_error():
m1 = helper.make_tensor_value_info('1', TensorProto.INT8, [None, 6, 16])
m2 = helper.make_tensor_value_info('2', TensorProto.INT16, [None, 16, 8])
y = helper.make_tensor_value_info('y', TensorProto.INT32, [None, 6, 8])

node = onnx.helper.make_node(
'MatMulInteger',
inputs=['1', '2'],
outputs=['y'],
)

return ([node], [m1, m2], [y])


@onnx_test()
def matmulinteger_uns_test():
m1 = helper.make_tensor_value_info('1', TensorProto.UINT8, [4, 3])
Expand All @@ -5145,12 +5160,76 @@ def matmulinteger_uns_test():
return ([node], [m1, m2], [y])


@onnx_test()
def matmulinteger_int8_uint8_test():
m1 = helper.make_tensor_value_info('1', TensorProto.INT8, [4, 3])
m2 = helper.make_tensor_value_info('2', TensorProto.UINT8, [3, 2])
y = helper.make_tensor_value_info('y', TensorProto.INT32, [4, 2])

node = onnx.helper.make_node(
'MatMulInteger',
inputs=['1', '2'],
outputs=['y'],
)

return ([node], [m1, m2], [y])


@onnx_test()
def matmulinteger_uns_zp_test():
m1 = helper.make_tensor_value_info('1', TensorProto.UINT8, [4, 3])
m2 = helper.make_tensor_value_info('2', TensorProto.UINT8, [3, 2])
zp1 = helper.make_tensor('3', TensorProto.UINT8, [], [12])
zp2 = helper.make_tensor('4', TensorProto.UINT8, [], [0])
zp1 = helper.make_tensor('3', TensorProto.UINT8, [], [0])
zp2 = helper.make_tensor('4', TensorProto.UINT8, [], [1])
y = helper.make_tensor_value_info('y', TensorProto.INT32, [4, 2])

node = onnx.helper.make_node(
'MatMulInteger',
inputs=['1', '2', '3', '4'],
outputs=['y'],
)

return ([node], [m1, m2], [y], [zp1, zp2])


@onnx_test()
def matmulinteger_int8_uint8_one_zp_test():
m1 = helper.make_tensor_value_info('1', TensorProto.INT8, [4, 3])
m2 = helper.make_tensor_value_info('2', TensorProto.UINT8, [3, 2])
zp1 = helper.make_tensor('3', TensorProto.INT8, [], [5])
y = helper.make_tensor_value_info('y', TensorProto.INT32, [4, 2])

node = onnx.helper.make_node(
'MatMulInteger',
inputs=['1', '2', '3'],
outputs=['y'],
)

return ([node], [m1, m2], [y], [zp1])


@onnx_test()
def matmulinteger_int8_uint8_one_zp_error_test():
m1 = helper.make_tensor_value_info('1', TensorProto.UINT8, [4, 3])
m2 = helper.make_tensor_value_info('2', TensorProto.UINT8, [3, 2])
zp1 = helper.make_tensor('3', TensorProto.INT8, [], [5])
y = helper.make_tensor_value_info('y', TensorProto.INT32, [4, 2])

node = onnx.helper.make_node(
'MatMulInteger',
inputs=['1', '2', '3'],
outputs=['y'],
)

return ([node], [m1, m2], [y], [zp1])


@onnx_test()
def matmulinteger_int8_uint8_dual_zp_test():
m1 = helper.make_tensor_value_info('1', TensorProto.INT8, [4, 3])
m2 = helper.make_tensor_value_info('2', TensorProto.UINT8, [3, 2])
zp1 = helper.make_tensor('3', TensorProto.INT8, [], [1])
zp2 = helper.make_tensor('4', TensorProto.UINT8, [], [1])
y = helper.make_tensor_value_info('y', TensorProto.INT32, [4, 2])

node = onnx.helper.make_node(
Expand Down
18 changes: 18 additions & 0 deletions test/onnx/matmulinteger_int8_uint8_dual_zp_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
 %matmulinteger_int8_uint8_dual_zp_test:�

1
2
3
4y"MatMulInteger%matmulinteger_int8_uint8_dual_zp_test**B3**B4Z
1


Z
2


b
y


B
Expand Down
17 changes: 17 additions & 0 deletions test/onnx/matmulinteger_int8_uint8_one_zp_error_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
 *matmulinteger_int8_uint8_one_zp_error_test:�

1
2
3y"MatMulInteger*matmulinteger_int8_uint8_one_zp_error_test**B3Z
1


Z
2


b
y


B
Expand Down
17 changes: 17 additions & 0 deletions test/onnx/matmulinteger_int8_uint8_one_zp_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
 $matmulinteger_int8_uint8_one_zp_test:�

1
2
3y"MatMulInteger$matmulinteger_int8_uint8_one_zp_test**B3Z
1


Z
2


b
y


B
Expand Down
16 changes: 16 additions & 0 deletions test/onnx/matmulinteger_int8_uint8_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
 matmulinteger_int8_uint8_test:x

1
2y"MatMulIntegermatmulinteger_int8_uint8_testZ
1


Z
2


b
y


B
Expand Down
Binary file added test/onnx/matmulinteger_invalid_type_error.onnx
Binary file not shown.
Binary file modified test/onnx/matmulinteger_uns_zp_test.onnx
Binary file not shown.
Loading

0 comments on commit c6f4a18

Please sign in to comment.