diff --git a/src/onnx/parse_simplified_layer_normalization.cpp b/src/onnx/parse_simplified_layer_normalization.cpp new file mode 100644 index 00000000000..09654cc809a --- /dev/null +++ b/src/onnx/parse_simplified_layer_normalization.cpp @@ -0,0 +1,95 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +// ONNXRunTime implementation for reference: +// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc + +struct parse_simplified_layer_normalization : op_parser +{ + std::vector operators() const { return {{"SimplifiedLayerNormalization"}}; } + + std::vector parse(const op_desc& /*opd*/, + const onnx_parser& parser, + const onnx_parser::node_info& info, + std::vector args) const + { + int64_t axis = -1; + if(contains(info.attributes, "axis")) + { + axis = parser.parse_value(info.attributes.at("axis")).at(); + } + float epsilon = 1e-5f; + if(contains(info.attributes, "epsilon")) + { + epsilon = parser.parse_value(info.attributes.at("epsilon")).at(); + } + if(contains(info.attributes, "stash_type")) + { + std::cerr << "WARNING: SIMPLIFIED_LAYER_NORMALIZATION attribute stash_type is only used for training.\n"; + } + + if(args.size() != 2) + { + MIGRAPHX_THROW("PARSE_SIMPLIFIED_LAYER_NORMALIZATION: invalid input count - expected 2 got " + std::to_string(args.size())); + } + + auto x = args.at(0); + auto scale = args.at(1); + + auto x_shape = x->get_shape(); + auto x_dtype = x_shape.type(); + int64_t x_rank = x_shape.ndim(); + axis = axis < 0 ? axis + x_rank : axis; + + if(x_rank < 2 or x_rank > 3) + { + MIGRAPHX_THROW("PARSE_SIMPLIFIED_LAYER_NORMALIZATION: invalid input shape"); + } + + auto x_sq = info.add_common_op("mul", x, x); + auto rms = info.add_instruction(make_op("reduce_mean", {{"axes", {axis}}}), x_sq); + auto mean = rms; + epsilon = + (x_dtype == migraphx::shape::half_type and std::abs(epsilon) < 1e-7) ? 1e-7 : epsilon; + auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_dtype}, {epsilon}}); + rms = info.add_common_op("add", rms, eps); + auto rrms = info.add_instruction(make_op("rsqrt"), rms); + auto result = info.add_common_op("mul", x, rrms); + result = info.add_common_op("mul", result, scale); + + return {result, mean, rrms}; + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index e5afa0068fc..87f61f77568 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -9249,6 +9249,24 @@ def sign_test(): return ([node], [x], [y]) +@onnx_test() +def simplified_layer_normalization_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [2, 2, 4]) + scale = helper.make_tensor_value_info('scale', TensorProto.FLOAT16, [4]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [2, 2, 4]) + + node = onnx.helper.make_node( + 'SimplifiedLayerNormalization', + inputs=['x', 'scale'], + outputs=['y'], + axis=-1, + epsilon=1e-5, + stash_type=1, + ) + + return ([node], [x, scale], [y]) + + @onnx_test() def sin_test(): x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10]) diff --git a/test/onnx/include/onnx_test_utils.hpp b/test/onnx/include/onnx_test_utils.hpp index f6ebbb150f8..6129f315833 100644 --- a/test/onnx/include/onnx_test_utils.hpp +++ b/test/onnx/include/onnx_test_utils.hpp @@ -160,7 +160,8 @@ make_layer_norm(const std::vector& input_shape, size_t skipped_axis, bool skip_bias = false, const float eps_value = 1e-5f, - const migraphx::shape::type_t dtype = migraphx::shape::float_type) + const migraphx::shape::type_t dtype = migraphx::shape::float_type, + bool simplified = false) { migraphx::program p; auto* mm = p.get_main_module(); @@ -174,6 +175,19 @@ make_layer_norm(const std::vector& input_shape, auto eps = mm->add_literal(migraphx::literal{dtype, {eps_value}}); + if(simplified) + { + auto x_sq = add_common_op(*mm, migraphx::make_op("mul"), {x, x}); + auto axis = reduce_axes[0]; + axis = axis < 0 ? axis + x->get_shape().lens().size() : axis; + auto rms = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {axis}}}), x_sq); + rms = add_common_op(*mm, migraphx::make_op("add"), {rms, eps}); + auto rrms = mm->add_instruction(migraphx::make_op("rsqrt"), {rms}); + auto result = add_common_op(*mm, migraphx::make_op("mul"), {x, rrms}); + result = add_common_op(*mm, migraphx::make_op("mul"), {result, scale}); + return p; + } + auto mean = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", reduce_axes}}), x); auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, mean}); auto x_sqdiff_mean = add_common_op(*mm, migraphx::make_op("sqdiff"), {x, mean}); diff --git a/test/onnx/parse/simplified_layer_normalization_test.cpp b/test/onnx/parse/simplified_layer_normalization_test.cpp new file mode 100644 index 00000000000..2ae456b3645 --- /dev/null +++ b/test/onnx/parse/simplified_layer_normalization_test.cpp @@ -0,0 +1,35 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include + +TEST_CASE(simplified_layer_normalization_test) +{ + migraphx::program p = + make_layer_norm({2, 2, 4}, {4}, {-1}, 0, true, 1e-5f, migraphx::shape::half_type, true); + + auto prog = optimize_onnx("simplified_layer_normalization_test.onnx"); + EXPECT(p == prog); +} diff --git a/test/onnx/simplified_layer_normalization_test.onnx b/test/onnx/simplified_layer_normalization_test.onnx new file mode 100644 index 00000000000..a7b73fb1462 --- /dev/null +++ b/test/onnx/simplified_layer_normalization_test.onnx @@ -0,0 +1,25 @@ + #simplified_layer_normalization_test:Õ +g +x +scaley"SimplifiedLayerNormalization* +axisÿÿÿÿÿÿÿÿÿ * +epsilon¬Å'7 * + +stash_type #simplified_layer_normalization_testZ +x + + + + +Z +scale + + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/verify/simplified_layer_normalization.cpp b/test/onnx/verify/simplified_layer_normalization.cpp new file mode 100644 index 00000000000..0afc5c7dcf7 --- /dev/null +++ b/test/onnx/verify/simplified_layer_normalization.cpp @@ -0,0 +1,56 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include +#include + +TEST_CASE(simplified_layer_normalization_test) +{ + using migraphx::half; + std::vector x{half{0.8}, half{-0.5}, half{0.0}, half{1.0}, half{0.5}, half{0.2}, half{0.3}, half{-0.6}, + half{10.0}, half{-1.0}, half{0.0}, half{1.0}, half{1.2}, half{3.2}, half{-4.1}, half{5.3}}; + std::vector scale{half{0.1}, half{0.2}, half{4.0}, half{-2.2}}; + + auto p = read_onnx("simplified_layer_normalization_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape s_x{migraphx::shape::half_type, {2, 2, 4}}; + migraphx::shape s_s{migraphx::shape::half_type, {4}}; + + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(s_x, x.data()); + pp["scale"] = migraphx::argument(s_s, scale.data()); + + auto result = p.eval(pp).back(); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {half{0.11633}, half{-0.1455}, half{0.0}, half{-3.2}, half{0.1162}, half{0.09296}, half{2.791}, half{3.068}, + half{0.198}, half{-0.03958}, half{0.0}, half{-0.4355}, half{0.0319}, half{0.17}, half{-4.363}, half{-3.1}}; + + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +}