From f25606f97ab5eb14d8a3aa838b1b8d42513e5387 Mon Sep 17 00:00:00 2001 From: Charlie Lin Date: Tue, 17 Oct 2023 12:06:22 -0400 Subject: [PATCH 1/6] 2 Input Reshape `ref` implementation (#2304) --- src/include/migraphx/op/reshape.hpp | 58 +++++++++-- src/onnx/parse_reshape.cpp | 22 ++-- test/onnx/gen_onnx.py | 18 ++++ test/onnx/onnx_test.cpp | 94 ++++++++++++------ .../onnx/reshape_variable_input_dyn_test.onnx | Bin 0 -> 153 bytes test/onnx/reshape_variable_input_test.onnx | 17 ++++ test/op_shape_test.cpp | 23 ++++- test/ref/reshape.cpp | 78 ++++++++++++++- 8 files changed, 262 insertions(+), 48 deletions(-) create mode 100644 test/onnx/reshape_variable_input_dyn_test.onnx create mode 100644 test/onnx/reshape_variable_input_test.onnx diff --git a/src/include/migraphx/op/reshape.hpp b/src/include/migraphx/op/reshape.hpp index 90843c3ecb1..a521461c829 100644 --- a/src/include/migraphx/op/reshape.hpp +++ b/src/include/migraphx/op/reshape.hpp @@ -36,6 +36,22 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace op { +/** + * 1 input version: + * reshape(input_data) + * this.dims = output_dims + * Makes a copy of input_data to the output shape. + * + * 2 input version: + * reshape(input_data, output_buffer) + * this.dims = unset + * Copies input_data to output_buffer; output_buffer already has the output shape. + * This version will not fail gracefully if the input shape and output_buffer shape are + * incompatible. There's a throw that will catch when the number of elements do not match at + * runtime. This version should only be used for dynamic reshapes (output dimensions only known at + * runtime). If output_buffer has a static shape during compile/parse, you can use the 1 input + * version. + */ struct reshape { std::vector dims; @@ -215,32 +231,56 @@ struct reshape shape compute_shape(std::vector inputs) const { - check_shapes{inputs, *this, true}.has(1); + check_shapes{inputs, *this, true}.has(1, 2); auto n_neg_dims = std::count(dims.begin(), dims.end(), -1); if(n_neg_dims > 1) MIGRAPHX_THROW("reshape: Dimensions for reshape can only have one -1 dim"); auto s0 = inputs.front(); - if(s0.dynamic()) + if(inputs.size() == 1) { - return dyn_compute_shape(s0); + if(s0.dynamic()) + { + return dyn_compute_shape(s0); + } + else + { + return static_compute_shape(inputs, n_neg_dims); + } } else { - return static_compute_shape(inputs, n_neg_dims); + return inputs.back(); } } argument compute(const dyn_output& dyn_out, std::vector args) const { assert(dyn_out.computed_shape.standard()); - argument result{dyn_out.computed_shape}; + if(args.size() == 1) + { + argument result{dyn_out.computed_shape}; - visit_all(result, args[0])([&](auto output, auto input) { - std::copy(input.begin(), input.end(), output.begin()); - }); - return result; + visit_all(result, args[0])([&](auto output, auto input) { + std::copy(input.begin(), input.end(), output.begin()); + }); + return result; + } + else + { + // 2 arg + if(args[0].get_shape().elements() != args[1].get_shape().elements()) + { + MIGRAPHX_THROW("Reshape: Number of elements must match at runtime. Input: " + + std::to_string(args[0].get_shape().elements()) + + " Output buffer: " + std::to_string(args[1].get_shape().elements())); + } + visit_all(args[1], args[0])([&](auto output, auto input) { + std::copy(input.begin(), input.end(), output.begin()); + }); + return args[1]; + } } }; diff --git a/src/onnx/parse_reshape.cpp b/src/onnx/parse_reshape.cpp index a22f1389c33..dba225d7779 100644 --- a/src/onnx/parse_reshape.cpp +++ b/src/onnx/parse_reshape.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -45,15 +45,25 @@ struct parse_reshape : op_parser { literal s = parser.parse_value(info.attributes.at("shape")); s.visit([&](auto v) { copy(v, std::back_inserter(dims)); }); + return info.add_instruction(make_op("reshape", {{"dims", dims}}), args[0]); } - if(args.size() == 2) + else { + // 2 inputs auto s = args[1]->eval(); - check_arg_empty(s, "Reshape: non-constant shape input is not supported"); - s.visit([&](auto v) { copy(v, std::back_inserter(dims)); }); + if(s.empty()) + { + // arg[1] not eval-able + auto alloc_ins = info.add_instruction( + make_op("allocate", {{"buf_type", args[0]->get_shape().type()}}), args[1]); + return info.add_instruction(make_op("reshape"), args[0], alloc_ins); + } + else + { + s.visit([&](auto v) { copy(v, std::back_inserter(dims)); }); + return info.add_instruction(make_op("reshape", {{"dims", dims}}), args[0]); + } } - - return info.add_instruction(make_op("reshape", {{"dims", dims}}), args[0]); } }; diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index 73d42789de3..581a264f35f 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -6065,6 +6065,24 @@ def reshape_non_standard_test(): return ([trans, res], [x], [y]) +@onnx_test() +def reshape_variable_input_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [4, 2, 3]) + x_shape = helper.make_tensor_value_info('1', TensorProto.INT64, [2]) + y = helper.make_tensor_value_info('2', TensorProto.FLOAT, [3, 8]) + node = onnx.helper.make_node('Reshape', inputs=['0', '1'], outputs=['2']) + return ([node], [x, x_shape], [y]) + + +@onnx_test() +def reshape_variable_input_dyn_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [None, 2, 3]) + x_shape = helper.make_tensor_value_info('1', TensorProto.INT64, [2]) + y = helper.make_tensor_value_info('2', TensorProto.FLOAT, [None, 6]) + node = onnx.helper.make_node('Reshape', inputs=['0', '1'], outputs=['2']) + return ([node], [x, x_shape], [y]) + + @onnx_test() def resize_downsample_f_test(): scales = np.array([1.0, 1.0, 0.6, 0.6], dtype=np.float32) diff --git a/test/onnx/onnx_test.cpp b/test/onnx/onnx_test.cpp index a1fb16caa7f..368b8918e2f 100644 --- a/test/onnx/onnx_test.cpp +++ b/test/onnx/onnx_test.cpp @@ -362,10 +362,10 @@ TEST_CASE(averagepool_notset_test) auto* mm = p.get_main_module(); auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}}); auto ins = mm->add_instruction(migraphx::make_op("pooling", - {{"mode", migraphx::op::pooling_mode::average}, - {"padding", {2, 2, 2, 2}}, - {"stride", {2, 2}}, - {"lengths", {6, 6}}}), + {{"mode", migraphx::op::pooling_mode::average}, + {"padding", {2, 2, 2, 2}}, + {"stride", {2, 2}}, + {"lengths", {6, 6}}}), input); auto ret = mm->add_instruction( migraphx::make_op("slice", {{"axes", {2, 3}}, {"starts", {1, 1}}, {"ends", {2, 2}}}), ins); @@ -382,11 +382,11 @@ TEST_CASE(averagepool_nt_cip_test) auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}}); std::vector pads = {0, 0, 0, 0, 0, 0, 1, 1}; auto ins_pad = mm->add_instruction(migraphx::make_op("pad", {{"pads", pads}}), input); - auto ret = mm->add_instruction(migraphx::make_op("pooling", - {{"mode", migraphx::op::pooling_mode::average}, - {"padding", {0, 0, 0, 0}}, - {"stride", {2, 2}}, - {"lengths", {6, 6}}}), + auto ret = mm->add_instruction(migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::average}, + {"padding", {0, 0, 0, 0}}, + {"stride", {2, 2}}, + {"lengths", {6, 6}}}), ins_pad); mm->add_return({ret}); @@ -426,11 +426,11 @@ TEST_CASE(averagepool_sl_cip_test) auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}}); std::vector pads = {0, 0, 1, 1, 0, 0, 0, 0}; auto ins_pad = mm->add_instruction(migraphx::make_op("pad", {{"pads", pads}}), input); - auto ret = mm->add_instruction(migraphx::make_op("pooling", - {{"mode", migraphx::op::pooling_mode::average}, - {"padding", {0, 0, 0, 0}}, - {"stride", {1, 1}}, - {"lengths", {2, 2}}}), + auto ret = mm->add_instruction(migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::average}, + {"padding", {0, 0, 0, 0}}, + {"stride", {1, 1}}, + {"lengths", {2, 2}}}), ins_pad); mm->add_return({ret}); auto prog = migraphx::parse_onnx("averagepool_sl_cip_test.onnx"); @@ -444,10 +444,10 @@ TEST_CASE(averagepool_same_upper_test) auto* mm = p.get_main_module(); auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}}); auto ins = mm->add_instruction(migraphx::make_op("pooling", - {{"mode", migraphx::op::pooling_mode::average}, - {"padding", {1, 1, 1, 1}}, - {"stride", {1, 1}}, - {"lengths", {2, 2}}}), + {{"mode", migraphx::op::pooling_mode::average}, + {"padding", {1, 1, 1, 1}}, + {"stride", {1, 1}}, + {"lengths", {2, 2}}}), input); auto ret = mm->add_instruction( migraphx::make_op("slice", {{"axes", {2, 3}}, {"starts", {1, 1}}, {"ends", {6, 6}}}), ins); @@ -1634,7 +1634,7 @@ TEST_CASE(conv_transpose_input_pads_asymm_1d_test) auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 2, 3}}); auto l2 = mm->add_instruction( migraphx::make_op("convolution_backwards", - {{"padding", {0}}, {"stride", {2}}, {"dilation", {1}}}), + {{"padding", {0}}, {"stride", {2}}, {"dilation", {1}}}), l0, l1); mm->add_instruction(migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {6}}}), @@ -1668,7 +1668,7 @@ TEST_CASE(conv_transpose_output_padding_3d_test) auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 2, 3, 3, 3}}); auto l2 = mm->add_instruction( migraphx::make_op("convolution_backwards", - {{"padding", {0, 0, 0}}, {"stride", {3, 2, 2}}, {"dilation", {1, 1, 1}}}), + {{"padding", {0, 0, 0}}, {"stride", {3, 2, 2}}, {"dilation", {1, 1, 1}}}), l0, l1); mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 0, 0, 0, 0, 0, 1, 1, 1}}}), l2); @@ -1701,7 +1701,7 @@ TEST_CASE(conv_transpose_output_shape_3d_test) auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 2, 3, 3, 3}}); auto l2 = mm->add_instruction( migraphx::make_op("convolution_backwards", - {{"padding", {0, 0, 0}}, {"stride", {3, 2, 2}}, {"dilation", {1, 1, 1}}}), + {{"padding", {0, 0, 0}}, {"stride", {3, 2, 2}}, {"dilation", {1, 1, 1}}}), l0, l1); mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 0, 0, 0, 0, 0, 1, 1, 1}}}), l2); @@ -1996,7 +1996,7 @@ TEST_CASE(equal_test) auto eq = mm->add_instruction(migraphx::make_op("equal"), input1, input2); auto ret = mm->add_instruction( migraphx::make_op("convert", - {{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}), + {{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}), eq); mm->add_return({ret}); @@ -2016,7 +2016,7 @@ TEST_CASE(equal_bool_test) auto input2 = mm->add_parameter("x2", sb); auto cin1 = mm->add_instruction( migraphx::make_op("convert", - {{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}), + {{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}), input1); auto ret = mm->add_instruction(migraphx::make_op("equal"), cin1, input2); mm->add_return({ret}); @@ -2726,7 +2726,7 @@ TEST_CASE(greater_test) auto gr = mm->add_instruction(migraphx::make_op("greater"), input1, input2); auto ret = mm->add_instruction( migraphx::make_op("convert", - {{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}), + {{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}), gr); mm->add_return({ret}); @@ -2745,7 +2745,7 @@ TEST_CASE(greater_bool_test) auto input2 = mm->add_parameter("x2", sb); auto cin1 = mm->add_instruction( migraphx::make_op("convert", - {{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}), + {{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}), input1); auto ret = mm->add_instruction(migraphx::make_op("greater"), cin1, input2); mm->add_return({ret}); @@ -3602,7 +3602,7 @@ TEST_CASE(less_test) auto le = mm->add_instruction(migraphx::make_op("less"), input1, input2); auto ret = mm->add_instruction( migraphx::make_op("convert", - {{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}), + {{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}), le); mm->add_return({ret}); @@ -3621,7 +3621,7 @@ TEST_CASE(less_bool_test) auto input2 = mm->add_parameter("x2", sb); auto cin1 = mm->add_instruction( migraphx::make_op("convert", - {{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}), + {{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}), input1); auto ret = mm->add_instruction(migraphx::make_op("less"), cin1, input2); mm->add_return({ret}); @@ -5463,7 +5463,7 @@ TEST_CASE(reducel1_dyn_test) // a shape with 4 dynamic dimensions auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, - {{3, 3}, {3, 5}, {4, 6, {5}}, {5, 7, {6}}}}); + {{3, 3}, {3, 5}, {4, 6, {5}}, {5, 7, {6}}}}); auto abs_ins = mm->add_instruction(migraphx::make_op("abs"), l0); auto sum_ins = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {-2}}}), abs_ins); @@ -5483,7 +5483,7 @@ TEST_CASE(reducel1_dyn_test) // No axes given in the onnx file. Parser should default to all axes. auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, - {{3, 3}, {3, 5}, {4, 6, {5}}, {5, 7, {6}}}}); + {{3, 3}, {3, 5}, {4, 6, {5}}, {5, 7, {6}}}}); auto abs_ins = mm->add_instruction(migraphx::make_op("abs"), l0); auto sum_ins = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 1, 2, 3}}}), abs_ins); @@ -5719,6 +5719,38 @@ TEST_CASE(reshape_non_standard_test) EXPECT(p == prog); } +TEST_CASE(reshape_variable_input_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto p0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}}); + auto p1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::int64_type, {2}}); + auto alloc = mm->add_instruction( + migraphx::make_op("allocate", {{"buf_type", migraphx::shape::float_type}}), p1); + mm->add_instruction(migraphx::make_op("reshape"), p0, alloc); + + auto prog = optimize_onnx("reshape_variable_input_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(reshape_variable_input_dyn_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto p0 = mm->add_parameter( + "0", migraphx::shape{migraphx::shape::float_type, {{1, 4}, {2, 2}, {3, 3}}}); + auto p1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::int64_type, {2}}); + auto alloc = mm->add_instruction( + migraphx::make_op("allocate", {{"buf_type", migraphx::shape::float_type}}), p1); + auto reshape = mm->add_instruction(migraphx::make_op("reshape"), p0, alloc); + mm->add_return({reshape}); + + migraphx::onnx_options options; + options.default_dyn_dim_value = {1, 4}; + auto prog = parse_onnx("reshape_variable_input_dyn_test.onnx", options); + EXPECT(p == prog); +} + TEST_CASE(resize_downsample_c_test) { migraphx::program p; @@ -7169,7 +7201,7 @@ TEST_CASE(squeeze_unsqueeze_dyn_test) std::vector unsqueeze_axes{0, 1, 3, 5}; auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, - {{1, 1}, {1, 4}, {1, 1}, {1, 1}, {1, 4}, {1, 1}}}); + {{1, 1}, {1, 4}, {1, 1}, {1, 1}, {1, 4}, {1, 1}}}); auto c0 = mm->add_instruction(migraphx::make_op("contiguous"), l0); auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", squeeze_axes}}), c0); auto c1 = mm->add_instruction(migraphx::make_op("contiguous"), l1); @@ -7249,7 +7281,7 @@ TEST_CASE(sum_int_test) auto input2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::uint32_type, {3}}); auto cin0 = mm->add_instruction( migraphx::make_op("convert", - {{"target_type", migraphx::to_value(migraphx::shape::uint32_type)}}), + {{"target_type", migraphx::to_value(migraphx::shape::uint32_type)}}), input0); auto cin1 = mm->add_instruction( migraphx::make_op("convert", diff --git a/test/onnx/reshape_variable_input_dyn_test.onnx b/test/onnx/reshape_variable_input_dyn_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..63155647d9ea3b8bf8419493a631ea36521940a2 GIT binary patch literal 153 zcmd z4}z#Bs4Yqqs7XkGi;sg*h>MGXi;05?h?%4KK?ZSgad5B;v48}V1i2WEgm{1=j6!T+ Kb!<*d0{j3)U?PYB literal 0 HcmV?d00001 diff --git a/test/onnx/reshape_variable_input_test.onnx b/test/onnx/reshape_variable_input_test.onnx new file mode 100644 index 00000000000..c0dc123a7eb --- /dev/null +++ b/test/onnx/reshape_variable_input_test.onnx @@ -0,0 +1,17 @@ +reshape_variable_input_test:p + +0 +12"Reshapereshape_variable_input_testZ +0 + + + +Z +1 + + +b +2 +  + +B \ No newline at end of file diff --git a/test/op_shape_test.cpp b/test/op_shape_test.cpp index d03f0141ff3..862b354748b 100644 --- a/test/op_shape_test.cpp +++ b/test/op_shape_test.cpp @@ -2684,7 +2684,7 @@ TEST_CASE(reshape_broadcast_squeeze_memlayout_change) expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input); } -TEST_CASE(reshape_dyn_shape) +TEST_CASE(reshape_dyn_1in) { migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {24, 24}, {1, 1}, {1, 1}}}; for(auto&& new_shape : std::vector>{ @@ -2708,6 +2708,27 @@ TEST_CASE(reshape_dyn_shape) } } +TEST_CASE(reshape_dyn_2in_0) +{ + migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {24, 24}, {1, 1}, {1, 1}}}; + migraphx::shape output{migraphx::shape::float_type, {{1, 4}, {8, 8}, {3, 3}, {1, 1}}}; + expect_shape(output, migraphx::make_op("reshape"), input, output); +} + +TEST_CASE(reshape_dyn_2in_1) +{ + migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {24, 24}, {1, 1}, {1, 1}}}; + migraphx::shape output{migraphx::shape::float_type, {{12, 12}, {2, 2}, {1, 1}, {1, 4}}}; + expect_shape(output, migraphx::make_op("reshape"), input, output); +} + +TEST_CASE(reshape_dyn_2in_2) +{ + migraphx::shape input{migraphx::shape::float_type, {2, 24, 1, 1}}; + migraphx::shape output{migraphx::shape::float_type, {{1, 2}, {6, 12}, {1, 1}, {4, 4}}}; + expect_shape(output, migraphx::make_op("reshape"), input, output); +} + TEST_CASE(reshape_multiple_non_fixed_error) { migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {24, 24}, {10, 20}, {1, 1}}}; diff --git a/test/ref/reshape.cpp b/test/ref/reshape.cpp index b59a6cde5c4..33f8d87e760 100644 --- a/test/ref/reshape.cpp +++ b/test/ref/reshape.cpp @@ -153,7 +153,7 @@ TEST_CASE(reshape_test2) EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); } -TEST_CASE(reshape_dyn_test) +TEST_CASE(reshape_dyn_1in_test) { migraphx::program p; auto* mm = p.get_main_module(); @@ -173,3 +173,79 @@ TEST_CASE(reshape_dyn_test) result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); } + +TEST_CASE(reshape_2in_test0) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s_in{migraphx::shape::float_type, {{1, 4}, {24, 24}, {1, 1}, {1, 1}}}; + migraphx::shape s_out{migraphx::shape::float_type, {{1, 4}, {6, 6}, {4, 4}, {1, 1}}}; + auto input = mm->add_parameter("X", s_in); + auto output_buffer = mm->add_parameter("Y", s_out); + mm->add_instruction(migraphx::make_op("reshape"), input, output_buffer); + p.compile(migraphx::make_target("ref")); + + std::vector gold(48); + std::iota(gold.begin(), gold.end(), -3.); + std::vector buffer(48); + std::iota(buffer.begin(), buffer.end(), 0.); + migraphx::parameter_map params; + migraphx::shape input_fixed_shape{migraphx::shape::float_type, {2, 24, 1, 1}}; + migraphx::shape output_fixed_shape{migraphx::shape::float_type, {2, 6, 4, 1}}; + params["X"] = migraphx::argument(input_fixed_shape, gold.data()); + params["Y"] = migraphx::argument(output_fixed_shape, buffer.data()); + auto result = p.eval(params).back(); + EXPECT(result.get_shape() == output_fixed_shape); + std::vector results_vector{}; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); +} + +TEST_CASE(reshape_2in_test1) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s_in{migraphx::shape::float_type, {2, 24, 1, 1}}; + migraphx::shape s_out{migraphx::shape::float_type, {{2, 4}, {6, 6}, {2, 4}, {1, 1}}}; + auto input = mm->add_parameter("X", s_in); + auto output_buffer = mm->add_parameter("Y", s_out); + mm->add_instruction(migraphx::make_op("reshape"), input, output_buffer); + p.compile(migraphx::make_target("ref")); + + std::vector gold(48); + std::iota(gold.begin(), gold.end(), -3.); + std::vector buffer(48); + std::iota(buffer.begin(), buffer.end(), 0.); + migraphx::parameter_map params; + migraphx::shape output_fixed_shape{migraphx::shape::float_type, {2, 6, 4, 1}}; + params["X"] = migraphx::argument(s_in, gold.data()); + params["Y"] = migraphx::argument(output_fixed_shape, buffer.data()); + auto result = p.eval(params).back(); + EXPECT(result.get_shape() == output_fixed_shape); + std::vector results_vector{}; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); +} + +TEST_CASE(reshape_2in_elements_runtime_error) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s_in{migraphx::shape::float_type, {2, 24, 1, 1}}; + migraphx::shape s_out{migraphx::shape::float_type, {{2, 4}, {6, 6}, {2, 4}, {1, 1}}}; + auto input = mm->add_parameter("X", s_in); + auto output_buffer = mm->add_parameter("Y", s_out); + mm->add_instruction(migraphx::make_op("reshape"), input, output_buffer); + p.compile(migraphx::make_target("ref")); + + std::vector gold(48); + std::iota(gold.begin(), gold.end(), -3.); + std::vector buffer(48); + std::iota(buffer.begin(), buffer.end(), 0.); + migraphx::parameter_map params; + // elements do not match up + migraphx::shape output_fixed_shape{migraphx::shape::float_type, {2, 6, 2, 1}}; + params["X"] = migraphx::argument(s_in, gold.data()); + params["Y"] = migraphx::argument(output_fixed_shape, buffer.data()); + EXPECT(test::throws([&] { std::ignore = p.eval(params).back(); })); +} From 52c74f0e3fbb414bc372db9ee73564ad449376aa Mon Sep 17 00:00:00 2001 From: Attila Dusnoki <126579622+attila-dusnoki-htec@users.noreply.github.com> Date: Tue, 17 Oct 2023 20:13:24 +0200 Subject: [PATCH 2/6] Add GroupNorm and LayerNorm onnx parsing (#2242) --- src/onnx/parse_groupnorm.cpp | 130 ++++++++ src/onnx/parse_layernorm.cpp | 131 ++++++++ test/onnx/gen_onnx.py | 217 ++++++++++++++ test/onnx/group_norm_3d_half_test.onnx | 30 ++ test/onnx/group_norm_3d_test.onnx | 25 ++ test/onnx/group_norm_4d_half_test.onnx | 32 ++ test/onnx/group_norm_4d_test.onnx | 27 ++ test/onnx/group_norm_5d_half_test.onnx | 34 +++ test/onnx/group_norm_5d_test.onnx | 29 ++ .../group_norm_invalid_bias_shape_test.onnx | 27 ++ ...p_norm_invalid_input_count_error_test.onnx | 22 ++ ...p_norm_invalid_input_shape_error_test.onnx | 23 ++ ...up_norm_invalid_num_groups_error_test.onnx | 27 ++ .../group_norm_invalid_scale_shape_test.onnx | 27 ++ ...oup_norm_missing_attribute_error_test.onnx | 21 ++ test/onnx/group_norm_small_eps_half_test.onnx | 30 ++ .../layer_norm_2d_axis_minus_one_test.onnx | 22 ++ test/onnx/layer_norm_2d_axis_one_test.onnx | 22 ++ test/onnx/layer_norm_2d_axis_zero_test.onnx | Bin 0 -> 214 bytes test/onnx/layer_norm_3d_half_test.onnx | 28 ++ test/onnx/layer_norm_3d_test.onnx | 24 ++ test/onnx/layer_norm_4d_half_test.onnx | 30 ++ test/onnx/layer_norm_4d_test.onnx | 26 ++ .../layer_norm_invalid_axis_error_test.onnx | Bin 0 -> 219 bytes ...r_norm_invalid_input_count_error_test.onnx | 11 + ...er_norm_invalid_minus_axis_error_test.onnx | 26 ++ .../layer_norm_invalid_shape_error_test.onnx | Bin 0 -> 212 bytes test/onnx/layer_norm_small_eps_half_test.onnx | 20 ++ test/onnx/layer_norm_without_bias_test.onnx | 16 + test/onnx/onnx_test.cpp | 282 ++++++++++++++++++ test/onnx/verify_onnx.cpp | 99 ++++++ 31 files changed, 1438 insertions(+) create mode 100644 src/onnx/parse_groupnorm.cpp create mode 100644 src/onnx/parse_layernorm.cpp create mode 100644 test/onnx/group_norm_3d_half_test.onnx create mode 100644 test/onnx/group_norm_3d_test.onnx create mode 100644 test/onnx/group_norm_4d_half_test.onnx create mode 100644 test/onnx/group_norm_4d_test.onnx create mode 100644 test/onnx/group_norm_5d_half_test.onnx create mode 100644 test/onnx/group_norm_5d_test.onnx create mode 100644 test/onnx/group_norm_invalid_bias_shape_test.onnx create mode 100644 test/onnx/group_norm_invalid_input_count_error_test.onnx create mode 100644 test/onnx/group_norm_invalid_input_shape_error_test.onnx create mode 100644 test/onnx/group_norm_invalid_num_groups_error_test.onnx create mode 100644 test/onnx/group_norm_invalid_scale_shape_test.onnx create mode 100644 test/onnx/group_norm_missing_attribute_error_test.onnx create mode 100644 test/onnx/group_norm_small_eps_half_test.onnx create mode 100644 test/onnx/layer_norm_2d_axis_minus_one_test.onnx create mode 100644 test/onnx/layer_norm_2d_axis_one_test.onnx create mode 100644 test/onnx/layer_norm_2d_axis_zero_test.onnx create mode 100644 test/onnx/layer_norm_3d_half_test.onnx create mode 100644 test/onnx/layer_norm_3d_test.onnx create mode 100644 test/onnx/layer_norm_4d_half_test.onnx create mode 100644 test/onnx/layer_norm_4d_test.onnx create mode 100644 test/onnx/layer_norm_invalid_axis_error_test.onnx create mode 100644 test/onnx/layer_norm_invalid_input_count_error_test.onnx create mode 100644 test/onnx/layer_norm_invalid_minus_axis_error_test.onnx create mode 100644 test/onnx/layer_norm_invalid_shape_error_test.onnx create mode 100644 test/onnx/layer_norm_small_eps_half_test.onnx create mode 100644 test/onnx/layer_norm_without_bias_test.onnx diff --git a/src/onnx/parse_groupnorm.cpp b/src/onnx/parse_groupnorm.cpp new file mode 100644 index 00000000000..657e36ea30d --- /dev/null +++ b/src/onnx/parse_groupnorm.cpp @@ -0,0 +1,130 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_groupnorm : op_parser +{ + std::vector operators() const { return {{"GroupNormalization"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& parser, + const onnx_parser::node_info& info, + std::vector args) const + { + float epsilon = 1e-5f; + if(contains(info.attributes, "epsilon")) + { + epsilon = parser.parse_value(info.attributes.at("epsilon")).at(); + } + size_t num_groups; + if(contains(info.attributes, "num_groups")) + { + num_groups = parser.parse_value(info.attributes.at("num_groups")).at(); + } + else + { + MIGRAPHX_THROW("PARSE_GROUPNORM: num_groups must be available"); + } + + if(args.size() != 3) + { + MIGRAPHX_THROW("PARSE_GROUPNORM: invalid input count"); + } + + auto x = args.at(0); + auto scale = args.at(1); + auto bias = args.at(2); + + auto x_shape = x->get_shape(); + auto x_dtype = x_shape.type(); + auto x_dims = x_shape.lens(); + + if(x_shape.ndim() <= 2) + { + MIGRAPHX_THROW("PARSE_GROUPNORM: invalid input shape"); + } + + auto c = x_shape.lens().at(1); + if(c % num_groups != 0) + { + MIGRAPHX_THROW( + "PARSE_GROUPNORM: num_groups should be a divisor of the number of channels"); + } + auto group_size = c / num_groups; + if(scale->get_shape().ndim() != 1 or scale->get_shape().lens().at(0) != num_groups) + { + MIGRAPHX_THROW("PARSE_GROUPNORM: scale tensor shape should be num_groups"); + } + if(bias->get_shape().ndim() != 1 or bias->get_shape().lens().at(0) != num_groups) + { + MIGRAPHX_THROW("PARSE_GROUPNORM: bias tensor shape should be num_groups"); + } + + // Original shape: N x C x D1 x ... x Dn + // New shape: N x num_groups x C // num_groups x D1 x ... x Dn + + std::vector dims = {x_dims.at(0), num_groups, group_size}; + std::copy(x_dims.begin() + 2, x_dims.end(), std::back_inserter(dims)); + auto x_reshaped = info.add_instruction(make_op("reshape", {{"dims", dims}}), x); + + // Axes for D1 x ... x Dn + std::vector axes(dims.size() - 2); + std::iota(axes.begin(), axes.end(), 2); + + // y = (x - mean) * rsqrt(variance + epsilon) * scale + bias + // mean = reduce_mean({D1, D2, ... Dk}, x) + // variance = reduce_mean({D1, D2, ... Dk}, (x - mean)^2) + + auto mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x_reshaped); + auto x_sub_mean = info.add_common_op("sub", x_reshaped, mean); + auto x_sqdiff_mean = info.add_common_op("sqdiff", x_reshaped, mean); + auto variance = + info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x_sqdiff_mean); + epsilon = + (x_dtype == migraphx::shape::half_type and std::abs(epsilon) < 1e-7) ? 1e-7 : epsilon; + auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_dtype}, {epsilon}}); + auto var_eps = info.add_common_op("add", variance, eps); + auto rsqrt = info.add_instruction(make_op("rsqrt"), var_eps); + auto result = info.add_common_op("mul", x_sub_mean, rsqrt); + auto scale_bcast = + info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale); + auto bias_bcast = + info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias); + auto scaled = info.add_instruction(make_op("mul"), result, scale_bcast); + auto y = info.add_instruction(make_op("add"), scaled, bias_bcast); + auto y_reshaped = info.add_instruction(make_op("reshape", {{"dims", x_dims}}), y); + return y_reshaped; + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_layernorm.cpp b/src/onnx/parse_layernorm.cpp new file mode 100644 index 00000000000..5e6e4da7054 --- /dev/null +++ b/src/onnx/parse_layernorm.cpp @@ -0,0 +1,131 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_layernorm : op_parser +{ + std::vector operators() const { return {{"LayerNormalization"}}; } + + std::vector parse(const op_desc& /*opd*/, + const onnx_parser& parser, + const onnx_parser::node_info& info, + std::vector args) const + { + int64_t axis = -1; + if(contains(info.attributes, "axis")) + { + axis = parser.parse_value(info.attributes.at("axis")).at(); + } + float epsilon = 1e-5f; + if(contains(info.attributes, "epsilon")) + { + epsilon = parser.parse_value(info.attributes.at("epsilon")).at(); + } + if(contains(info.attributes, "stash_type")) + { + std::cerr << "WARNING: LAYERNORM does not support stash_type, it will be ignored.\n"; + } + + if(args.size() < 2 or args.size() > 3) + { + MIGRAPHX_THROW("PARSE_LAYERNORM: invalid input count"); + } + + auto x = args.at(0); + auto scale = args.at(1); + bool skip_bias = args.size() == 2; + instruction_ref bias; + if(not skip_bias) + { + bias = args.at(2); + } + + auto x_shape = x->get_shape(); + auto x_dtype = x_shape.type(); + int64_t x_rank = x_shape.ndim(); + + if(x_rank < 2) + { + MIGRAPHX_THROW("PARSE_LAYERNORM: invalid input shape"); + } + + // If rank(X) is r, axis' allowed range is [-r, r) + if(axis < -x_rank or axis >= x_rank) + { + MIGRAPHX_THROW("PARSE_LAYERNORM: invalid axis"); + } + + // y = (x - mean) * rsqrt(variance + epsilon) * scale + bias + // mean = reduce_mean({D1, D2, ... Dk}, x) + // variance = reduce_mean({D1, D2, ... Dk}, (x - mean)^2) + + // axis can be negative + axis = axis < 0 ? axis + x_rank : axis; + + auto kdims = x_rank - axis; + std::vector axes(kdims); + std::iota(axes.begin(), axes.end(), axis); + auto skipped_axes = x_rank - kdims; + + auto mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x); + auto x_sub_mean = info.add_common_op("sub", x, mean); + auto x_sqdiff_mean = info.add_common_op("sqdiff", x, mean); + auto variance = + info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x_sqdiff_mean); + epsilon = + (x_dtype == migraphx::shape::half_type and std::abs(epsilon) < 1e-7) ? 1e-7 : epsilon; + auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_dtype}, {epsilon}}); + auto var_eps = info.add_common_op("add", variance, eps); + auto rsqrt = info.add_instruction(make_op("rsqrt"), var_eps); + auto result = info.add_common_op("mul", x_sub_mean, rsqrt); + + instruction_ref scale_bcast = scale; + instruction_ref bias_bcast = bias; + if(skipped_axes > 0) + { + auto x_dims = x_shape.lens(); + scale_bcast = info.add_instruction( + make_op("broadcast", {{"axis", skipped_axes}, {"out_lens", x_dims}}), scale); + if(not skip_bias) + { + bias_bcast = info.add_instruction( + make_op("broadcast", {{"axis", skipped_axes}, {"out_lens", x_dims}}), bias); + } + } + auto scaled = info.add_instruction(make_op("mul"), result, scale_bcast); + auto y = skip_bias ? scaled : info.add_instruction(make_op("add"), scaled, bias_bcast); + return {y, mean, rsqrt}; + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index 581a264f35f..577eb99f409 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -2722,6 +2722,119 @@ def group_conv_test(): return ([node], [x, y], [z]) +def group_norm_test(x_dims, + scale_dims, + bias_dims, + y_dims, + num_groups, + eps_value=1e-5, + dtype=TensorProto.FLOAT): + x = helper.make_tensor_value_info('x', dtype, x_dims) + scale = helper.make_tensor_value_info('scale', dtype, scale_dims) + bias = helper.make_tensor_value_info('bias', dtype, bias_dims) + y = helper.make_tensor_value_info('y', dtype, y_dims) + + node = onnx.helper.make_node('GroupNormalization', + inputs=['x', 'scale', 'bias'], + outputs=['y'], + num_groups=num_groups, + epsilon=eps_value) + + return ([node], [x, scale, bias], [y]) + + +@onnx_test() +def group_norm_3d_test(): + return group_norm_test([1, 4, 2], [2], [2], [1, 4, 2], 2) + + +@onnx_test() +def group_norm_3d_half_test(): + return group_norm_test([1, 4, 2], [2], [2], [1, 4, 2], + 2, + dtype=TensorProto.FLOAT16) + + +@onnx_test() +def group_norm_4d_test(): + return group_norm_test([1, 4, 3, 3], [2], [2], [1, 4, 3, 3], 2) + + +@onnx_test() +def group_norm_4d_half_test(): + return group_norm_test([1, 4, 3, 3], [2], [2], [1, 4, 3, 3], + 2, + dtype=TensorProto.FLOAT16) + + +@onnx_test() +def group_norm_5d_test(): + return group_norm_test([3, 3, 3, 3, 3], [1], [1], [3, 3, 3, 3, 3], 1) + + +@onnx_test() +def group_norm_5d_half_test(): + return group_norm_test([3, 3, 3, 3, 3], [1], [1], [3, 3, 3, 3, 3], + 1, + dtype=TensorProto.FLOAT16) + + +@onnx_test() +def group_norm_small_eps_half_test(): + return group_norm_test([1, 4, 2], [2], [2], [1, 4, 2], + 2, + eps_value=1e-12, + dtype=TensorProto.FLOAT16) + + +@onnx_test() +def group_norm_invalid_num_groups_error_test(): + return group_norm_test([1, 4, 3, 3], [2], [2], [1, 4, 3, 3], 3) + + +@onnx_test() +def group_norm_missing_attribute_error_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 4]) + scale = helper.make_tensor_value_info('scale', TensorProto.FLOAT, [2]) + bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [2]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 4]) + + node = onnx.helper.make_node('GroupNormalization', + inputs=['x', 'scale', 'bias'], + outputs=['y']) + + return ([node], [x, scale, bias], [y]) + + +@onnx_test() +def group_norm_invalid_input_count_error_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 4, 3, 3]) + scale = helper.make_tensor_value_info('scale', TensorProto.FLOAT, [2]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 4, 3, 3]) + + node = onnx.helper.make_node('GroupNormalization', + inputs=['x', 'scale'], + outputs=['y'], + num_groups=2) + + return ([node], [x, scale], [y]) + + +@onnx_test() +def group_norm_invalid_input_shape_error_test(): + return group_norm_test([1, 4], [2], [2], [1, 4], 2) + + +@onnx_test() +def group_norm_invalid_scale_shape_test(): + return group_norm_test([1, 4, 3, 3], [1], [2], [1, 4, 3, 3], 2) + + +@onnx_test() +def group_norm_invalid_bias_shape_test(): + return group_norm_test([1, 4, 3, 3], [2], [3], [1, 4, 3, 3], 2) + + @onnx_test() def hardsigmoid_default_test(): x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 3, 4, 5]) @@ -3804,6 +3917,110 @@ def layernorm_test(): bias_add], [x, scale, bias], [y], [pow_tensor, epsilon_tensor]) +def make_layer_norm(shape, axis, dtype=TensorProto.FLOAT): + norm_axis = axis + len(shape) if axis < 0 else axis + x = helper.make_tensor_value_info('x', dtype, shape) + scale = helper.make_tensor_value_info('scale', dtype, shape[norm_axis:]) + bias = helper.make_tensor_value_info('bias', dtype, shape[norm_axis:]) + y = helper.make_tensor_value_info('y', dtype, shape) + + node = onnx.helper.make_node('LayerNormalization', + inputs=['x', 'scale', 'bias'], + outputs=['y'], + axis=axis) + + return ([node], [x, scale, bias], [y]) + + +@onnx_test() +def layer_norm_invalid_shape_error_test(): + return make_layer_norm([3], 0) + + +@onnx_test() +def layer_norm_2d_axis_zero_test(): + return make_layer_norm([3, 4], 0) + + +@onnx_test() +def layer_norm_2d_axis_one_test(): + return make_layer_norm([3, 4], 1) + + +@onnx_test() +def layer_norm_2d_axis_minus_one_test(): + return make_layer_norm([3, 4], -1) + + +@onnx_test() +def layer_norm_3d_test(): + return make_layer_norm([1, 4, 2], -1) + + +@onnx_test() +def layer_norm_3d_half_test(): + return make_layer_norm([1, 4, 2], -1, TensorProto.FLOAT16) + + +@onnx_test() +def layer_norm_4d_test(): + return make_layer_norm([3, 3, 3, 3], -1) + + +@onnx_test() +def layer_norm_4d_half_test(): + return make_layer_norm([3, 3, 3, 3], -1, TensorProto.FLOAT16) + + +@onnx_test() +def layer_norm_invalid_axis_error_test(): + return make_layer_norm([1, 4, 2], 1000) + + +@onnx_test() +def layer_norm_invalid_minus_axis_error_test(): + return make_layer_norm([1, 4, 2], -1000) + + +@onnx_test() +def layer_norm_invalid_input_count_error_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 2]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 2]) + + node = onnx.helper.make_node('LayerNormalization', + inputs=['x'], + outputs=['y']) + + return ([node], [x], [y]) + + +@onnx_test() +def layer_norm_without_bias_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 2]) + scale = helper.make_tensor_value_info('scale', TensorProto.FLOAT, [2]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 2]) + + node = onnx.helper.make_node('LayerNormalization', + inputs=['x', 'scale'], + outputs=['y']) + + return ([node], [x, scale], [y]) + + +@onnx_test() +def layer_norm_small_eps_half_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [1, 2]) + scale = helper.make_tensor_value_info('scale', TensorProto.FLOAT16, [2]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [1, 2]) + + node = onnx.helper.make_node('LayerNormalization', + inputs=['x', 'scale'], + outputs=['y'], + epsilon=1e-12) + + return ([node], [x, scale], [y]) + + @onnx_test() def leaky_relu_test(): x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3]) diff --git a/test/onnx/group_norm_3d_half_test.onnx b/test/onnx/group_norm_3d_half_test.onnx new file mode 100644 index 00000000000..03d89cd0f69 --- /dev/null +++ b/test/onnx/group_norm_3d_half_test.onnx @@ -0,0 +1,30 @@ +group_norm_3d_half_test:à +M +x +scale +biasy"GroupNormalization* +epsilon¬Å'7 * + +num_groups group_norm_3d_half_testZ +x + + + + +Z +scale + + + +Z +bias + + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/group_norm_3d_test.onnx b/test/onnx/group_norm_3d_test.onnx new file mode 100644 index 00000000000..e264ba459cc --- /dev/null +++ b/test/onnx/group_norm_3d_test.onnx @@ -0,0 +1,25 @@ + group_norm_3d_test:« +: +x +scale +biasy"GroupNormalization* + +num_groups group_norm_3d_testZ +x + + + +Z +scale + + +Z +bias + + +b +y + + + +B \ No newline at end of file diff --git a/test/onnx/group_norm_4d_half_test.onnx b/test/onnx/group_norm_4d_half_test.onnx new file mode 100644 index 00000000000..48302d36b9a --- /dev/null +++ b/test/onnx/group_norm_4d_half_test.onnx @@ -0,0 +1,32 @@ +group_norm_4d_half_test:Ë +M +x +scale +biasy"GroupNormalization* +epsilon¬Å'7 * + +num_groups group_norm_4d_half_testZ +x + + + + + +Z +scale + + + +Z +bias + + + +b +y + + + + + +B \ No newline at end of file diff --git a/test/onnx/group_norm_4d_test.onnx b/test/onnx/group_norm_4d_test.onnx new file mode 100644 index 00000000000..03e86d2d2bc --- /dev/null +++ b/test/onnx/group_norm_4d_test.onnx @@ -0,0 +1,27 @@ + group_norm_4d_test:³ +: +x +scale +biasy"GroupNormalization* + +num_groups group_norm_4d_testZ +x + + + + +Z +scale + + +Z +bias + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/group_norm_5d_half_test.onnx b/test/onnx/group_norm_5d_half_test.onnx new file mode 100644 index 00000000000..af8ecdb4800 --- /dev/null +++ b/test/onnx/group_norm_5d_half_test.onnx @@ -0,0 +1,34 @@ +group_norm_5d_half_test:Ó +M +x +scale +biasy"GroupNormalization* +epsilon¬Å'7 * + +num_groups group_norm_5d_half_testZ +x + + + + + + +Z +scale + + + +Z +bias + + + +b +y + + + + + + +B \ No newline at end of file diff --git a/test/onnx/group_norm_5d_test.onnx b/test/onnx/group_norm_5d_test.onnx new file mode 100644 index 00000000000..cdcf82369cc --- /dev/null +++ b/test/onnx/group_norm_5d_test.onnx @@ -0,0 +1,29 @@ + group_norm_5d_test:» +: +x +scale +biasy"GroupNormalization* + +num_groups group_norm_5d_testZ +x + + + + + +Z +scale + + +Z +bias + + +b +y + + + + + +B \ No newline at end of file diff --git a/test/onnx/group_norm_invalid_bias_shape_test.onnx b/test/onnx/group_norm_invalid_bias_shape_test.onnx new file mode 100644 index 00000000000..3b684c2e20b --- /dev/null +++ b/test/onnx/group_norm_invalid_bias_shape_test.onnx @@ -0,0 +1,27 @@ + "group_norm_invalid_bias_shape_test:à +: +x +scale +biasy"GroupNormalization* + +num_groups "group_norm_invalid_bias_shape_testZ +x + + + + +Z +scale + + +Z +bias + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/group_norm_invalid_input_count_error_test.onnx b/test/onnx/group_norm_invalid_input_count_error_test.onnx new file mode 100644 index 00000000000..977682fc10e --- /dev/null +++ b/test/onnx/group_norm_invalid_input_count_error_test.onnx @@ -0,0 +1,22 @@ + )group_norm_invalid_input_count_error_test:° +4 +x +scaley"GroupNormalization* + +num_groups )group_norm_invalid_input_count_error_testZ +x + + + + +Z +scale + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/group_norm_invalid_input_shape_error_test.onnx b/test/onnx/group_norm_invalid_input_shape_error_test.onnx new file mode 100644 index 00000000000..904e5e4e847 --- /dev/null +++ b/test/onnx/group_norm_invalid_input_shape_error_test.onnx @@ -0,0 +1,23 @@ + )group_norm_invalid_input_shape_error_test:º +: +x +scale +biasy"GroupNormalization* + +num_groups )group_norm_invalid_input_shape_error_testZ +x +  + +Z +scale + + +Z +bias + + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/group_norm_invalid_num_groups_error_test.onnx b/test/onnx/group_norm_invalid_num_groups_error_test.onnx new file mode 100644 index 00000000000..c8d581d170a --- /dev/null +++ b/test/onnx/group_norm_invalid_num_groups_error_test.onnx @@ -0,0 +1,27 @@ + (group_norm_invalid_num_groups_error_test:É +: +x +scale +biasy"GroupNormalization* + +num_groups (group_norm_invalid_num_groups_error_testZ +x + + + + +Z +scale + + +Z +bias + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/group_norm_invalid_scale_shape_test.onnx b/test/onnx/group_norm_invalid_scale_shape_test.onnx new file mode 100644 index 00000000000..c8587f7a561 --- /dev/null +++ b/test/onnx/group_norm_invalid_scale_shape_test.onnx @@ -0,0 +1,27 @@ + #group_norm_invalid_scale_shape_test:Ä +: +x +scale +biasy"GroupNormalization* + +num_groups #group_norm_invalid_scale_shape_testZ +x + + + + +Z +scale + + +Z +bias + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/group_norm_missing_attribute_error_test.onnx b/test/onnx/group_norm_missing_attribute_error_test.onnx new file mode 100644 index 00000000000..59849649b53 --- /dev/null +++ b/test/onnx/group_norm_missing_attribute_error_test.onnx @@ -0,0 +1,21 @@ + 'group_norm_missing_attribute_error_test:¥ +' +x +scale +biasy"GroupNormalization'group_norm_missing_attribute_error_testZ +x +  + +Z +scale + + +Z +bias + + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/group_norm_small_eps_half_test.onnx b/test/onnx/group_norm_small_eps_half_test.onnx new file mode 100644 index 00000000000..0b5e538598e --- /dev/null +++ b/test/onnx/group_norm_small_eps_half_test.onnx @@ -0,0 +1,30 @@ +group_norm_small_eps_half_test:Ê +M +x +scale +biasy"GroupNormalization* +epsilon̼Œ+ * + +num_groups group_norm_small_eps_half_testZ +x + + + + +Z +scale + + + +Z +bias + + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/layer_norm_2d_axis_minus_one_test.onnx b/test/onnx/layer_norm_2d_axis_minus_one_test.onnx new file mode 100644 index 00000000000..7bb5eaaeec9 --- /dev/null +++ b/test/onnx/layer_norm_2d_axis_minus_one_test.onnx @@ -0,0 +1,22 @@ + !layer_norm_2d_axis_minus_one_test:µ += +x +scale +biasy"LayerNormalization* +axisÿÿÿÿÿÿÿÿÿ !layer_norm_2d_axis_minus_one_testZ +x +  + +Z +scale + + +Z +bias + + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/layer_norm_2d_axis_one_test.onnx b/test/onnx/layer_norm_2d_axis_one_test.onnx new file mode 100644 index 00000000000..fa1444c0a06 --- /dev/null +++ b/test/onnx/layer_norm_2d_axis_one_test.onnx @@ -0,0 +1,22 @@ + layer_norm_2d_axis_one_test:¦ +4 +x +scale +biasy"LayerNormalization* +axis layer_norm_2d_axis_one_testZ +x +  + +Z +scale + + +Z +bias + + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/layer_norm_2d_axis_zero_test.onnx b/test/onnx/layer_norm_2d_axis_zero_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..74e2ea452cdc1062d0d1ceb3632f1aa6efc02ea5 GIT binary patch literal 214 zcmd2h^09lUbEml9`{U#m&V6(k;QTfRPEmu~EW6GllrLcsLk^IJlTN en1PrjN*rnnhJ+Z@UNj9!KvOEw6+1Br2m=760Wik^ literal 0 HcmV?d00001 diff --git a/test/onnx/layer_norm_3d_half_test.onnx b/test/onnx/layer_norm_3d_half_test.onnx new file mode 100644 index 00000000000..1d65c0e5b33 --- /dev/null +++ b/test/onnx/layer_norm_3d_half_test.onnx @@ -0,0 +1,28 @@ +layer_norm_3d_half_test:³ += +x +scale +biasy"LayerNormalization* +axisÿÿÿÿÿÿÿÿÿ layer_norm_3d_half_testZ +x + + + + +Z +scale + + + +Z +bias + + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/layer_norm_3d_test.onnx b/test/onnx/layer_norm_3d_test.onnx new file mode 100644 index 00000000000..a8b32801ca9 --- /dev/null +++ b/test/onnx/layer_norm_3d_test.onnx @@ -0,0 +1,24 @@ + layer_norm_3d_test:® += +x +scale +biasy"LayerNormalization* +axisÿÿÿÿÿÿÿÿÿ layer_norm_3d_testZ +x + + + +Z +scale + + +Z +bias + + +b +y + + + +B \ No newline at end of file diff --git a/test/onnx/layer_norm_4d_half_test.onnx b/test/onnx/layer_norm_4d_half_test.onnx new file mode 100644 index 00000000000..0b4f32087e0 --- /dev/null +++ b/test/onnx/layer_norm_4d_half_test.onnx @@ -0,0 +1,30 @@ +layer_norm_4d_half_test:» += +x +scale +biasy"LayerNormalization* +axisÿÿÿÿÿÿÿÿÿ layer_norm_4d_half_testZ +x + + + + + +Z +scale + + + +Z +bias + + + +b +y + + + + + +B \ No newline at end of file diff --git a/test/onnx/layer_norm_4d_test.onnx b/test/onnx/layer_norm_4d_test.onnx new file mode 100644 index 00000000000..214aae27228 --- /dev/null +++ b/test/onnx/layer_norm_4d_test.onnx @@ -0,0 +1,26 @@ + layer_norm_4d_test:¶ += +x +scale +biasy"LayerNormalization* +axisÿÿÿÿÿÿÿÿÿ layer_norm_4d_testZ +x + + + + +Z +scale + + +Z +bias + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/layer_norm_invalid_axis_error_test.onnx b/test/onnx/layer_norm_invalid_axis_error_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..eeaaa4af4de4c877f85ea8da3c12ecc767ec807e GIT binary patch literal 219 zcmdbPEkY3#f?A8DFbSx$5?PTGlYlS)QW!Q` literal 0 HcmV?d00001 diff --git a/test/onnx/layer_norm_small_eps_half_test.onnx b/test/onnx/layer_norm_small_eps_half_test.onnx new file mode 100644 index 00000000000..08411c7cacc --- /dev/null +++ b/test/onnx/layer_norm_small_eps_half_test.onnx @@ -0,0 +1,20 @@ +layer_norm_small_eps_half_test:• +4 +x +scaley"LayerNormalization* +epsilon̼Œ+ layer_norm_small_eps_half_testZ +x +  + + +Z +scale + + + +b +y +  + + +B \ No newline at end of file diff --git a/test/onnx/layer_norm_without_bias_test.onnx b/test/onnx/layer_norm_without_bias_test.onnx new file mode 100644 index 00000000000..4f584da4382 --- /dev/null +++ b/test/onnx/layer_norm_without_bias_test.onnx @@ -0,0 +1,16 @@ + layer_norm_without_bias_test:€ +! +x +scaley"LayerNormalizationlayer_norm_without_bias_testZ +x +  + +Z +scale + + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/onnx_test.cpp b/test/onnx/onnx_test.cpp index 368b8918e2f..415a612d0e0 100644 --- a/test/onnx/onnx_test.cpp +++ b/test/onnx/onnx_test.cpp @@ -2786,6 +2786,145 @@ TEST_CASE(group_conv_test) EXPECT(p == prog); } +migraphx::program make_group_norm(const std::vector& input_dims, + const std::vector& scale_dims, + const std::vector& bias_dims, + const std::vector& reshape_dims, + const std::vector& reduce_axes, + const float eps_value = 1e-5f, + const migraphx::shape::type_t dtype = migraphx::shape::float_type) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", {dtype, input_dims}); + auto scale = mm->add_parameter("scale", {dtype, scale_dims}); + auto bias = mm->add_parameter("bias", {dtype, bias_dims}); + + auto eps = mm->add_literal(migraphx::literal{dtype, {eps_value}}); + + auto x_reshaped = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", reshape_dims}}), x); + auto mean = + mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", reduce_axes}}), x_reshaped); + auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x_reshaped, mean}); + auto x_sqdiff_mean = add_common_op(*mm, migraphx::make_op("sqdiff"), {x_reshaped, mean}); + auto var = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", reduce_axes}}), + x_sqdiff_mean); + auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {var, eps}); + auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), {var_eps}); + auto result = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, rsqrt}); + auto scale_bcast = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", reshape_dims}}), scale); + auto bias_bcast = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", reshape_dims}}), bias); + auto scaled = mm->add_instruction(migraphx::make_op("mul"), {result, scale_bcast}); + auto y = mm->add_instruction(migraphx::make_op("add"), {scaled, bias_bcast}); + mm->add_instruction(migraphx::make_op("reshape", {{"dims", input_dims}}), y); + + return p; +} + +TEST_CASE(group_norm_3d_test) +{ + migraphx::program p = make_group_norm( + {1, 4, 2}, {2}, {2}, {1, 2, 2, 2}, {2, 3}, 1e-5f, migraphx::shape::float_type); + auto prog = optimize_onnx("group_norm_3d_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(group_norm_3d_half_test) +{ + migraphx::program p = make_group_norm( + {1, 4, 2}, {2}, {2}, {1, 2, 2, 2}, {2, 3}, 1e-5f, migraphx::shape::half_type); + auto prog = optimize_onnx("group_norm_3d_half_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(group_norm_4d_test) +{ + migraphx::program p = make_group_norm( + {1, 4, 3, 3}, {2}, {2}, {1, 2, 2, 3, 3}, {2, 3, 4}, 1e-5f, migraphx::shape::float_type); + auto prog = optimize_onnx("group_norm_4d_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(group_norm_4d_half_test) +{ + migraphx::program p = make_group_norm( + {1, 4, 3, 3}, {2}, {2}, {1, 2, 2, 3, 3}, {2, 3, 4}, 1e-5f, migraphx::shape::half_type); + auto prog = optimize_onnx("group_norm_4d_half_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(group_norm_5d_test) +{ + migraphx::program p = make_group_norm({3, 3, 3, 3, 3}, + {1}, + {1}, + {3, 1, 3, 3, 3, 3}, + {2, 3, 4, 5}, + 1e-5f, + migraphx::shape::float_type); + auto prog = optimize_onnx("group_norm_5d_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(group_norm_5d_half_test) +{ + migraphx::program p = make_group_norm({3, 3, 3, 3, 3}, + {1}, + {1}, + {3, 1, 3, 3, 3, 3}, + {2, 3, 4, 5}, + 1e-5f, + migraphx::shape::half_type); + auto prog = optimize_onnx("group_norm_5d_half_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(group_norm_small_eps_half_test) +{ + migraphx::program p = make_group_norm( + {1, 4, 2}, {2}, {2}, {1, 2, 2, 2}, {2, 3}, 1e-7f, migraphx::shape::half_type); + auto prog = optimize_onnx("group_norm_small_eps_half_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(group_norm_invalid_num_groups_error_test) +{ + EXPECT(test::throws( + [&] { migraphx::parse_onnx("group_norm_invalid_num_groups_error_test.onnx"); })); +} + +TEST_CASE(group_norm_missing_attribute_error_test) +{ + EXPECT(test::throws( + [&] { migraphx::parse_onnx("group_norm_missing_attribute_error_test.onnx"); })); +} + +TEST_CASE(group_norm_invalid_input_count_error_test) +{ + EXPECT(test::throws( + [&] { migraphx::parse_onnx("group_norm_invalid_input_count_error_test.onnx"); })); +} + +TEST_CASE(group_norm_invalid_input_shape_error_test) +{ + EXPECT(test::throws( + [&] { migraphx::parse_onnx("group_norm_invalid_input_shape_error_test.onnx"); })); +} + +TEST_CASE(group_norm_invalid_scale_shape_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("group_norm_invalid_scale_shape_test.onnx"); })); +} + +TEST_CASE(group_norm_invalid_bias_shape_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("group_norm_invalid_bias_shape_test.onnx"); })); +} + TEST_CASE(hardsigmoid_default_test) { migraphx::program p; @@ -3648,6 +3787,149 @@ TEST_CASE(lessorequal_test) EXPECT(p == prog); } +migraphx::program make_layer_norm(const std::vector& input_shape, + const std::vector& scale_bias_shape, + const std::vector& reduce_axes, + size_t skipped_axis, + bool skip_bias = false, + const float eps_value = 1e-5f, + const migraphx::shape::type_t dtype = migraphx::shape::float_type) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", {dtype, input_shape}); + auto scale = mm->add_parameter("scale", {dtype, scale_bias_shape}); + migraphx::instruction_ref bias; + if(not skip_bias) + { + bias = mm->add_parameter("bias", {dtype, scale_bias_shape}); + } + + auto eps = mm->add_literal(migraphx::literal{dtype, {eps_value}}); + + auto mean = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", reduce_axes}}), x); + auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, mean}); + auto x_sqdiff_mean = add_common_op(*mm, migraphx::make_op("sqdiff"), {x, mean}); + auto var = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", reduce_axes}}), + x_sqdiff_mean); + auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {var, eps}); + auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), {var_eps}); + auto result = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, rsqrt}); + migraphx::instruction_ref scale_bcast = scale; + migraphx::instruction_ref bias_bcast = bias; + if(skipped_axis > 0) + { + scale_bcast = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", skipped_axis}, {"out_lens", input_shape}}), + scale); + if(not skip_bias) + { + bias_bcast = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", skipped_axis}, {"out_lens", input_shape}}), + bias); + } + } + auto scaled = mm->add_instruction(migraphx::make_op("mul"), {result, scale_bcast}); + if(not skip_bias) + { + mm->add_instruction(migraphx::make_op("add"), {scaled, bias_bcast}); + } + + return p; +} + +TEST_CASE(layer_norm_2d_axis_zero_test) +{ + migraphx::program p = make_layer_norm({3, 4}, {3, 4}, {0, 1}, 0); + + auto prog = optimize_onnx("layer_norm_2d_axis_zero_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(layer_norm_2d_axis_one_test) +{ + migraphx::program p = make_layer_norm({3, 4}, {4}, {1}, 1); + + auto prog = optimize_onnx("layer_norm_2d_axis_one_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(layer_norm_2d_axis_minus_one_test) +{ + migraphx::program p = make_layer_norm({3, 4}, {4}, {1}, 1); + + auto prog = optimize_onnx("layer_norm_2d_axis_one_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(layer_norm_3d_test) +{ + migraphx::program p = make_layer_norm({1, 4, 2}, {2}, {2}, 2); + + auto prog = optimize_onnx("layer_norm_3d_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(layer_norm_3d_half_test) +{ + migraphx::program p = + make_layer_norm({1, 4, 2}, {2}, {2}, 2, false, 1e-5f, migraphx::shape::half_type); + + auto prog = optimize_onnx("layer_norm_3d_half_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(layer_norm_4d_test) +{ + migraphx::program p = make_layer_norm({3, 3, 3, 3}, {3}, {3}, 3); + + auto prog = optimize_onnx("layer_norm_4d_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(layer_norm_4d_half_test) +{ + migraphx::program p = + make_layer_norm({3, 3, 3, 3}, {3}, {3}, 3, false, 1e-5f, migraphx::shape::half_type); + + auto prog = optimize_onnx("layer_norm_4d_half_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(layer_norm_invalid_axis_error_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("layer_norm_invalid_axis_error_test.onnx"); })); +} + +TEST_CASE(layer_norm_invalid_minus_axis_error_test) +{ + EXPECT(test::throws( + [&] { migraphx::parse_onnx("layer_norm_invalid_minus_axis_error_test.onnx"); })); +} + +TEST_CASE(layer_norm_invalid_input_count_error_test) +{ + EXPECT(test::throws( + [&] { migraphx::parse_onnx("layer_norm_invalid_input_count_error_test.onnx"); })); +} + +TEST_CASE(layer_norm_without_bias_test) +{ + migraphx::program p = make_layer_norm({1, 2}, {2}, {1}, 1, true); + + auto prog = optimize_onnx("layer_norm_without_bias_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(layer_norm_small_eps_half_test) +{ + migraphx::program p = + make_layer_norm({1, 2}, {2}, {1}, 1, true, 1e-7, migraphx::shape::half_type); + + auto prog = optimize_onnx("layer_norm_small_eps_half_test.onnx"); + EXPECT(p == prog); +} + TEST_CASE(log_test) { migraphx::program p; diff --git a/test/onnx/verify_onnx.cpp b/test/onnx/verify_onnx.cpp index f74efe08392..5e6a5094e45 100644 --- a/test/onnx/verify_onnx.cpp +++ b/test/onnx/verify_onnx.cpp @@ -538,6 +538,70 @@ TEST_CASE(gemm_half_test) EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); } +template +std::vector norm_test(const std::vector& x_dims, + std::vector& scale, + std::vector& bias, + const std::string& onnx_file) +{ + migraphx::program p = migraphx::parse_onnx(onnx_file); + p.compile(migraphx::make_target("ref")); + + migraphx::shape s_x{migraphx::shape::get_type{}, x_dims}; + migraphx::shape s_s{migraphx::shape::get_type{}, {scale.size()}}; + migraphx::shape s_b{migraphx::shape::get_type{}, {scale.size()}}; + + std::vector x(s_x.elements()); + std::iota(std::begin(x), std::end(x), 1); + + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(s_x, x.data()); + pp["scale"] = migraphx::argument(s_s, scale.data()); + pp["bias"] = migraphx::argument(s_b, bias.data()); + + auto result = p.eval(pp).back(); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + return result_vector; +} + +TEST_CASE(group_norm_test) +{ + std::vector scale{1.2, 0.8}; + std::vector bias{0.5, 0.2}; + std::vector result_vector = + norm_test({1, 4, 2}, scale, bias, "group_norm_3d_test.onnx"); + std::vector gold = {-1.10996256, + -0.0366542, + 1.0366542, + 2.10996256, + -0.87330837, + -0.15776947, + 0.55776947, + 1.27330837}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(group_norm_half_test) +{ + using migraphx::half; + std::vector scale{half{1.2}, half{0.8}}; + std::vector bias{half{0.5}, half{0.2}}; + std::vector result_vector = + norm_test({1, 4, 2}, scale, bias, "group_norm_3d_half_test.onnx"); + std::vector gold = {half{-1.10996256}, + half{-0.0366542}, + half{1.0366542}, + half{2.10996256}, + half{-0.87330837}, + half{-0.15776947}, + half{0.55776947}, + half{1.27330837}}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + TEST_CASE(greaterorequal_test) { migraphx::program p = migraphx::parse_onnx("greaterorequal_test.onnx"); @@ -950,6 +1014,41 @@ TEST_CASE(instance_norm_3d_test) EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); } +TEST_CASE(layer_norm_test) +{ + std::vector scale{1.2, 0.8}; + std::vector bias{0.5, 0.2}; + std::vector result_vector = + norm_test({1, 4, 2}, scale, bias, "layer_norm_3d_test.onnx"); + std::vector gold = {-0.69997597, + 0.99998398, + -0.69997597, + 0.99998398, + -0.69997597, + 0.99998398, + -0.69997597, + 0.99998398}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(layer_norm_half_test) +{ + using migraphx::half; + std::vector scale{half{1.2}, half{0.8}}; + std::vector bias{half{0.5}, half{0.2}}; + std::vector result_vector = + norm_test({1, 4, 2}, scale, bias, "layer_norm_3d_half_test.onnx"); + std::vector gold = {half{-0.69997597}, + half{0.99998398}, + half{-0.69997597}, + half{0.99998398}, + half{-0.69997597}, + half{0.99998398}, + half{-0.69997597}, + half{0.99998398}}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + TEST_CASE(lessorequal_test) { migraphx::program p = migraphx::parse_onnx("lessorequal_test.onnx"); From 94bda2432b9d5d228a7a6ff6ef92c77f7798dc89 Mon Sep 17 00:00:00 2001 From: Attila Dusnoki <126579622+attila-dusnoki-htec@users.noreply.github.com> Date: Tue, 17 Oct 2023 22:07:49 +0200 Subject: [PATCH 3/6] Add axes (optional) input to Pad (#2178) --- src/onnx/parse_pad.cpp | 147 ++++++++++---- test/onnx/gen_onnx.py | 182 ++++++++++++++++++ test/onnx/onnx_test.cpp | 82 ++++++++ test/onnx/pad_4arg_axes_test.onnx | Bin 0 -> 305 bytes .../pad_4arg_invalid_axes_error_test.onnx | Bin 0 -> 334 bytes test/onnx/pad_4arg_neg_axes_test.onnx | Bin 0 -> 331 bytes .../pad_asym_invalid_pads_error_test.onnx | Bin 0 -> 171 bytes test/onnx/pad_asym_test.onnx | Bin 0 -> 136 bytes test/onnx/pad_reflect_with_axes_test.onnx | 18 ++ 9 files changed, 396 insertions(+), 33 deletions(-) create mode 100644 test/onnx/pad_4arg_axes_test.onnx create mode 100644 test/onnx/pad_4arg_invalid_axes_error_test.onnx create mode 100644 test/onnx/pad_4arg_neg_axes_test.onnx create mode 100644 test/onnx/pad_asym_invalid_pads_error_test.onnx create mode 100644 test/onnx/pad_asym_test.onnx create mode 100644 test/onnx/pad_reflect_with_axes_test.onnx diff --git a/src/onnx/parse_pad.cpp b/src/onnx/parse_pad.cpp index 5f425211c66..a654ca06b59 100644 --- a/src/onnx/parse_pad.cpp +++ b/src/onnx/parse_pad.cpp @@ -115,34 +115,9 @@ struct parse_pad : op_parser { std::vector operators() const { return {{"Pad"}}; } - instruction_ref parse(const op_desc& /*opd*/, - const onnx_parser& parser, - onnx_parser::node_info info, - std::vector args) const + std::string parse_mode(const onnx_parser::node_info& info, + const std::vector& args) const { - std::vector pads{}; - if(args.size() >= 2) - { - auto pad_arg = args.at(1)->eval(); - check_arg_empty(pad_arg, "PARSE_PAD: pad input must be constant"); - pad_arg.visit([&](auto v) { pads.assign(v.begin(), v.end()); }); - } - else if(contains(info.attributes, "pads")) - { - auto&& pad_vals = info.attributes["pads"].ints(); - pads = std::vector(pad_vals.begin(), pad_vals.end()); - } - else - { - MIGRAPHX_THROW("PARSE_PAD: pad must be available"); - } - - // check if padding is actually being done (at least one value is nonzero) - if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; })) - { - return info.add_instruction(make_op("identity"), args.front()); - } - if(contains(info.attributes, "mode")) { auto mode = info.attributes.at("mode").s(); @@ -152,28 +127,59 @@ struct parse_pad : op_parser { MIGRAPHX_THROW("PARSE_PAD: reflect padding with dynamic shape not supported"); } - return reflect_pad(info, pads, args.front()); } - if(mode != "constant") + else if(mode != "constant") { MIGRAPHX_THROW( "PARSE_PAD: migraphx currently only supports constant and reflect padding"); } + return mode; + } + else + { + // default mode + return "constant"; } + } + std::vector parse_pads(const onnx_parser::node_info& info, + const std::vector& args) const + { + std::vector pads{}; + if(args.size() >= 2) + { + auto pad_arg = args.at(1)->eval(); + check_arg_empty(pad_arg, "PARSE_PAD: `pads` input must be constant"); + pad_arg.visit([&](auto v) { pads.assign(v.begin(), v.end()); }); + } + else if(contains(info.attributes, "pads")) + { + auto&& pad_vals = info.attributes.at("pads").ints(); + pads = std::vector(pad_vals.begin(), pad_vals.end()); + } + else + { + MIGRAPHX_THROW("PARSE_PAD: `pads` must be available"); + } + return pads; + } + + float parse_constant_value(const onnx_parser& parser, + const onnx_parser::node_info& info, + const std::vector& args) const + { float value = 0.0f; - // third input is the value - if(args.size() == 3) + if(args.size() >= 3 and args.at(2)->get_shape().scalar()) { auto val_ins = args.at(2); if(not val_ins->can_eval()) { - MIGRAPHX_THROW("PARSE_PAD: input value must be constant"); + MIGRAPHX_THROW("PARSE_PAD: input `value` must be constant"); } auto val_arg = val_ins->eval(); if(val_arg.get_shape().elements() != 1) { - MIGRAPHX_THROW("PARSE_PAD: value should contain only one element"); + MIGRAPHX_THROW("PARSE_PAD: `value` should contain only one element"); } value = val_arg.at(); } @@ -181,6 +187,81 @@ struct parse_pad : op_parser { value = parser.parse_value(info.attributes.at("value")).at(); } + return value; + } + + std::vector parse_axes(const std::vector& args, + bool is_constant_mode) const + { + std::vector axes{}; + // axes is 3rd or 4th, depending on constant mode + auto pos = is_constant_mode ? 4 : 3; + if(args.size() >= pos) + { + auto axes_arg = args.at(pos - 1)->eval(); + check_arg_empty(axes_arg, "PARSE_PAD: variable `axes` input not supported"); + axes_arg.visit([&](auto v) { axes.assign(v.begin(), v.end()); }); + } + return axes; + } + + std::vector calculate_pads_with_axes(const std::vector& pads, + const std::vector& axes, + size_t input_rank) const + { + size_t num_axes = axes.size(); + if(num_axes * 2 != pads.size()) + { + MIGRAPHX_THROW("PARSE_PAD: number of elements of pads should be equal to 2 * " + "number of elements of axes"); + } + + std::vector new_pads(input_rank * 2); + for(size_t idx{0}; idx < num_axes; ++idx) + { + // axis can be negative + int64_t axis = axes[idx] < 0 ? input_rank + axes[idx] : axes[idx]; + // pad format is x1_begin, x2_begin, ... , x3_end, x4_end + new_pads[axis] = pads[idx]; + new_pads[axis + input_rank] = pads[idx + num_axes]; + } + return new_pads; + } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& parser, + const onnx_parser::node_info& info, + const std::vector& args) const + { + std::vector pads = parse_pads(info, args); + + // check if padding is actually being done (at least one value is nonzero) + if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; })) + { + return info.add_instruction(make_op("identity"), args.front()); + } + + std::string mode = parse_mode(info, args); + bool is_constant_mode = mode == "constant"; + float value = is_constant_mode ? parse_constant_value(parser, info, args) : 0.0f; + std::vector axes = parse_axes(args, is_constant_mode); + size_t input_rank = args.front()->get_shape().ndim(); + + if(not axes.empty()) + { + pads = calculate_pads_with_axes(pads, axes, input_rank); + } + + if(pads.size() != input_rank * 2) + { + MIGRAPHX_THROW("PARSE_PAD: number of elements of pads should be equal to 2 * " + "input rank"); + } + + if(mode == "reflect") + { + return reflect_pad(info, pads, args.front()); + } return info.add_instruction(migraphx::make_op("pad", {{"pads", pads}, {"value", value}}), args.front()); diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index 577eb99f409..de4493cc6fe 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -5107,6 +5107,32 @@ def pad_test(): return ([node], [x], [y]) +@onnx_test() +def pad_asym_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 4, 5]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 6, 4, 12]) + + node = onnx.helper.make_node('Pad', + inputs=['0'], + pads=[0, 1, 0, 3, 0, 2, 0, 4], + outputs=['1']) + + return ([node], [x], [y]) + + +@onnx_test() +def pad_asym_invalid_pads_error_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 4, 5]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 6, 4, 12]) + + node = onnx.helper.make_node('Pad', + inputs=['0'], + pads=[0, 1, 0, 3, 0, 2], + outputs=['1']) + + return ([node], [x], [y]) + + @onnx_test() def pad_3arg_test(): values = np.array([1]) @@ -5139,6 +5165,129 @@ def pad_3arg_test(): return ([arg_val, arg_pad, node], [x], [y]) +@onnx_test() +def pad_4arg_axes_test(): + values = np.array([1]) + val_tensor = helper.make_tensor(name='val', + data_type=TensorProto.FLOAT, + dims=values.reshape(()).shape, + vals=values.astype(float)) + arg_val = onnx.helper.make_node('Constant', + inputs=[], + outputs=['arg_val'], + value=val_tensor) + + sizes = np.array([1, 3, 2, 4]) + pad_tensor = helper.make_tensor(name='pad_size', + data_type=TensorProto.INT32, + dims=sizes.shape, + vals=sizes.astype(int)) + arg_pad = onnx.helper.make_node('Constant', + inputs=[], + outputs=['arg_pad'], + value=pad_tensor) + + axes = np.array([1, 3]) + axes_tensor = helper.make_tensor(name='pad_axes', + data_type=TensorProto.INT32, + dims=axes.shape, + vals=axes.astype(int)) + arg_axes = onnx.helper.make_node('Constant', + inputs=[], + outputs=['arg_axes'], + value=axes_tensor) + + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 4, 5]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 6, 4, 12]) + + node = onnx.helper.make_node( + 'Pad', inputs=['0', 'arg_pad', 'arg_val', 'arg_axes'], outputs=['1']) + + return ([arg_axes, arg_val, arg_pad, node], [x], [y]) + + +@onnx_test() +def pad_4arg_invalid_axes_error_test(): + values = np.array([1]) + val_tensor = helper.make_tensor(name='val', + data_type=TensorProto.FLOAT, + dims=values.reshape(()).shape, + vals=values.astype(float)) + arg_val = onnx.helper.make_node('Constant', + inputs=[], + outputs=['arg_val'], + value=val_tensor) + + sizes = np.array([1, 3, 2, 4]) + pad_tensor = helper.make_tensor(name='pad_size', + data_type=TensorProto.INT32, + dims=sizes.shape, + vals=sizes.astype(int)) + arg_pad = onnx.helper.make_node('Constant', + inputs=[], + outputs=['arg_pad'], + value=pad_tensor) + + axes = np.array([1, 2, 3]) + axes_tensor = helper.make_tensor(name='pad_axes', + data_type=TensorProto.INT32, + dims=axes.shape, + vals=axes.astype(int)) + arg_axes = onnx.helper.make_node('Constant', + inputs=[], + outputs=['arg_axes'], + value=axes_tensor) + + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 4, 5]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 6, 4, 12]) + + node = onnx.helper.make_node( + 'Pad', inputs=['0', 'arg_pad', 'arg_val', 'arg_axes'], outputs=['1']) + + return ([arg_axes, arg_val, arg_pad, node], [x], [y]) + + +@onnx_test() +def pad_4arg_neg_axes_test(): + values = np.array([1]) + val_tensor = helper.make_tensor(name='val', + data_type=TensorProto.FLOAT, + dims=values.reshape(()).shape, + vals=values.astype(float)) + arg_val = onnx.helper.make_node('Constant', + inputs=[], + outputs=['arg_val'], + value=val_tensor) + + sizes = np.array([1, 3, 2, 4]) + pad_tensor = helper.make_tensor(name='pad_size', + data_type=TensorProto.INT32, + dims=sizes.shape, + vals=sizes.astype(int)) + arg_pad = onnx.helper.make_node('Constant', + inputs=[], + outputs=['arg_pad'], + value=pad_tensor) + + axes = np.array([-3, -1]) + axes_tensor = helper.make_tensor(name='pad_axes', + data_type=TensorProto.INT32, + dims=axes.shape, + vals=axes.astype(int)) + arg_axes = onnx.helper.make_node('Constant', + inputs=[], + outputs=['arg_axes'], + value=axes_tensor) + + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 4, 5]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 6, 4, 12]) + + node = onnx.helper.make_node( + 'Pad', inputs=['0', 'arg_pad', 'arg_val', 'arg_axes'], outputs=['1']) + + return ([arg_axes, arg_val, arg_pad, node], [x], [y]) + + @onnx_test() def pad_reflect_test(): x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 2]) @@ -5162,6 +5311,39 @@ def pad_reflect_test(): return ([arg_pad, node], [x], [y]) +@onnx_test() +def pad_reflect_with_axes_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 2]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [2, 5]) + + sizes = np.array([2, 1]) + pad_tensor = helper.make_tensor(name='pad_size', + data_type=TensorProto.INT32, + dims=sizes.shape, + vals=sizes.astype(int)) + arg_pad = onnx.helper.make_node('Constant', + inputs=[], + outputs=['arg_pad'], + value=pad_tensor) + + axes = np.array([1]) + axes_tensor = helper.make_tensor(name='pad_axes', + data_type=TensorProto.INT32, + dims=axes.shape, + vals=axes.astype(int)) + arg_axes = onnx.helper.make_node('Constant', + inputs=[], + outputs=['arg_axes'], + value=axes_tensor) + + node = onnx.helper.make_node('Pad', + mode='reflect', + inputs=['0', 'arg_pad', 'arg_axes'], + outputs=['1']) + + return ([arg_axes, arg_pad, node], [x], [y]) + + @onnx_test() def pad_reflect_multiaxis_test(): x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3]) diff --git a/test/onnx/onnx_test.cpp b/test/onnx/onnx_test.cpp index 415a612d0e0..fa006131050 100644 --- a/test/onnx/onnx_test.cpp +++ b/test/onnx/onnx_test.cpp @@ -4958,6 +4958,22 @@ TEST_CASE(pad_test) EXPECT(p == prog); } +TEST_CASE(pad_asym_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}}); + mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 1, 0, 3, 0, 2, 0, 4}}}), l0); + auto prog = optimize_onnx("pad_asym_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(pad_asym_invalid_pads_error_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("pad_asym_invalid_pads_error_test.onnx"); })); +} + TEST_CASE(pad_3arg_test) { migraphx::program p; @@ -4974,6 +4990,51 @@ TEST_CASE(pad_3arg_test) EXPECT(p == prog); } +TEST_CASE(pad_4arg_axes_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}}); + // axes=[1,3] + mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 3}}); + // constant_value=1 + mm->add_literal({migraphx::shape{migraphx::shape::float_type}, {1.0f}}); + // pads=[1,3,2,4] + mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {4}}, {1, 3, 2, 4}}); + auto r = mm->add_instruction( + migraphx::make_op("pad", {{"pads", {0, 1, 0, 3, 0, 2, 0, 4}}, {"value", 1.0f}}), l0); + mm->add_return({r}); + + auto prog = migraphx::parse_onnx("pad_4arg_axes_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(pad_4arg_invalid_axes_error_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("pad_4arg_invalid_axes_error_test.onnx"); })); +} + +TEST_CASE(pad_4arg_neg_axes_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}}); + // axes=[-3,-1] + mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {2}}, {-3, -1}}); + // constant_value=1 + mm->add_literal({migraphx::shape{migraphx::shape::float_type}, {1.0f}}); + // pads=[1,3,2,4] + mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {4}}, {1, 3, 2, 4}}); + auto r = mm->add_instruction( + migraphx::make_op("pad", {{"pads", {0, 1, 0, 3, 0, 2, 0, 4}}, {"value", 1.0f}}), l0); + mm->add_return({r}); + + auto prog = migraphx::parse_onnx("pad_4arg_neg_axes_test.onnx"); + + EXPECT(p == prog); +} + TEST_CASE(pad_attr_dyn_test) { migraphx::program p; @@ -5032,6 +5093,27 @@ TEST_CASE(pad_reflect_test) EXPECT(p == prog); } +TEST_CASE(pad_reflect_with_axes_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 2}}); + mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {1}}, {1}}); + mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {2}}, {2, 1}}); + auto l1 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {0, 1}}, {"ends", {2, 2}}}), l0); + auto l2 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {0, 0}}, {"ends", {2, 1}}}), l0); + auto l3 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {0, 0}}, {"ends", {2, 1}}}), l0); + auto r = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l1, l0, l3); + mm->add_return({r}); + + auto prog = migraphx::parse_onnx("pad_reflect_with_axes_test.onnx"); + + EXPECT(p == prog); +} + TEST_CASE(pad_reflect_multiaxis_test) { migraphx::program p; diff --git a/test/onnx/pad_4arg_axes_test.onnx b/test/onnx/pad_4arg_axes_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..a9681f4a8dc78f3dd5da8202f59b493f8f95dd28 GIT binary patch literal 305 zcmd(Oo=y1EJ}}0tVk`6FG(#fv6{lfWg^4@6I0@F&d)0@Nz5zJlH+16OUx-v z)e_=h5@6F}Vq|vW0O<#5UBJk~r61YwdeU+1|yqI@EoJ_gWNs$Zo#0)@0zyDbv}gy-ia4;5ehsN%bf!U{zM({jdT zSL0)2j;8#z1j+INh5U1}L~%&gI{p9G%>b%ok9Qf8f@6k21VlmgEtv6-*?61n1dHGa DY&KGU literal 0 HcmV?d00001 diff --git a/test/onnx/pad_4arg_neg_axes_test.onnx b/test/onnx/pad_4arg_neg_axes_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..69f8efd74b369512e3b30c4b732cb3374c6152e7 GIT binary patch literal 331 zcmdk{uFQ?k%qvUG$xMj{3KYku78T_e#h0WOmsmA0aw%~!8VE5ODlrEn zrf7+Bv4E5~FgP#*F*6V|Ens9PY*3Um$P_Uy5e`Nn0WKyEMj&PeViq7~O#(_9qDitr KC3&2f1cU*tDJ3-k literal 0 HcmV?d00001 diff --git a/test/onnx/pad_asym_test.onnx b/test/onnx/pad_asym_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..ffc3dc17f98e4ee0eaaba0b43fee0120e60fee28 GIT binary patch literal 136 zcmdk{uFQ=uNi8n1D&$h*Vl)t9G*n^^NKDa^uWxAZ=n?A{>lD0$fZSj6lo`#4JF}ngo literal 0 HcmV?d00001 diff --git a/test/onnx/pad_reflect_with_axes_test.onnx b/test/onnx/pad_reflect_with_axes_test.onnx new file mode 100644 index 00000000000..1556ed708f9 --- /dev/null +++ b/test/onnx/pad_reflect_with_axes_test.onnx @@ -0,0 +1,18 @@ + pad_reflect_with_axes_test:ä +3arg_axes"Constant* +value**Bpad_axes  +3arg_pad"Constant* +value**Bpad_size  +2 +0 +arg_pad +arg_axes1"Pad* +mode"reflect pad_reflect_with_axes_testZ +0 +  + +b +1 +  + +B \ No newline at end of file From 5139b9308a7b8deda201e972cc24ffa6c78c9036 Mon Sep 17 00:00:00 2001 From: Charlie Lin Date: Tue, 17 Oct 2023 20:07:23 -0400 Subject: [PATCH 4/6] Change driver verify to check for fp16 and --fp16 (#2334) --- src/driver/argument_parser.hpp | 7 +++++ src/driver/main.cpp | 49 ++++++++-------------------------- src/driver/verify.cpp | 36 +++++++++++++++++++++++++ src/driver/verify.hpp | 6 +++++ 4 files changed, 60 insertions(+), 38 deletions(-) diff --git a/src/driver/argument_parser.hpp b/src/driver/argument_parser.hpp index 683df8478d3..a0ee434a2b6 100644 --- a/src/driver/argument_parser.hpp +++ b/src/driver/argument_parser.hpp @@ -187,6 +187,13 @@ struct value_parser } }; +// version for std::optional object +template +struct value_parser> +{ + static T apply(const std::string& x) { return value_parser::apply(x); } +}; + struct argument_parser { struct argument diff --git a/src/driver/main.cpp b/src/driver/main.cpp index 4978f0c7f6b..15da9b2d716 100644 --- a/src/driver/main.cpp +++ b/src/driver/main.cpp @@ -540,22 +540,17 @@ struct params : command struct verify : command { compiler c; - // Set to -1. as nonsense initial value - double rms_tol = -1.0; - double atol = -1.0; - double rtol = -1.0; + std::optional rms_tol; + std::optional atol; + std::optional rtol; bool per_instruction = false; bool reduce = false; void parse(argument_parser& ap) { c.parse(ap); - ap(rms_tol, {"--rms-tol"}, ap.help("Tolerance for the RMS error (Default: 0.001)")); - ap(atol, - {"--atol"}, - ap.help("Tolerance for the elementwise absolute difference (Default: 0.001)")); - ap(rtol, - {"--rtol"}, - ap.help("Tolerance for the elementwise relative difference (Default: 0.001)")); + ap(rms_tol, {"--rms-tol"}, ap.help("Tolerance for the RMS error")); + ap(atol, {"--atol"}, ap.help("Tolerance for the elementwise absolute difference")); + ap(rtol, {"--rtol"}, ap.help("Tolerance for the elementwise relative difference")); ap(per_instruction, {"-i", "--per-instruction"}, ap.help("Verify each instruction"), @@ -572,33 +567,6 @@ struct verify : command auto t = c.ct.get_target(); auto m = c.parameters.generate(p, t, true, c.l.batch); - // TODO remove this and make the driver able to figure out datatype most used in the model - // then set the tolerances appropriately. Need to check here because c.to_fp16 only set - // after argument_parser.parse() is run. This code is complicated because there's not a - // good way to change the default tolerances after reading `--fp16` but before reading - // `--rms-tol`, `--atol`, and `--rtol`. - migraphx::verify::tolerance tols{}; - if(c.to_fp16) - { - tols = migraphx::verify::tolerance{8e-2, 4e-2, 4e-2}; - } - if(not float_equal(this->rms_tol, -1.0)) - { - tols.rms_tol = this->rms_tol; - } - if(not float_equal(this->atol, -1.0)) - { - tols.atol = this->atol; - } - if(not float_equal(this->rtol, -1.0)) - { - tols.rtol = this->rtol; - } - - std::cout << "rms_tol: " << tols.rms_tol << std::endl; - std::cout << "atol: " << tols.atol << std::endl; - std::cout << "rtol: " << tols.rtol << std::endl; - auto quantize = precision::fp32; if(c.to_fp16) { @@ -609,6 +577,11 @@ struct verify : command quantize = precision::int8; } + auto tols = get_tolerances(p, quantize, rms_tol, atol, rtol); + std::cout << "rms_tol: " << tols.rms_tol << std::endl; + std::cout << "atol: " << tols.atol << std::endl; + std::cout << "rtol: " << tols.rtol << std::endl; + if(per_instruction) { verify_instructions(p, t, c.co, quantize, tols); diff --git a/src/driver/verify.cpp b/src/driver/verify.cpp index df028a693ed..8a6c96b200b 100644 --- a/src/driver/verify.cpp +++ b/src/driver/verify.cpp @@ -36,6 +36,42 @@ namespace migraphx { namespace driver { inline namespace MIGRAPHX_INLINE_NS { +/** + * Gives tolerances based on user input (`rms_tol`, `atol`, `rtol` parameters) and defaults. + * Sets to fp16 tolerances if `quantize` input is fp16 or any fp16 instruction in found in the + * model. + */ +verify::tolerance get_tolerances(const program& p, + precision quantize, + std::optional rms_tol, + std::optional atol, + std::optional rtol) +{ + bool has_fp16 = any_of(p.get_modules(), [](auto&& m) { + return any_of(*m, [](auto&& ins) { return (ins.get_shape().type() == shape::half_type); }); + }); + migraphx::verify::tolerance result{}; + if(has_fp16 or quantize == precision::fp16) + { + result.rms_tol = 8e-2; + result.atol = 4e-2; + result.rtol = 4e-2; + } + if(rms_tol) + { + result.rms_tol = *rms_tol; + } + if(atol) + { + result.atol = *atol; + } + if(rtol) + { + result.rtol = *rtol; + } + return result; +} + std::vector run_ref(program p, const parameter_map& inputs) { p.compile(migraphx::make_target("ref")); diff --git a/src/driver/verify.hpp b/src/driver/verify.hpp index 63ac161f252..582501c017a 100644 --- a/src/driver/verify.hpp +++ b/src/driver/verify.hpp @@ -32,6 +32,12 @@ namespace migraphx { namespace driver { inline namespace MIGRAPHX_INLINE_NS { +verify::tolerance get_tolerances(const program& p, + precision quantize, + std::optional rms_tol, + std::optional atol, + std::optional rtol); + void verify_program(const std::string& name, const program& p, const target& t, From c39906221adae44b1a92ab0d9481bae6b066a515 Mon Sep 17 00:00:00 2001 From: Zakor Gyula <126694206+gyulaz-htec@users.noreply.github.com> Date: Wed, 18 Oct 2023 19:01:20 +0200 Subject: [PATCH 5/6] Add support for Shrink ONNX operator (#2240) --- src/onnx/parse_shrink.cpp | 85 +++++++++++++++++++++++ test/onnx/gen_onnx.py | 95 ++++++++++++++++++++++++++ test/onnx/onnx_test.cpp | 67 ++++++++++++++++++ test/onnx/shrink_hard_test.onnx | Bin 0 -> 111 bytes test/onnx/shrink_int8_test.onnx | Bin 0 -> 135 bytes test/onnx/shrink_soft_test.onnx | Bin 0 -> 127 bytes test/onnx/shrink_uint8_test.onnx | Bin 0 -> 137 bytes test/onnx/shrink_verify2_test.onnx | Bin 0 -> 133 bytes test/onnx/shrink_verify_test.onnx | Bin 0 -> 131 bytes test/onnx/verify_onnx.cpp | 106 +++++++++++++++++++++++++++++ test/py/onnx_backend_test.py | 3 - test/verify/test_shrink.cpp | 86 +++++++++++++++++++++++ 12 files changed, 439 insertions(+), 3 deletions(-) create mode 100644 src/onnx/parse_shrink.cpp create mode 100644 test/onnx/shrink_hard_test.onnx create mode 100644 test/onnx/shrink_int8_test.onnx create mode 100644 test/onnx/shrink_soft_test.onnx create mode 100644 test/onnx/shrink_uint8_test.onnx create mode 100644 test/onnx/shrink_verify2_test.onnx create mode 100644 test/onnx/shrink_verify_test.onnx create mode 100644 test/verify/test_shrink.cpp diff --git a/src/onnx/parse_shrink.cpp b/src/onnx/parse_shrink.cpp new file mode 100644 index 00000000000..669706425c7 --- /dev/null +++ b/src/onnx/parse_shrink.cpp @@ -0,0 +1,85 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_shrink : op_parser +{ + std::vector operators() const { return {{"Shrink"}}; } + + instruction_ref parse(const op_desc&, + const onnx_parser& parser, + const onnx_parser::node_info& info, + std::vector args) const + { + float bias = 0.0; + if(contains(info.attributes, "bias")) + { + bias = parser.parse_value(info.attributes.at("bias")).at(); + } + float lambd = 0.5; + if(contains(info.attributes, "lambd")) + { + lambd = parser.parse_value(info.attributes.at("lambd")).at(); + } + + auto x = args[0]; + auto x_shape = x->get_shape(); + auto x_type = x_shape.type(); + auto lit_bias = info.add_literal(bias); + auto lit_neg_lambd = info.add_literal(-lambd); + auto lit_lambd = info.add_literal(lambd); + + auto x_plus_bias = info.add_common_op("add", x, lit_bias); + auto x_min_bias = info.add_common_op("sub", x, lit_bias); + + auto cond1 = info.add_common_op("less", x, lit_neg_lambd); + auto cond2_a = info.add_common_op("not", cond1); + auto cond2_b = info.add_common_op("greater", x, lit_lambd); + auto cond2 = info.add_common_op("logical_and", cond2_a, cond2_b); + + auto mul1 = info.add_instruction(make_op("convert", {{"target_type", x_type}}), cond1); + auto mul2 = info.add_instruction(make_op("convert", {{"target_type", x_type}}), cond2); + + auto first = info.add_common_op("mul", mul1, x_plus_bias); + auto second = info.add_common_op("mul", mul2, x_min_bias); + auto ret = info.add_common_op("add", first, second); + if(ret->get_shape().type() != x_type) + { + ret = info.add_instruction(make_op("convert", {{"target_type", x_type}}), ret); + } + return ret; + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index de4493cc6fe..9a1caeda801 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -7135,6 +7135,101 @@ def shape_gather_test(): return ([node_const, node_shape, node_gather], [x], [z]) +@onnx_test() +def shrink_hard_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [5]) + + node = onnx.helper.make_node( + "Shrink", + inputs=["x"], + outputs=["y"], + lambd=1.5, + ) + + return ([node], [x], [y]) + + +@onnx_test() +def shrink_soft_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [5]) + + node = onnx.helper.make_node( + "Shrink", + inputs=["x"], + outputs=["y"], + lambd=1.5, + bias=1.5, + ) + + return ([node], [x], [y]) + + +@onnx_test() +def shrink_verify_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [5]) + + node = onnx.helper.make_node( + "Shrink", + inputs=["x"], + outputs=["y"], + lambd=-5.0, + bias=1.0, + ) + + return ([node], [x], [y]) + + +@onnx_test() +def shrink_verify2_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [5]) + + node = onnx.helper.make_node( + "Shrink", + inputs=["x"], + outputs=["y"], + lambd=-6.0, + bias=5.0, + ) + + return ([node], [x], [y]) + + +@onnx_test() +def shrink_int8_test(): + x = helper.make_tensor_value_info('x', TensorProto.INT8, [3, 3]) + y = helper.make_tensor_value_info('y', TensorProto.INT8, [3, 3]) + + node = onnx.helper.make_node( + "Shrink", + inputs=["x"], + outputs=["y"], + lambd=1.5, + bias=1.5, + ) + + return ([node], [x], [y]) + + +@onnx_test() +def shrink_uint8_test(): + x = helper.make_tensor_value_info('x', TensorProto.UINT8, [3, 3]) + y = helper.make_tensor_value_info('y', TensorProto.UINT8, [3, 3]) + + node = onnx.helper.make_node( + "Shrink", + inputs=["x"], + outputs=["y"], + lambd=5.0, + bias=-4.5, + ) + + return ([node], [x], [y]) + + @onnx_test() def sign_test(): x = helper.make_tensor_value_info('x', TensorProto.DOUBLE, [10, 5]) diff --git a/test/onnx/onnx_test.cpp b/test/onnx/onnx_test.cpp index fa006131050..e4c2e7e63bb 100644 --- a/test/onnx/onnx_test.cpp +++ b/test/onnx/onnx_test.cpp @@ -7006,6 +7006,73 @@ TEST_CASE(shape_gather_test) EXPECT(p == prog); } +TEST_CASE(shrink_hard_test) +{ + migraphx::program p; + float bias = 0.0; + float lambd = 1.5; + std::vector lens{5}; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, lens}); + auto lit_bias = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {bias}}); + auto lit_neg_lambd = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {-lambd}}); + auto lit_lambd = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {lambd}}); + + auto x_plus_bias = add_common_op(*mm, migraphx::make_op("add"), {x, lit_bias}); + auto x_min_bias = add_common_op(*mm, migraphx::make_op("sub"), {x, lit_bias}); + + auto cond1 = add_common_op(*mm, migraphx::make_op("less"), {x, lit_neg_lambd}); + auto cond2_a = add_common_op(*mm, migraphx::make_op("not"), {cond1}); + auto cond2_b = add_common_op(*mm, migraphx::make_op("greater"), {x, lit_lambd}); + auto cond2 = add_common_op(*mm, migraphx::make_op("logical_and"), {cond2_a, cond2_b}); + + auto mul1 = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), cond1); + auto mul2 = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), cond2); + + auto first = add_common_op(*mm, migraphx::make_op("mul"), {mul1, x_plus_bias}); + auto second = add_common_op(*mm, migraphx::make_op("mul"), {mul2, x_min_bias}); + add_common_op(*mm, migraphx::make_op("add"), {first, second}); + auto prog = optimize_onnx("shrink_hard_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(shrink_int8_test) +{ + migraphx::program p; + float bias = 1.5; + float lambd = 1.5; + std::vector lens{3, 3}; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::int8_type, lens}); + auto lit_bias = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {bias}}); + auto lit_neg_lambd = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {-lambd}}); + auto lit_lambd = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {lambd}}); + + auto x_plus_bias = add_common_op(*mm, migraphx::make_op("add"), {x, lit_bias}); + auto x_min_bias = add_common_op(*mm, migraphx::make_op("sub"), {x, lit_bias}); + + auto cond1 = add_common_op(*mm, migraphx::make_op("less"), {x, lit_neg_lambd}); + auto cond2_a = add_common_op(*mm, migraphx::make_op("not"), {cond1}); + auto cond2_b = add_common_op(*mm, migraphx::make_op("greater"), {x, lit_lambd}); + auto cond2 = add_common_op(*mm, migraphx::make_op("logical_and"), {cond2_a, cond2_b}); + + auto mul1 = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::int8_type}}), cond1); + auto mul2 = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::int8_type}}), cond2); + + auto first = add_common_op(*mm, migraphx::make_op("mul"), {mul1, x_plus_bias}); + auto second = add_common_op(*mm, migraphx::make_op("mul"), {mul2, x_min_bias}); + auto ret = add_common_op(*mm, migraphx::make_op("add"), {first, second}); + mm->add_instruction(migraphx::make_op("convert", {{"target_type", migraphx::shape::int8_type}}), + ret); + auto prog = optimize_onnx("shrink_int8_test.onnx"); + + EXPECT(p == prog); +} + TEST_CASE(sign_test) { migraphx::program p; diff --git a/test/onnx/shrink_hard_test.onnx b/test/onnx/shrink_hard_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..51c88c32cb6120249efeaa4bf9ae883c48988e57 GIT binary patch literal 111 zcmdAPfMJWf>j- literal 0 HcmV?d00001 diff --git a/test/onnx/shrink_int8_test.onnx b/test/onnx/shrink_int8_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..29b85880a52531a5df0d13c1350b5fbe6f314f65 GIT binary patch literal 135 zcmdF{o5)cLe%hetf literal 0 HcmV?d00001 diff --git a/test/onnx/shrink_verify2_test.onnx b/test/onnx/shrink_verify2_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..93e89460ea6f9b478b6f4d0df7ec20187cff56c7 GIT binary patch literal 133 zcmdZhBxNQRi!v}Q za9F^|sKw94nv data{-2, -1, 0, 1, 2}; + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(s, data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold = {-2, 0, 0, 0, 2}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(shrink_soft_test) +{ + migraphx::program p = migraphx::parse_onnx("shrink_soft_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape s{migraphx::shape::float_type, {5}}; + std::vector data{-2, -1, 0, 1, 2}; + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(s, data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold = {-0.5, 0, 0, 0, 0.5}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(shrink_verify_test) +{ + migraphx::program p = migraphx::parse_onnx("shrink_verify_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape s{migraphx::shape::half_type, {5}}; + std::vector tmp = {-10.0, -5.0, 0.0, 5.0, 10.0}; + std::vector data{tmp.cbegin(), tmp.cend()}; + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(s, data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + tmp = {-9.0, -4.0, 1.0, 4.0, 9.0}; + std::vector gold{tmp.cbegin(), tmp.cend()}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(shrink_verify2_test) +{ + migraphx::program p = migraphx::parse_onnx("shrink_verify2_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape s{migraphx::shape::half_type, {5}}; + std::vector tmp = {-10.0, -5.0, 0.0, 5.0, 10.0}; + std::vector data{tmp.cbegin(), tmp.cend()}; + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(s, data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + tmp = {-5.0, 0.0, 5.0, 10.0, 5.0}; + std::vector gold{tmp.cbegin(), tmp.cend()}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(shrink_int8_test) +{ + migraphx::program p = migraphx::parse_onnx("shrink_int8_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape s{migraphx::shape::int8_type, {3, 3}}; + std::vector data{-4, -3, -2, -1, 0, 1, 2, 3, 4}; + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(s, data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold = {-2, -1, 0, 0, 0, 0, 0, 1, 2}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(shrink_uint8_test) +{ + migraphx::program p = migraphx::parse_onnx("shrink_uint8_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape s{migraphx::shape::uint8_type, {3, 3}}; + std::vector data{1, 2, 3, 4, 5, 6, 7, 8, 9}; + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(s, data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold = {0, 0, 0, 0, 0, 10, 11, 12, 13}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + TEST_CASE(size_verify_test) { migraphx::program p = migraphx::parse_onnx("size_verify_test.onnx"); diff --git a/test/py/onnx_backend_test.py b/test/py/onnx_backend_test.py index 3bafd6f2cae..db2dec28e13 100644 --- a/test/py/onnx_backend_test.py +++ b/test/py/onnx_backend_test.py @@ -249,8 +249,6 @@ def disabled_tests_onnx_1_7_0(backend_test): backend_test.exclude(r'test_reversesequence_time_cpu') backend_test.exclude(r'test_scan9_sum_cpu') backend_test.exclude(r'test_scan_sum_cpu') - backend_test.exclude(r'test_shrink_hard_cpu') - backend_test.exclude(r'test_shrink_soft_cpu') backend_test.exclude(r'test_slice_cpu') backend_test.exclude(r'test_slice_default_axes_cpu') backend_test.exclude(r'test_slice_default_steps_cpu') @@ -463,7 +461,6 @@ def disabled_tests_onnx_1_7_0(backend_test): backend_test.exclude(r'test_sequence_model6_cpu') backend_test.exclude(r'test_sequence_model7_cpu') backend_test.exclude(r'test_sequence_model8_cpu') - backend_test.exclude(r'test_shrink_cpu') backend_test.exclude(r'test_strnorm_model_monday_casesensintive_lower_cpu') backend_test.exclude( r'test_strnorm_model_monday_casesensintive_nochangecase_cpu') diff --git a/test/verify/test_shrink.cpp b/test/verify/test_shrink.cpp new file mode 100644 index 00000000000..ed3ab4ce86e --- /dev/null +++ b/test/verify/test_shrink.cpp @@ -0,0 +1,86 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include +#include +#include +#include + +template +struct test_shrink : verify_program> +{ + migraphx::program create_program() const + { + migraphx::program p; + float bias = 1.5; + float lambd = 1.5; + auto* mm = p.get_main_module(); + migraphx::shape is{T, {2, 3}}; + std::vector data; + migraphx::shape::visit(T, [&](auto as) { + as.is_signed() ? data.assign({-3.0, -2.0, -1.0, 0.0, 1.0, 2.0}) + : data.assign({3.0, 2.0, 1.0, 0.0, 1.0, 2.0}); + }); + auto x = mm->add_literal(migraphx::literal{is, data}); + auto lit_bias = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {bias}}); + auto lit_neg_lambd = + mm->add_literal(migraphx::literal{migraphx::shape::float_type, {-lambd}}); + auto lit_lambd = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {lambd}}); + + auto x_plus_bias = add_common_op(*mm, migraphx::make_op("add"), {x, lit_bias}); + auto x_min_bias = add_common_op(*mm, migraphx::make_op("sub"), {x, lit_bias}); + + auto cond1 = add_common_op(*mm, migraphx::make_op("less"), {x, lit_neg_lambd}); + auto cond2_a = add_common_op(*mm, migraphx::make_op("not"), {cond1}); + auto cond2_b = add_common_op(*mm, migraphx::make_op("greater"), {x, lit_lambd}); + auto cond2 = add_common_op(*mm, migraphx::make_op("logical_and"), {cond2_a, cond2_b}); + + auto mul1 = mm->add_instruction(migraphx::make_op("convert", {{"target_type", T}}), cond1); + auto mul2 = mm->add_instruction(migraphx::make_op("convert", {{"target_type", T}}), cond2); + + auto first = add_common_op(*mm, migraphx::make_op("mul"), {mul1, x_plus_bias}); + auto second = add_common_op(*mm, migraphx::make_op("mul"), {mul2, x_min_bias}); + auto ret = add_common_op(*mm, migraphx::make_op("add"), {first, second}); + if(ret->get_shape().type() != T) + { + mm->add_instruction(migraphx::make_op("convert", {{"target_type", T}}), ret); + } + return p; + } +}; + +template struct test_shrink; +template struct test_shrink; +template struct test_shrink; +template struct test_shrink; +template struct test_shrink; +template struct test_shrink; +template struct test_shrink; +template struct test_shrink; +template struct test_shrink; +template struct test_shrink; +template struct test_shrink; From 6e86734d12618328235748b2d54c9de1e3f203c7 Mon Sep 17 00:00:00 2001 From: kahmed10 <15948690+kahmed10@users.noreply.github.com> Date: Wed, 18 Oct 2023 21:08:47 -0500 Subject: [PATCH 6/6] update script when offload copy is disabled (#2348) --- tools/accuracy/accuracy_checker.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tools/accuracy/accuracy_checker.py b/tools/accuracy/accuracy_checker.py index d368ca2a29e..8752bbe7f78 100644 --- a/tools/accuracy/accuracy_checker.py +++ b/tools/accuracy/accuracy_checker.py @@ -220,10 +220,16 @@ def main(): else: test_input = np.zeros(in_shape).astype(get_np_datatype(in_type)) test_inputs[name] = test_input - params[name] = migraphx.argument(test_input) + migraphx_arg = migraphx.argument(test_input) + if not args.offload_copy: + migraphx_arg = migraphx.to_gpu(migraphx_arg) + params[name] = migraphx_arg if not args.ort_run: - pred_migx = np.array(model.run(params)[-1]) + if not args.offload_copy: + pred_migx = np.array(migraphx.from_gpu(model.run(params)[-1])) + else: + pred_migx = np.array(model.run(params)[-1]) if use_onnx: sess_op = ort.SessionOptions()