From 846ccc3ec480714a740343eb3f0e4dc3c6661c2d Mon Sep 17 00:00:00 2001 From: Gyula Zakor Date: Fri, 10 Nov 2023 12:29:11 +0000 Subject: [PATCH] Add QLinearMul operator --- ...qlinearadd.cpp => parse_qlinearbinary.cpp} | 40 ++++++---- test/onnx/gen_onnx.py | 54 +++++++++++++ test/onnx/onnx_test.cpp | 53 +++++++++++++ test/onnx/qlinearmul_bcast_test.onnx | Bin 0 -> 343 bytes test/onnx/qlinearmul_test.onnx | Bin 0 -> 306 bytes test/onnx/verify_onnx.cpp | 72 ++++++++++++++++++ 6 files changed, 206 insertions(+), 13 deletions(-) rename src/onnx/{parse_qlinearadd.cpp => parse_qlinearbinary.cpp} (80%) create mode 100644 test/onnx/qlinearmul_bcast_test.onnx create mode 100644 test/onnx/qlinearmul_test.onnx diff --git a/src/onnx/parse_qlinearadd.cpp b/src/onnx/parse_qlinearbinary.cpp similarity index 80% rename from src/onnx/parse_qlinearadd.cpp rename to src/onnx/parse_qlinearbinary.cpp index 81f00e71d6a..6a4f0b5d172 100644 --- a/src/onnx/parse_qlinearadd.cpp +++ b/src/onnx/parse_qlinearbinary.cpp @@ -36,7 +36,7 @@ namespace onnx { /* ********************************************************************************* - * Reference: see QLinearAdd in * + * Reference: see QLinearAdd, QLinearMul in * * https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md * ********************************************************************************* @@ -49,6 +49,17 @@ namespace onnx { This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + com.microsoft.QLinearMul + Performs element-wise binary multiplication on 8 bit data types (with Numpy-style broadcasting + support). + + C = ((A - A_zero_point) * (B - B_zero_point)) * (A_scale * B_scale)/C_scale + C_zero_point + + Version + This version of the operator has been available since version 1 of the 'com.microsoft' operator + set. + + General definition of binary QLinear* ops: Inputs (7 - 8) A : T First operand. @@ -88,15 +99,18 @@ namespace onnx { */ -struct parse_qlinearadd : op_parser +struct parse_qlinearbinary : op_parser { - std::vector operators() const { return {{"QLinearAdd"}}; } + std::vector operators() const + { + return {{"QLinearAdd", "add"}, {"QLinearMul", "mul"}}; + } - // basic type checking for QLinearAdd Operator - void check_inputs(const std::vector& args) const + // basic type checking for binary QLinear Operator + void check_inputs(const std::vector& args, const std::string& op_name) const { if(args.size() < 7) - MIGRAPHX_THROW("QLINEARADD: missing inputs"); + MIGRAPHX_THROW(op_name + ": missing inputs"); const auto& in_a = args[0]; const auto& in_b = args[3]; @@ -107,19 +121,19 @@ struct parse_qlinearadd : op_parser auto type_a = sh_a.type(); auto type_b = sh_b.type(); if(type_a != migraphx::shape::int8_type and type_a != migraphx::shape::uint8_type) - MIGRAPHX_THROW("QLINEARADD: unsupported input type"); + MIGRAPHX_THROW(op_name + ": unsupported input type"); if(type_b != migraphx::shape::int8_type and type_b != migraphx::shape::uint8_type) - MIGRAPHX_THROW("QLINEARADD: unsupported input type"); + MIGRAPHX_THROW(op_name + ": unsupported input type"); if(type_a != type_b) - MIGRAPHX_THROW("QLINEARADD: mismatched input types"); + MIGRAPHX_THROW(op_name + ": mismatched input types"); } - instruction_ref parse(const op_desc& /* opd */, + instruction_ref parse(const op_desc& opd, const onnx_parser& /*parser*/, const onnx_parser::node_info& info, const std::vector& args) const { - check_inputs(args); + check_inputs(args, opd.op_name); // A const auto& in_a = args[0]; @@ -134,8 +148,8 @@ struct parse_qlinearadd : op_parser const auto& in_zero_pt_b = args[5]; auto dquant_b = bcast_qdq_instr("dequantizelinear", in_b, in_scale_b, in_zero_pt_b, info); - // C = A + B - auto out_c = info.add_common_op("add", dquant_a, dquant_b); + // C = op(A, B) + auto out_c = info.add_common_op(opd.op_name, dquant_a, dquant_b); const auto& in_scale_c = args[6]; diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index dd1f90d755f..12e18e8c159 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -6008,6 +6008,60 @@ def qlinearmatmul_3D_test(): [sc_a, zero_pt_a, sc_b, zero_pt_b, sc_c, zero_pt_c]) +@onnx_test() +def qlinearmul_test(): + a = helper.make_tensor_value_info('A', TensorProto.UINT8, [64]) + sc_a = helper.make_tensor('A_scale', TensorProto.FLOAT, [], [0.005]) + zero_pt_a = helper.make_tensor('A_zero_point', TensorProto.UINT8, [], [0]) + + b = helper.make_tensor_value_info('B', TensorProto.UINT8, [64]) + sc_b = helper.make_tensor('B_scale', TensorProto.FLOAT, [], [0.005]) + zero_pt_b = helper.make_tensor('B_zero_point', TensorProto.UINT8, [], [64]) + + sc_c = helper.make_tensor('C_scale', TensorProto.FLOAT, [], [0.5]) + zero_pt_c = helper.make_tensor('C_zero_point', TensorProto.UINT8, [], [64]) + + c = helper.make_tensor_value_info('C', TensorProto.UINT8, [64]) + + node = onnx.helper.make_node( + 'QLinearMul', + inputs=[ + 'A', 'A_scale', 'A_zero_point', 'B', 'B_scale', 'B_zero_point', + 'C_scale', 'C_zero_point' + ], + outputs=['C'], + ) + return ([node], [a, b], [c], + [sc_a, zero_pt_a, sc_b, zero_pt_b, sc_c, zero_pt_c]) + + +@onnx_test() +def qlinearmul_bcast_test(): + a = helper.make_tensor_value_info('A', TensorProto.INT8, [64]) + sc_a = helper.make_tensor('A_scale', TensorProto.FLOAT, [], [0.005]) + zero_pt_a = helper.make_tensor('A_zero_point', TensorProto.INT8, [], [0]) + + b = helper.make_tensor_value_info('B', TensorProto.INT8, [1, 1, 64]) + sc_b = helper.make_tensor('B_scale', TensorProto.FLOAT, [], [0.005]) + zero_pt_b = helper.make_tensor('B_zero_point', TensorProto.INT8, [], [64]) + + sc_c = helper.make_tensor('C_scale', TensorProto.FLOAT, [], [0.5]) + zero_pt_c = helper.make_tensor('C_zero_point', TensorProto.INT8, [], [-64]) + + c = helper.make_tensor_value_info('C', TensorProto.INT8, [1, 1, 64]) + + node = onnx.helper.make_node( + 'QLinearAdd', + inputs=[ + 'A', 'A_scale', 'A_zero_point', 'B', 'B_scale', 'B_zero_point', + 'C_scale', 'C_zero_point' + ], + outputs=['C'], + ) + return ([node], [a, b], [c], + [sc_a, zero_pt_a, sc_b, zero_pt_b, sc_c, zero_pt_c]) + + @onnx_test() def quantizelinear_test(): arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5]) diff --git a/test/onnx/onnx_test.cpp b/test/onnx/onnx_test.cpp index 011fff9e8e9..f329a5d5228 100644 --- a/test/onnx/onnx_test.cpp +++ b/test/onnx/onnx_test.cpp @@ -5754,6 +5754,59 @@ TEST_CASE(qlinearmatmul_2D_test) EXPECT(p.sort() == prog.sort()); } +TEST_CASE(qlinearmul_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + auto a = mm->add_parameter("A", {migraphx::shape::uint8_type, {64}}); + auto b = mm->add_parameter("B", {migraphx::shape::uint8_type, {64}}); + + auto sc_a = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.005}}); + auto z_pt_a = mm->add_literal(migraphx::literal{migraphx::shape::uint8_type, {0}}); + + auto sc_b = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.005}}); + auto z_pt_b = mm->add_literal(migraphx::literal{migraphx::shape::uint8_type, {64}}); + + auto sc_c = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}}); + auto z_pt_c = mm->add_literal(migraphx::literal{migraphx::shape::uint8_type, {64}}); + + auto scale_a_bcast = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), sc_a); + + auto z_pt_a_bcast = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), z_pt_a); + + auto fp_a = + mm->add_instruction(migraphx::make_op("dequantizelinear"), a, scale_a_bcast, z_pt_a_bcast); + + auto scale_b_bcast = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), sc_b); + + auto z_pt_b_bcast = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), z_pt_b); + + auto fp_b = + mm->add_instruction(migraphx::make_op("dequantizelinear"), b, scale_b_bcast, z_pt_b_bcast); + + auto fp_c = mm->add_instruction(migraphx::make_op("mul"), fp_a, fp_b); + + auto scale_c_bcast = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), sc_c); + + auto z_pt_c_bcast = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), z_pt_c); + + auto c = + mm->add_instruction(migraphx::make_op("quantizelinear"), fp_c, scale_c_bcast, z_pt_c_bcast); + + mm->add_return({c}); + + auto prog = migraphx::parse_onnx("qlinearmul_test.onnx"); + + EXPECT(p.sort() == prog.sort()); +} + migraphx::instruction_ref insert_quantizelinear_clip(migraphx::module& m, const migraphx::instruction_ref ins, const migraphx::instruction_ref round, diff --git a/test/onnx/qlinearmul_bcast_test.onnx b/test/onnx/qlinearmul_bcast_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..802086f4d6b77953b5eeb5cbc5de2c095caa308d GIT binary patch literal 343 zcmdmSJ~>2|FVR3o$w?aRvH-ZE{RW!C|YGpa7#13)l6<)=n^c zwS)zjwHO(kP^?E*2XzHRor4n(vST3X7#JAronX$=k^!2>b>Ke~Fgo!dJ2Q$O=w~4= pE)EW6Ar>ws4u>dlpjU;2xCDR#JU{_PD0WB!N;qSaaAFb=1^{NJTv`AC literal 0 HcmV?d00001 diff --git a/test/onnx/qlinearmul_test.onnx b/test/onnx/qlinearmul_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..2430c8a0913e4940d547fc11083738b06ec3ed5c GIT binary patch literal 306 zcmd#3;^TEPFDZ` literal 0 HcmV?d00001 diff --git a/test/onnx/verify_onnx.cpp b/test/onnx/verify_onnx.cpp index 7a711a8ed20..9c4365dfd47 100644 --- a/test/onnx/verify_onnx.cpp +++ b/test/onnx/verify_onnx.cpp @@ -1895,6 +1895,78 @@ TEST_CASE(qlinearmatmul_3D_test) EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); } +TEST_CASE(qlinearmul_test) +{ + // github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearMul + migraphx::program p = migraphx::parse_onnx("qlinearmul_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape a{migraphx::shape::uint8_type, {64}}; + std::vector data_a = {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, + 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, + 52, 54, 56, 58, 60, 62, 64, 66, 68, 70, 72, 74, 76, + 78, 80, 82, 84, 86, 88, 90, 92, 94, 96, 98, 100, 102, + 104, 106, 108, 110, 112, 114, 116, 118, 120, 122, 124, 126}; + + migraphx::shape b{migraphx::shape::uint8_type, {64}}; + std::vector data_b = {128, 126, 124, 122, 120, 118, 116, 114, 112, 110, 108, 106, 104, + 102, 100, 98, 96, 94, 92, 90, 88, 86, 84, 82, 80, 78, + 76, 74, 72, 70, 68, 66, 64, 62, 60, 58, 56, 54, 52, + 50, 48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26, + 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2}; + + migraphx::parameter_map pp; + pp["A"] = migraphx::argument(a, data_a.data()); + pp["B"] = migraphx::argument(b, data_b.data()); + auto result = p.eval(pp).back(); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(qlinearmul_bcast_test) +{ + // github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearMul + migraphx::program p = migraphx::parse_onnx("qlinearmul_bcast_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape a{migraphx::shape::int8_type, {64}}; + std::vector data_a = {-64, -62, -60, -58, -56, -54, -52, -50, -48, -46, -44, -42, -40, + -38, -36, -34, -32, -30, -28, -26, -24, -22, -20, -18, -16, -14, + -12, -10, -8, -6, -4, -2, 0, 2, 4, 6, 8, 10, 12, + 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, + 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62}; + + migraphx::shape b{migraphx::shape::int8_type, {1, 1, 64}}; + std::vector data_b = {96, 94, 92, 90, 88, 86, 84, 82, 80, 78, 76, 74, 72, + 70, 68, 66, 64, 62, 60, 58, 56, 54, 52, 50, 48, 46, + 44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20, + 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, -2, -4, -6, + -8, -10, -12, -14, -16, -18, -20, -22, -24, -26, -28, -30}; + + migraphx::parameter_map pp; + pp["A"] = migraphx::argument(a, data_a.data()); + pp["B"] = migraphx::argument(b, data_b.data()); + auto result = p.eval(pp).back(); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {-64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64, + -64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64, + -64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64, + -64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64, + -64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64}; + + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + TEST_CASE(resize_downsample_f_test) { migraphx::program p = migraphx::parse_onnx("resize_downsample_f_test.onnx");