From 545bf19898fef6495a4f441c245df7e1d4e77b38 Mon Sep 17 00:00:00 2001 From: Alan Turner Date: Wed, 29 May 2024 10:52:30 -0700 Subject: [PATCH] Formatting --- .../parse_simplified_layer_normalization.cpp | 25 ++++++------ test/onnx/include/onnx_test_utils.hpp | 8 ++-- .../verify/simplified_layer_normalization.cpp | 38 ++++++++++++++++--- 3 files changed, 51 insertions(+), 20 deletions(-) diff --git a/src/onnx/parse_simplified_layer_normalization.cpp b/src/onnx/parse_simplified_layer_normalization.cpp index 09654cc809a..29dbcd9e9cd 100644 --- a/src/onnx/parse_simplified_layer_normalization.cpp +++ b/src/onnx/parse_simplified_layer_normalization.cpp @@ -54,37 +54,40 @@ struct parse_simplified_layer_normalization : op_parserget_shape(); auto x_dtype = x_shape.type(); int64_t x_rank = x_shape.ndim(); - axis = axis < 0 ? axis + x_rank : axis; + axis = axis < 0 ? axis + x_rank : axis; if(x_rank < 2 or x_rank > 3) { MIGRAPHX_THROW("PARSE_SIMPLIFIED_LAYER_NORMALIZATION: invalid input shape"); } - + auto x_sq = info.add_common_op("mul", x, x); - auto rms = info.add_instruction(make_op("reduce_mean", {{"axes", {axis}}}), x_sq); + auto rms = info.add_instruction(make_op("reduce_mean", {{"axes", {axis}}}), x_sq); auto mean = rms; epsilon = (x_dtype == migraphx::shape::half_type and std::abs(epsilon) < 1e-7) ? 1e-7 : epsilon; - auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_dtype}, {epsilon}}); - rms = info.add_common_op("add", rms, eps); - auto rrms = info.add_instruction(make_op("rsqrt"), rms); + auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_dtype}, {epsilon}}); + rms = info.add_common_op("add", rms, eps); + auto rrms = info.add_instruction(make_op("rsqrt"), rms); auto result = info.add_common_op("mul", x, rrms); - result = info.add_common_op("mul", result, scale); + result = info.add_common_op("mul", result, scale); return {result, mean, rrms}; } diff --git a/test/onnx/include/onnx_test_utils.hpp b/test/onnx/include/onnx_test_utils.hpp index 6129f315833..a6b05456b09 100644 --- a/test/onnx/include/onnx_test_utils.hpp +++ b/test/onnx/include/onnx_test_utils.hpp @@ -179,12 +179,12 @@ make_layer_norm(const std::vector& input_shape, { auto x_sq = add_common_op(*mm, migraphx::make_op("mul"), {x, x}); auto axis = reduce_axes[0]; - axis = axis < 0 ? axis + x->get_shape().lens().size() : axis; - auto rms = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {axis}}}), x_sq); - rms = add_common_op(*mm, migraphx::make_op("add"), {rms, eps}); + axis = axis < 0 ? axis + x->get_shape().lens().size() : axis; + auto rms = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {axis}}}), x_sq); + rms = add_common_op(*mm, migraphx::make_op("add"), {rms, eps}); auto rrms = mm->add_instruction(migraphx::make_op("rsqrt"), {rms}); auto result = add_common_op(*mm, migraphx::make_op("mul"), {x, rrms}); - result = add_common_op(*mm, migraphx::make_op("mul"), {result, scale}); + result = add_common_op(*mm, migraphx::make_op("mul"), {result, scale}); return p; } diff --git a/test/onnx/verify/simplified_layer_normalization.cpp b/test/onnx/verify/simplified_layer_normalization.cpp index 0afc5c7dcf7..f3838e4a551 100644 --- a/test/onnx/verify/simplified_layer_normalization.cpp +++ b/test/onnx/verify/simplified_layer_normalization.cpp @@ -30,8 +30,22 @@ TEST_CASE(simplified_layer_normalization_test) { using migraphx::half; - std::vector x{half{0.8}, half{-0.5}, half{0.0}, half{1.0}, half{0.5}, half{0.2}, half{0.3}, half{-0.6}, - half{10.0}, half{-1.0}, half{0.0}, half{1.0}, half{1.2}, half{3.2}, half{-4.1}, half{5.3}}; + std::vector x{half{0.8}, + half{-0.5}, + half{0.0}, + half{1.0}, + half{0.5}, + half{0.2}, + half{0.3}, + half{-0.6}, + half{10.0}, + half{-1.0}, + half{0.0}, + half{1.0}, + half{1.2}, + half{3.2}, + half{-4.1}, + half{5.3}}; std::vector scale{half{0.1}, half{0.2}, half{4.0}, half{-2.2}}; auto p = read_onnx("simplified_layer_normalization_test.onnx"); @@ -49,8 +63,22 @@ TEST_CASE(simplified_layer_normalization_test) std::vector result_vector; result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); - std::vector gold = {half{0.11633}, half{-0.1455}, half{0.0}, half{-3.2}, half{0.1162}, half{0.09296}, half{2.791}, half{3.068}, - half{0.198}, half{-0.03958}, half{0.0}, half{-0.4355}, half{0.0319}, half{0.17}, half{-4.363}, half{-3.1}}; + std::vector gold = {half{0.11633}, + half{-0.1455}, + half{0.0}, + half{-3.2}, + half{0.1162}, + half{0.09296}, + half{2.791}, + half{3.068}, + half{0.198}, + half{-0.03958}, + half{0.0}, + half{-0.4355}, + half{0.0319}, + half{0.17}, + half{-4.363}, + half{-3.1}}; - EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); }