Skip to content

Commit

Permalink
Add QLinearMul operator (#2430)
Browse files Browse the repository at this point in the history
  • Loading branch information
gyulaz-htec authored Nov 17, 2023
1 parent 7f93a81 commit 0102d44
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 13 deletions.
40 changes: 27 additions & 13 deletions src/onnx/parse_qlinearadd.cpp → src/onnx/parse_qlinearbinary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ namespace onnx {

/*
*********************************************************************************
* Reference: see QLinearAdd in *
* Reference: see QLinearAdd, QLinearMul in *
* https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md *
*********************************************************************************
Expand All @@ -49,6 +49,17 @@ namespace onnx {
This version of the operator has been available since version 1 of the 'com.microsoft' operator
set.
com.microsoft.QLinearMul
Performs element-wise binary multiplication on 8 bit data types (with Numpy-style broadcasting
support).
C = ((A - A_zero_point) * (B - B_zero_point)) * (A_scale * B_scale)/C_scale + C_zero_point
Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator
set.
General definition of binary QLinear* ops:
Inputs (7 - 8)
A : T
First operand.
Expand Down Expand Up @@ -88,15 +99,18 @@ namespace onnx {
*/

struct parse_qlinearadd : op_parser<parse_qlinearadd>
struct parse_qlinearbinary : op_parser<parse_qlinearbinary>
{
std::vector<op_desc> operators() const { return {{"QLinearAdd"}}; }
std::vector<op_desc> operators() const
{
return {{"QLinearAdd", "add"}, {"QLinearMul", "mul"}};
}

// basic type checking for QLinearAdd Operator
void check_inputs(const std::vector<instruction_ref>& args) const
// basic type checking for binary QLinear Operator
void check_inputs(const std::vector<instruction_ref>& args, const std::string& op_name) const
{
if(args.size() < 7)
MIGRAPHX_THROW("QLINEARADD: missing inputs");
MIGRAPHX_THROW(op_name + ": missing inputs");

const auto& in_a = args[0];
const auto& in_b = args[3];
Expand All @@ -107,19 +121,19 @@ struct parse_qlinearadd : op_parser<parse_qlinearadd>
auto type_a = sh_a.type();
auto type_b = sh_b.type();
if(type_a != migraphx::shape::int8_type and type_a != migraphx::shape::uint8_type)
MIGRAPHX_THROW("QLINEARADD: unsupported input type");
MIGRAPHX_THROW(op_name + ": unsupported input type");
if(type_b != migraphx::shape::int8_type and type_b != migraphx::shape::uint8_type)
MIGRAPHX_THROW("QLINEARADD: unsupported input type");
MIGRAPHX_THROW(op_name + ": unsupported input type");
if(type_a != type_b)
MIGRAPHX_THROW("QLINEARADD: mismatched input types");
MIGRAPHX_THROW(op_name + ": mismatched input types");
}

instruction_ref parse(const op_desc& /* opd */,
instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
check_inputs(args);
check_inputs(args, opd.op_name);

// A
const auto& in_a = args[0];
Expand All @@ -134,8 +148,8 @@ struct parse_qlinearadd : op_parser<parse_qlinearadd>
const auto& in_zero_pt_b = args[5];
auto dquant_b = bcast_qdq_instr("dequantizelinear", in_b, in_scale_b, in_zero_pt_b, info);

// C = A + B
auto out_c = info.add_common_op("add", dquant_a, dquant_b);
// C = op(A, B)
auto out_c = info.add_common_op(opd.op_name, dquant_a, dquant_b);

const auto& in_scale_c = args[6];

Expand Down
55 changes: 55 additions & 0 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -6179,6 +6179,61 @@ def qlinearmatmul_3D_test():
[sc_a, zero_pt_a, sc_b, zero_pt_b, sc_c, zero_pt_c])


@onnx_test()
def qlinearmul_test():
a = helper.make_tensor_value_info('A', TensorProto.UINT8, [64])
sc_a = helper.make_tensor('A_scale', TensorProto.FLOAT, [], [0.05])
zero_pt_a = helper.make_tensor('A_zero_point', TensorProto.UINT8, [], [0])

b = helper.make_tensor_value_info('B', TensorProto.UINT8, [64])
sc_b = helper.make_tensor('B_scale', TensorProto.FLOAT, [], [0.05])
zero_pt_b = helper.make_tensor('B_zero_point', TensorProto.UINT8, [], [16])

sc_c = helper.make_tensor('C_scale', TensorProto.FLOAT, [], [0.05])
zero_pt_c = helper.make_tensor('C_zero_point', TensorProto.UINT8, [],
[100])

c = helper.make_tensor_value_info('C', TensorProto.UINT8, [64])

node = onnx.helper.make_node(
'QLinearMul',
inputs=[
'A', 'A_scale', 'A_zero_point', 'B', 'B_scale', 'B_zero_point',
'C_scale', 'C_zero_point'
],
outputs=['C'],
)
return ([node], [a, b], [c],
[sc_a, zero_pt_a, sc_b, zero_pt_b, sc_c, zero_pt_c])


@onnx_test()
def qlinearmul_bcast_test():
a = helper.make_tensor_value_info('A', TensorProto.INT8, [64])
sc_a = helper.make_tensor('A_scale', TensorProto.FLOAT, [], [0.05])
zero_pt_a = helper.make_tensor('A_zero_point', TensorProto.INT8, [], [0])

b = helper.make_tensor_value_info('B', TensorProto.INT8, [1, 1, 64])
sc_b = helper.make_tensor('B_scale', TensorProto.FLOAT, [], [0.05])
zero_pt_b = helper.make_tensor('B_zero_point', TensorProto.INT8, [], [128])

sc_c = helper.make_tensor('C_scale', TensorProto.FLOAT, [], [0.15])
zero_pt_c = helper.make_tensor('C_zero_point', TensorProto.INT8, [], [32])

c = helper.make_tensor_value_info('C', TensorProto.INT8, [1, 1, 64])

node = onnx.helper.make_node(
'QLinearMul',
inputs=[
'A', 'A_scale', 'A_zero_point', 'B', 'B_scale', 'B_zero_point',
'C_scale', 'C_zero_point'
],
outputs=['C'],
)
return ([node], [a, b], [c],
[sc_a, zero_pt_a, sc_b, zero_pt_b, sc_c, zero_pt_c])


@onnx_test()
def quantizelinear_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5])
Expand Down
53 changes: 53 additions & 0 deletions test/onnx/onnx_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5754,6 +5754,59 @@ TEST_CASE(qlinearmatmul_2D_test)
EXPECT(p.sort() == prog.sort());
}

TEST_CASE(qlinearmul_test)
{
migraphx::program p;
auto* mm = p.get_main_module();

auto a = mm->add_parameter("A", {migraphx::shape::uint8_type, {64}});
auto b = mm->add_parameter("B", {migraphx::shape::uint8_type, {64}});

auto sc_a = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.05}});
auto z_pt_a = mm->add_literal(migraphx::literal{migraphx::shape::uint8_type, {0}});

auto sc_b = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.05}});
auto z_pt_b = mm->add_literal(migraphx::literal{migraphx::shape::uint8_type, {16}});

auto sc_c = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.05}});
auto z_pt_c = mm->add_literal(migraphx::literal{migraphx::shape::uint8_type, {100}});

auto scale_a_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), sc_a);

auto z_pt_a_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), z_pt_a);

auto fp_a =
mm->add_instruction(migraphx::make_op("dequantizelinear"), a, scale_a_bcast, z_pt_a_bcast);

auto scale_b_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), sc_b);

auto z_pt_b_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), z_pt_b);

auto fp_b =
mm->add_instruction(migraphx::make_op("dequantizelinear"), b, scale_b_bcast, z_pt_b_bcast);

auto fp_c = mm->add_instruction(migraphx::make_op("mul"), fp_a, fp_b);

auto scale_c_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), sc_c);

auto z_pt_c_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), z_pt_c);

auto c =
mm->add_instruction(migraphx::make_op("quantizelinear"), fp_c, scale_c_bcast, z_pt_c_bcast);

mm->add_return({c});

auto prog = migraphx::parse_onnx("qlinearmul_test.onnx");

EXPECT(p.sort() == prog.sort());
}

migraphx::instruction_ref insert_quantizelinear_clip(migraphx::module& m,
const migraphx::instruction_ref ins,
const migraphx::instruction_ref round,
Expand Down
Binary file added test/onnx/qlinearmul_bcast_test.onnx
Binary file not shown.
Binary file added test/onnx/qlinearmul_test.onnx
Binary file not shown.
75 changes: 75 additions & 0 deletions test/onnx/verify_onnx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1895,6 +1895,81 @@ TEST_CASE(qlinearmatmul_3D_test)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}

TEST_CASE(qlinearmul_test)
{
// github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearMul
migraphx::program p = migraphx::parse_onnx("qlinearmul_test.onnx");
p.compile(migraphx::make_target("ref"));

migraphx::shape a{migraphx::shape::uint8_type, {64}};
std::vector<uint8_t> data_a = {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24,
26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50,
52, 54, 56, 58, 60, 62, 64, 66, 68, 70, 72, 74, 76,
78, 80, 82, 84, 86, 88, 90, 92, 94, 96, 98, 100, 102,
104, 106, 108, 110, 112, 114, 116, 118, 120, 122, 124, 126};

migraphx::shape b{migraphx::shape::uint8_type, {64}};
std::vector<uint8_t> data_b = {128, 126, 124, 122, 120, 118, 116, 114, 112, 110, 108, 106, 104,
102, 100, 98, 96, 94, 92, 90, 88, 86, 84, 82, 80, 78,
76, 74, 72, 70, 68, 66, 64, 62, 60, 58, 56, 54, 52,
50, 48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26,
24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2};

migraphx::parameter_map pp;
pp["A"] = migraphx::argument(a, data_a.data());
pp["B"] = migraphx::argument(b, data_b.data());
auto result = p.eval(pp).back();

std::vector<uint8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });

std::vector<uint8_t> gold = {100, 111, 122, 132, 142, 151, 160, 169, 177, 185, 192, 199, 206,
212, 218, 223, 228, 233, 237, 241, 244, 247, 250, 252, 254, 255,
255, 255, 255, 255, 255, 255, 254, 252, 250, 247, 244, 241, 237,
233, 228, 223, 218, 212, 206, 199, 192, 185, 177, 169, 160, 151,
142, 132, 122, 111, 100, 89, 77, 65, 52, 39, 26, 12};

EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}

TEST_CASE(qlinearmul_bcast_test)
{
// github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearMul
migraphx::program p = migraphx::parse_onnx("qlinearmul_bcast_test.onnx");
p.compile(migraphx::make_target("ref"));

migraphx::shape a{migraphx::shape::int8_type, {64}};
std::vector<int8_t> data_a = {-64, -62, -60, -58, -56, -54, -52, -50, -48, -46, -44, -42, -40,
-38, -36, -34, -32, -30, -28, -26, -24, -22, -20, -18, -16, -14,
-12, -10, -8, -6, -4, -2, 0, 2, 4, 6, 8, 10, 12,
14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38,
40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62};

migraphx::shape b{migraphx::shape::int8_type, {1, 1, 64}};
std::vector<int8_t> data_b = {96, 94, 92, 90, 88, 86, 84, 82, 80, 78, 76, 74, 72,
70, 68, 66, 64, 62, 60, 58, 56, 54, 52, 50, 48, 46,
44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20,
18, 16, 14, 12, 10, 8, 6, 4, 2, 0, -2, -4, -6,
-8, -10, -12, -14, -16, -18, -20, -22, -24, -26, -28, -30};

migraphx::parameter_map pp;
pp["A"] = migraphx::argument(a, data_a.data());
pp["B"] = migraphx::argument(b, data_b.data());
auto result = p.eval(pp).back();

std::vector<int8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });

std::vector<int8_t> gold = {-128, -128, -128, -128, -128, -128, -128, -128, -128, -126, -118,
-109, -101, -93, -86, -78, -70, -63, -56, -49, -42, -35,
-28, -21, -15, -9, -2, 4, 10, 15, 21, 27, 32,
37, 42, 47, 52, 57, 62, 66, 70, 75, 79, 83,
86, 90, 94, 97, 100, 103, 106, 109, 112, 115, 117,
119, 122, 124, 126, 127, 127, 127, 127, 127};

EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}

TEST_CASE(resize_downsample_f_test)
{
migraphx::program p = migraphx::parse_onnx("resize_downsample_f_test.onnx");
Expand Down

0 comments on commit 0102d44

Please sign in to comment.