Skip to content

Commit

Permalink
QLinearConcat tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gyulaz-htec committed Nov 30, 2023
1 parent 9752ed7 commit bc838bd
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 0 deletions.
50 changes: 50 additions & 0 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -6251,6 +6251,56 @@ def qlinearaveragepool_nt_cip_test():
return ([node], [x], [y], [x_scale, x_zero_point, y_scale, y_zero_point])


@onnx_test()
def qlinearconcat_test():
y_scale = helper.make_tensor('1', TensorProto.FLOAT, [], [0.5])
y_zero_point = helper.make_tensor('2', TensorProto.INT8, [], [2])

t0 = helper.make_tensor_value_info('t0', TensorProto.INT8, [2])
s0 = helper.make_tensor('3', TensorProto.FLOAT, [], [0.5])
zp0 = helper.make_tensor('4', TensorProto.INT8, [], [1])

t1 = helper.make_tensor_value_info('t1', TensorProto.INT8, [3])
s1 = helper.make_tensor('5', TensorProto.FLOAT, [], [0.25])
zp1 = helper.make_tensor('6', TensorProto.INT8, [], [0])

y = helper.make_tensor_value_info('out', TensorProto.INT8, [5])

node = onnx.helper.make_node(
'QLinearConcat',
inputs=['1', '2', 't0', '3', '4', 't1', '5', '6'],
axis=0,
outputs=['out'],
)

return ([node], [t0, t1], [y], [y_scale, y_zero_point, s0, zp0, s1, zp1])


@onnx_test()
def qlinearconcat_3d_test():
y_scale = helper.make_tensor('1', TensorProto.FLOAT, [], [0.5])
y_zero_point = helper.make_tensor('2', TensorProto.INT8, [], [2])

t0 = helper.make_tensor_value_info('t0', TensorProto.INT8, [3, 4, 2])
s0 = helper.make_tensor('3', TensorProto.FLOAT, [], [0.5])
zp0 = helper.make_tensor('4', TensorProto.INT8, [], [10])

t1 = helper.make_tensor_value_info('t1', TensorProto.INT8, [3, 2, 2])
s1 = helper.make_tensor('5', TensorProto.FLOAT, [], [0.4])
zp1 = helper.make_tensor('6', TensorProto.INT8, [], [20])

y = helper.make_tensor_value_info('out', TensorProto.UINT8, [3, 6, 2])

node = onnx.helper.make_node(
'QLinearConcat',
inputs=['1', '2', 't0', '3', '4', 't1', '5', '6'],
axis=1,
outputs=['out'],
)

return ([node], [t0, t1], [y], [y_scale, y_zero_point, s0, zp0, s1, zp1])


@onnx_test()
def qlinearconv_test():
# https://xadupre.github.io/draft/onnx/onnx_doc_folder/onnx__QLinearConv.html
Expand Down
53 changes: 53 additions & 0 deletions test/onnx/onnx_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5645,6 +5645,59 @@ TEST_CASE(qlinearaveragepool_notset_test)
EXPECT(p == prog);
}

TEST_CASE(qlinearconcat_test)
{
migraphx::program p;
auto* mm = p.get_main_module();

auto sc_y = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}});
auto z_pt_y = mm->add_literal(migraphx::literal{migraphx::shape::int8_type, {2}});

auto sc_0 = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}});
auto z_pt_0 = mm->add_literal(migraphx::literal{migraphx::shape::int8_type, {1}});

auto sc_1 = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.25}});
auto z_pt_1 = mm->add_literal(migraphx::literal{migraphx::shape::int8_type, {0}});

auto t0 = mm->add_parameter("t0", {migraphx::shape::int8_type, {2}});
auto t1 = mm->add_parameter("t1", {migraphx::shape::int8_type, {3}});

auto scale_0_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2}}}), sc_0);

auto z_pt_0_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2}}}), z_pt_0);

auto fp_0 =
mm->add_instruction(migraphx::make_op("dequantizelinear"), t0, scale_0_bcast, z_pt_0_bcast);

auto scale_1_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3}}}), sc_1);

auto z_pt_1_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3}}}), z_pt_1);

auto fp_1 =
mm->add_instruction(migraphx::make_op("dequantizelinear"), t1, scale_1_bcast, z_pt_1_bcast);

auto fp_y = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), fp_0, fp_1);

auto scale_y_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), sc_y);

auto z_pt_y_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), z_pt_y);

auto y =
mm->add_instruction(migraphx::make_op("quantizelinear"), fp_y, scale_y_bcast, z_pt_y_bcast);

mm->add_return({y});

auto prog = migraphx::parse_onnx("qlinearconcat_test.onnx");

EXPECT(p == prog);
}

TEST_CASE(qlinearconv_test)
{
migraphx::program p;
Expand Down
Binary file added test/onnx/qlinearconcat_3d_test.onnx
Binary file not shown.
Binary file added test/onnx/qlinearconcat_test.onnx
Binary file not shown.
46 changes: 46 additions & 0 deletions test/onnx/verify_onnx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1932,6 +1932,52 @@ TEST_CASE(qlinearaveragepool_nt_cip_test)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}

TEST_CASE(qlinearconcat_test)
{
auto p = migraphx::parse_onnx("qlinearconcat_test.onnx");
p.compile(migraphx::make_target("ref"));

std::vector<int8_t> data_t0 = {2, 3};
migraphx::shape s_t0{migraphx::shape::int8_type, {2}};
migraphx::parameter_map pp;
pp["t0"] = migraphx::argument(s_t0, data_t0.data());

std::vector<int8_t> data_t1 = {6, 8, 10};
migraphx::shape s_t1{migraphx::shape::int8_type, {3}};
pp["t1"] = migraphx::argument(s_t1, data_t1.data());

auto result = p.eval(pp).back();
std::vector<int8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });

std::vector<int8_t> gold = {3, 4, 5, 6, 7};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}

TEST_CASE(qlinearconcat_3d_test)
{
auto p = migraphx::parse_onnx("qlinearconcat_3d_test.onnx");
p.compile(migraphx::make_target("ref"));

std::vector<int8_t> data_t0 = {10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10};
migraphx::shape s_t0{migraphx::shape::int8_type, {3, 4, 2}};
migraphx::parameter_map pp;
pp["t0"] = migraphx::argument(s_t0, data_t0.data());

std::vector<int8_t> data_t1 = {25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25};
migraphx::shape s_t1{migraphx::shape::int8_type, {3, 2, 2}};
pp["t1"] = migraphx::argument(s_t1, data_t1.data());

auto result = p.eval(pp).back();
std::vector<uint8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });

std::vector<int8_t> gold = {2, 2, 2, 2, 2, 2, 2, 2, 6, 6, 6, 6, 2, 2, 2, 2, 2, 2,
2, 2, 6, 6, 6, 6, 2, 2, 2, 2, 2, 2, 2, 2, 6, 6, 6, 6};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}

TEST_CASE(qlinearconv_test)
{
// https://xadupre.github.io/draft/onnx/onnx_doc_folder/onnx__QLinearConv.html
Expand Down

0 comments on commit bc838bd

Please sign in to comment.