Skip to content

Commit

Permalink
Add QLinearMul operator
Browse files Browse the repository at this point in the history
  • Loading branch information
gyulaz-htec committed Nov 10, 2023
1 parent 35e5298 commit 26e8dc7
Show file tree
Hide file tree
Showing 6 changed files with 331 additions and 0 deletions.
152 changes: 152 additions & 0 deletions src/onnx/parse_qlinearmul.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/onnx/broadcast_qdq.hpp>
#include <migraphx/instruction.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {

/*
*********************************************************************************
* Reference: see QLinearMul in *
* https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md *
*********************************************************************************
Performs element-wise binary multiplication on 8 bit data types (with Numpy-style broadcasting
support).
C = ((A - A_zero_point) * (B - B_zero_point)) * (A_scale * B_scale)/C_scale + C_zero_point
Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator
set.
Inputs (7 - 8)
A :
T First operand.
A_scale : tensor(float)
Input A's scale. It's a scalar, which means a per-tensor/layer quantization.
A_zero_point (optional) : T
Input A zero point. Default value is 0 if it's not specified. It's a scalar, which means a
per-tensor/layer quantization.
B : T
Second operand.
B_scale : tensor(float)
Input B's scale. It's a scalar, which means a per-tensor/layer quantization.
B_zero_point (optional) : T
Input B zero point. Default value is 0 if it's not specified. It's a scalar, which means a
per-tensor/layer quantization.
C_scale : tensor(float)
Output scale. It's a scalar, which means a per-tensor/layer quantization.
C_zero_point (optional) : T
Output zero point. Default value is 0 if it's not specified. It's a scalar, which means a
per-tensor/layer quantization.
Outputs
C : T
Result, has same element type as two inputs
Type Constraints
T : tensor(uint8), tensor(int8)
Constrain input and output types to 8 bit signed and unsigned tensors.
*/

struct parse_qlinearmul : op_parser<parse_qlinearmul>
{
std::vector<op_desc> operators() const { return {{"QLinearMul"}}; }

// basic type checking for QLinearMul Operator
void check_inputs(const std::vector<instruction_ref>& args) const
{
if(args.size() < 7)
MIGRAPHX_THROW("QLINEARMUL: missing inputs");

const auto& in_a = args[0];
const auto& in_b = args[3];

auto sh_a = in_a->get_shape();
auto sh_b = in_b->get_shape();

auto type_a = sh_a.type();
auto type_b = sh_b.type();
if(type_a != migraphx::shape::int8_type and type_a != migraphx::shape::uint8_type)
MIGRAPHX_THROW("QLINEARMUL: unsupported input type");
if(type_b != migraphx::shape::int8_type and type_b != migraphx::shape::uint8_type)
MIGRAPHX_THROW("QLINEARMUL: unsupported input type");
if(type_a != type_b)
MIGRAPHX_THROW("QLINEARMUL: mismatched input types");
}

instruction_ref parse(const op_desc& /* opd */,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
check_inputs(args);

// A
const auto& in_a = args[0];
const auto& in_scale_a = args[1];
const auto& in_zero_pt_a = args[2];

auto dquant_a = bcast_qdq_instr("dequantizelinear", in_a, in_scale_a, in_zero_pt_a, info);

// B
const auto& in_b = args[3];
const auto& in_scale_b = args[4];
const auto& in_zero_pt_b = args[5];
auto dquant_b = bcast_qdq_instr("dequantizelinear", in_b, in_scale_b, in_zero_pt_b, info);

// C = A * B
auto out_c = info.add_common_op("mul", dquant_a, dquant_b);

const auto& in_scale_c = args[6];

// zero_pt for C is supplied as the last optional argument..
if(args.size() == 8)
return (bcast_qdq_instr("quantizelinear", out_c, in_scale_c, args[7], info));

// if no zero_pt: just broadcast the scale..
auto bcast_scale_c = bcast_scalar_instr(out_c->get_shape(), in_scale_c, info);
return (info.add_instruction(migraphx::make_op("quantizelinear"), out_c, bcast_scale_c));
}
};

} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
54 changes: 54 additions & 0 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -6008,6 +6008,60 @@ def qlinearmatmul_3D_test():
[sc_a, zero_pt_a, sc_b, zero_pt_b, sc_c, zero_pt_c])


@onnx_test()
def qlinearmul_test():
a = helper.make_tensor_value_info('A', TensorProto.UINT8, [64])
sc_a = helper.make_tensor('A_scale', TensorProto.FLOAT, [], [0.005])
zero_pt_a = helper.make_tensor('A_zero_point', TensorProto.UINT8, [], [0])

b = helper.make_tensor_value_info('B', TensorProto.UINT8, [64])
sc_b = helper.make_tensor('B_scale', TensorProto.FLOAT, [], [0.005])
zero_pt_b = helper.make_tensor('B_zero_point', TensorProto.UINT8, [], [64])

sc_c = helper.make_tensor('C_scale', TensorProto.FLOAT, [], [0.5])
zero_pt_c = helper.make_tensor('C_zero_point', TensorProto.UINT8, [], [64])

c = helper.make_tensor_value_info('C', TensorProto.UINT8, [64])

node = onnx.helper.make_node(
'QLinearMul',
inputs=[
'A', 'A_scale', 'A_zero_point', 'B', 'B_scale', 'B_zero_point',
'C_scale', 'C_zero_point'
],
outputs=['C'],
)
return ([node], [a, b], [c],
[sc_a, zero_pt_a, sc_b, zero_pt_b, sc_c, zero_pt_c])


@onnx_test()
def qlinearmul_bcast_test():
a = helper.make_tensor_value_info('A', TensorProto.INT8, [64])
sc_a = helper.make_tensor('A_scale', TensorProto.FLOAT, [], [0.005])
zero_pt_a = helper.make_tensor('A_zero_point', TensorProto.INT8, [], [0])

b = helper.make_tensor_value_info('B', TensorProto.INT8, [1, 1, 64])
sc_b = helper.make_tensor('B_scale', TensorProto.FLOAT, [], [0.005])
zero_pt_b = helper.make_tensor('B_zero_point', TensorProto.INT8, [], [64])

sc_c = helper.make_tensor('C_scale', TensorProto.FLOAT, [], [0.5])
zero_pt_c = helper.make_tensor('C_zero_point', TensorProto.INT8, [], [-64])

c = helper.make_tensor_value_info('C', TensorProto.INT8, [1, 1, 64])

node = onnx.helper.make_node(
'QLinearAdd',
inputs=[
'A', 'A_scale', 'A_zero_point', 'B', 'B_scale', 'B_zero_point',
'C_scale', 'C_zero_point'
],
outputs=['C'],
)
return ([node], [a, b], [c],
[sc_a, zero_pt_a, sc_b, zero_pt_b, sc_c, zero_pt_c])


@onnx_test()
def quantizelinear_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5])
Expand Down
53 changes: 53 additions & 0 deletions test/onnx/onnx_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5754,6 +5754,59 @@ TEST_CASE(qlinearmatmul_2D_test)
EXPECT(p.sort() == prog.sort());
}

TEST_CASE(qlinearmul_test)
{
migraphx::program p;
auto* mm = p.get_main_module();

auto a = mm->add_parameter("A", {migraphx::shape::uint8_type, {64}});
auto b = mm->add_parameter("B", {migraphx::shape::uint8_type, {64}});

auto sc_a = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.005}});
auto z_pt_a = mm->add_literal(migraphx::literal{migraphx::shape::uint8_type, {0}});

auto sc_b = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.005}});
auto z_pt_b = mm->add_literal(migraphx::literal{migraphx::shape::uint8_type, {64}});

auto sc_c = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}});
auto z_pt_c = mm->add_literal(migraphx::literal{migraphx::shape::uint8_type, {64}});

auto scale_a_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), sc_a);

auto z_pt_a_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), z_pt_a);

auto fp_a =
mm->add_instruction(migraphx::make_op("dequantizelinear"), a, scale_a_bcast, z_pt_a_bcast);

auto scale_b_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), sc_b);

auto z_pt_b_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), z_pt_b);

auto fp_b =
mm->add_instruction(migraphx::make_op("dequantizelinear"), b, scale_b_bcast, z_pt_b_bcast);

auto fp_c = mm->add_instruction(migraphx::make_op("mul"), fp_a, fp_b);

auto scale_c_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), sc_c);

auto z_pt_c_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), z_pt_c);

auto c =
mm->add_instruction(migraphx::make_op("quantizelinear"), fp_c, scale_c_bcast, z_pt_c_bcast);

mm->add_return({c});

auto prog = migraphx::parse_onnx("qlinearmul_test.onnx");

EXPECT(p.sort() == prog.sort());
}

migraphx::instruction_ref insert_quantizelinear_clip(migraphx::module& m,
const migraphx::instruction_ref ins,
const migraphx::instruction_ref round,
Expand Down
Binary file added test/onnx/qlinearmul_bcast_test.onnx
Binary file not shown.
Binary file added test/onnx/qlinearmul_test.onnx
Binary file not shown.
72 changes: 72 additions & 0 deletions test/onnx/verify_onnx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1895,6 +1895,78 @@ TEST_CASE(qlinearmatmul_3D_test)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}

TEST_CASE(qlinearmul_test)
{
// github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearMul
migraphx::program p = migraphx::parse_onnx("qlinearmul_test.onnx");
p.compile(migraphx::make_target("ref"));

migraphx::shape a{migraphx::shape::uint8_type, {64}};
std::vector<uint8_t> data_a = {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24,
26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50,
52, 54, 56, 58, 60, 62, 64, 66, 68, 70, 72, 74, 76,
78, 80, 82, 84, 86, 88, 90, 92, 94, 96, 98, 100, 102,
104, 106, 108, 110, 112, 114, 116, 118, 120, 122, 124, 126};

migraphx::shape b{migraphx::shape::uint8_type, {64}};
std::vector<uint8_t> data_b = {128, 126, 124, 122, 120, 118, 116, 114, 112, 110, 108, 106, 104,
102, 100, 98, 96, 94, 92, 90, 88, 86, 84, 82, 80, 78,
76, 74, 72, 70, 68, 66, 64, 62, 60, 58, 56, 54, 52,
50, 48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26,
24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2};

migraphx::parameter_map pp;
pp["A"] = migraphx::argument(a, data_a.data());
pp["B"] = migraphx::argument(b, data_b.data());
auto result = p.eval(pp).back();

std::vector<uint8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });

std::vector<uint8_t> gold = {64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}

TEST_CASE(qlinearmul_bcast_test)
{
// github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearMul
migraphx::program p = migraphx::parse_onnx("qlinearmul_bcast_test.onnx");
p.compile(migraphx::make_target("ref"));

migraphx::shape a{migraphx::shape::int8_type, {64}};
std::vector<int8_t> data_a = {-64, -62, -60, -58, -56, -54, -52, -50, -48, -46, -44, -42, -40,
-38, -36, -34, -32, -30, -28, -26, -24, -22, -20, -18, -16, -14,
-12, -10, -8, -6, -4, -2, 0, 2, 4, 6, 8, 10, 12,
14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38,
40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62};

migraphx::shape b{migraphx::shape::int8_type, {1, 1, 64}};
std::vector<int8_t> data_b = {96, 94, 92, 90, 88, 86, 84, 82, 80, 78, 76, 74, 72,
70, 68, 66, 64, 62, 60, 58, 56, 54, 52, 50, 48, 46,
44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20,
18, 16, 14, 12, 10, 8, 6, 4, 2, 0, -2, -4, -6,
-8, -10, -12, -14, -16, -18, -20, -22, -24, -26, -28, -30};

migraphx::parameter_map pp;
pp["A"] = migraphx::argument(a, data_a.data());
pp["B"] = migraphx::argument(b, data_b.data());
auto result = p.eval(pp).back();

std::vector<int8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });

std::vector<int8_t> gold = {-64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64,
-64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64,
-64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64,
-64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64,
-64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64};

EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}

TEST_CASE(resize_downsample_f_test)
{
migraphx::program p = migraphx::parse_onnx("resize_downsample_f_test.onnx");
Expand Down

0 comments on commit 26e8dc7

Please sign in to comment.