Skip to content

Commit

Permalink
Add simplified_layer_normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
turneram committed May 29, 2024
1 parent dc028dd commit f5b1daa
Show file tree
Hide file tree
Showing 6 changed files with 244 additions and 1 deletion.
95 changes: 95 additions & 0 deletions src/onnx/parse_simplified_layer_normalization.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {

// ONNXRunTime implementation for reference:
// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc

struct parse_simplified_layer_normalization : op_parser<parse_simplified_layer_normalization>
{
std::vector<op_desc> operators() const { return {{"SimplifiedLayerNormalization"}}; }

std::vector<instruction_ref> parse(const op_desc& /*opd*/,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
int64_t axis = -1;
if(contains(info.attributes, "axis"))
{
axis = parser.parse_value(info.attributes.at("axis")).at<int64_t>();
}
float epsilon = 1e-5f;
if(contains(info.attributes, "epsilon"))
{
epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>();
}
if(contains(info.attributes, "stash_type"))
{
std::cerr << "WARNING: SIMPLIFIED_LAYER_NORMALIZATION attribute stash_type is only used for training.\n";
}

if(args.size() != 2)
{
MIGRAPHX_THROW("PARSE_SIMPLIFIED_LAYER_NORMALIZATION: invalid input count - expected 2 got " + std::to_string(args.size()));
}

auto x = args.at(0);
auto scale = args.at(1);

auto x_shape = x->get_shape();
auto x_dtype = x_shape.type();
int64_t x_rank = x_shape.ndim();
axis = axis < 0 ? axis + x_rank : axis;

if(x_rank < 2 or x_rank > 3)
{
MIGRAPHX_THROW("PARSE_SIMPLIFIED_LAYER_NORMALIZATION: invalid input shape");
}

auto x_sq = info.add_common_op("mul", x, x);
auto rms = info.add_instruction(make_op("reduce_mean", {{"axes", {axis}}}), x_sq);
auto mean = rms;
epsilon =
(x_dtype == migraphx::shape::half_type and std::abs(epsilon) < 1e-7) ? 1e-7 : epsilon;
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_dtype}, {epsilon}});
rms = info.add_common_op("add", rms, eps);
auto rrms = info.add_instruction(make_op("rsqrt"), rms);
auto result = info.add_common_op("mul", x, rrms);
result = info.add_common_op("mul", result, scale);

return {result, mean, rrms};
}
};

} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
18 changes: 18 additions & 0 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -9249,6 +9249,24 @@ def sign_test():
return ([node], [x], [y])


@onnx_test()
def simplified_layer_normalization_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [2, 2, 4])
scale = helper.make_tensor_value_info('scale', TensorProto.FLOAT16, [4])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [2, 2, 4])

node = onnx.helper.make_node(
'SimplifiedLayerNormalization',
inputs=['x', 'scale'],
outputs=['y'],
axis=-1,
epsilon=1e-5,
stash_type=1,
)

return ([node], [x, scale], [y])


@onnx_test()
def sin_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10])
Expand Down
16 changes: 15 additions & 1 deletion test/onnx/include/onnx_test_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ make_layer_norm(const std::vector<int64_t>& input_shape,
size_t skipped_axis,
bool skip_bias = false,
const float eps_value = 1e-5f,
const migraphx::shape::type_t dtype = migraphx::shape::float_type)
const migraphx::shape::type_t dtype = migraphx::shape::float_type,
bool simplified = false)
{
migraphx::program p;
auto* mm = p.get_main_module();
Expand All @@ -174,6 +175,19 @@ make_layer_norm(const std::vector<int64_t>& input_shape,

auto eps = mm->add_literal(migraphx::literal{dtype, {eps_value}});

if(simplified)
{
auto x_sq = add_common_op(*mm, migraphx::make_op("mul"), {x, x});
auto axis = reduce_axes[0];
axis = axis < 0 ? axis + x->get_shape().lens().size() : axis;
auto rms = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {axis}}}), x_sq);
rms = add_common_op(*mm, migraphx::make_op("add"), {rms, eps});
auto rrms = mm->add_instruction(migraphx::make_op("rsqrt"), {rms});
auto result = add_common_op(*mm, migraphx::make_op("mul"), {x, rrms});
result = add_common_op(*mm, migraphx::make_op("mul"), {result, scale});
return p;
}

auto mean = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", reduce_axes}}), x);
auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, mean});
auto x_sqdiff_mean = add_common_op(*mm, migraphx::make_op("sqdiff"), {x, mean});
Expand Down
35 changes: 35 additions & 0 deletions test/onnx/parse/simplified_layer_normalization_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <onnx_test.hpp>
#include <onnx_test_utils.hpp>

TEST_CASE(simplified_layer_normalization_test)
{
migraphx::program p =
make_layer_norm({2, 2, 4}, {4}, {-1}, 0, true, 1e-5f, migraphx::shape::half_type, true);

auto prog = optimize_onnx("simplified_layer_normalization_test.onnx");
EXPECT(p == prog);
}
25 changes: 25 additions & 0 deletions test/onnx/simplified_layer_normalization_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
 #simplified_layer_normalization_test:�
g
x
scaley"SimplifiedLayerNormalization*
axis����������*
epsilon��'7�*

stash_type�#simplified_layer_normalization_testZ
x




Z
scale



b
y




B
56 changes: 56 additions & 0 deletions test/onnx/verify/simplified_layer_normalization.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <onnx_test.hpp>
#include <onnx_verify_utils.hpp>

TEST_CASE(simplified_layer_normalization_test)
{
using migraphx::half;
std::vector<half> x{half{0.8}, half{-0.5}, half{0.0}, half{1.0}, half{0.5}, half{0.2}, half{0.3}, half{-0.6},
half{10.0}, half{-1.0}, half{0.0}, half{1.0}, half{1.2}, half{3.2}, half{-4.1}, half{5.3}};
std::vector<half> scale{half{0.1}, half{0.2}, half{4.0}, half{-2.2}};

auto p = read_onnx("simplified_layer_normalization_test.onnx");
p.compile(migraphx::make_target("ref"));

migraphx::shape s_x{migraphx::shape::half_type, {2, 2, 4}};
migraphx::shape s_s{migraphx::shape::half_type, {4}};

migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_x, x.data());
pp["scale"] = migraphx::argument(s_s, scale.data());

auto result = p.eval(pp).back();

std::vector<half> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });

std::vector<half> gold = {half{0.11633}, half{-0.1455}, half{0.0}, half{-3.2}, half{0.1162}, half{0.09296}, half{2.791}, half{3.068},
half{0.198}, half{-0.03958}, half{0.0}, half{-0.4355}, half{0.0319}, half{0.17}, half{-4.363}, half{-3.1}};

EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}

0 comments on commit f5b1daa

Please sign in to comment.