Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
turneram committed May 29, 2024
1 parent f5b1daa commit 545bf19
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 20 deletions.
25 changes: 14 additions & 11 deletions src/onnx/parse_simplified_layer_normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,37 +54,40 @@ struct parse_simplified_layer_normalization : op_parser<parse_simplified_layer_n
}
if(contains(info.attributes, "stash_type"))
{
std::cerr << "WARNING: SIMPLIFIED_LAYER_NORMALIZATION attribute stash_type is only used for training.\n";
std::cerr << "WARNING: SIMPLIFIED_LAYER_NORMALIZATION attribute stash_type is only "
"used for training.\n";
}

if(args.size() != 2)
{
MIGRAPHX_THROW("PARSE_SIMPLIFIED_LAYER_NORMALIZATION: invalid input count - expected 2 got " + std::to_string(args.size()));
MIGRAPHX_THROW(
"PARSE_SIMPLIFIED_LAYER_NORMALIZATION: invalid input count - expected 2 got " +
std::to_string(args.size()));
}

auto x = args.at(0);
auto scale = args.at(1);
auto x = args.at(0);
auto scale = args.at(1);

auto x_shape = x->get_shape();
auto x_dtype = x_shape.type();
int64_t x_rank = x_shape.ndim();
axis = axis < 0 ? axis + x_rank : axis;
axis = axis < 0 ? axis + x_rank : axis;

if(x_rank < 2 or x_rank > 3)
{
MIGRAPHX_THROW("PARSE_SIMPLIFIED_LAYER_NORMALIZATION: invalid input shape");
}

auto x_sq = info.add_common_op("mul", x, x);
auto rms = info.add_instruction(make_op("reduce_mean", {{"axes", {axis}}}), x_sq);
auto rms = info.add_instruction(make_op("reduce_mean", {{"axes", {axis}}}), x_sq);
auto mean = rms;
epsilon =
(x_dtype == migraphx::shape::half_type and std::abs(epsilon) < 1e-7) ? 1e-7 : epsilon;
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_dtype}, {epsilon}});
rms = info.add_common_op("add", rms, eps);
auto rrms = info.add_instruction(make_op("rsqrt"), rms);
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_dtype}, {epsilon}});
rms = info.add_common_op("add", rms, eps);
auto rrms = info.add_instruction(make_op("rsqrt"), rms);
auto result = info.add_common_op("mul", x, rrms);
result = info.add_common_op("mul", result, scale);
result = info.add_common_op("mul", result, scale);

return {result, mean, rrms};
}
Expand Down
8 changes: 4 additions & 4 deletions test/onnx/include/onnx_test_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,12 @@ make_layer_norm(const std::vector<int64_t>& input_shape,
{
auto x_sq = add_common_op(*mm, migraphx::make_op("mul"), {x, x});
auto axis = reduce_axes[0];
axis = axis < 0 ? axis + x->get_shape().lens().size() : axis;
auto rms = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {axis}}}), x_sq);
rms = add_common_op(*mm, migraphx::make_op("add"), {rms, eps});
axis = axis < 0 ? axis + x->get_shape().lens().size() : axis;
auto rms = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {axis}}}), x_sq);
rms = add_common_op(*mm, migraphx::make_op("add"), {rms, eps});
auto rrms = mm->add_instruction(migraphx::make_op("rsqrt"), {rms});
auto result = add_common_op(*mm, migraphx::make_op("mul"), {x, rrms});
result = add_common_op(*mm, migraphx::make_op("mul"), {result, scale});
result = add_common_op(*mm, migraphx::make_op("mul"), {result, scale});
return p;
}

Expand Down
38 changes: 33 additions & 5 deletions test/onnx/verify/simplified_layer_normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,22 @@
TEST_CASE(simplified_layer_normalization_test)
{
using migraphx::half;
std::vector<half> x{half{0.8}, half{-0.5}, half{0.0}, half{1.0}, half{0.5}, half{0.2}, half{0.3}, half{-0.6},
half{10.0}, half{-1.0}, half{0.0}, half{1.0}, half{1.2}, half{3.2}, half{-4.1}, half{5.3}};
std::vector<half> x{half{0.8},
half{-0.5},
half{0.0},
half{1.0},
half{0.5},
half{0.2},
half{0.3},
half{-0.6},
half{10.0},
half{-1.0},
half{0.0},
half{1.0},
half{1.2},
half{3.2},
half{-4.1},
half{5.3}};
std::vector<half> scale{half{0.1}, half{0.2}, half{4.0}, half{-2.2}};

auto p = read_onnx("simplified_layer_normalization_test.onnx");
Expand All @@ -49,8 +63,22 @@ TEST_CASE(simplified_layer_normalization_test)
std::vector<half> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });

std::vector<half> gold = {half{0.11633}, half{-0.1455}, half{0.0}, half{-3.2}, half{0.1162}, half{0.09296}, half{2.791}, half{3.068},
half{0.198}, half{-0.03958}, half{0.0}, half{-0.4355}, half{0.0319}, half{0.17}, half{-4.363}, half{-3.1}};
std::vector<half> gold = {half{0.11633},
half{-0.1455},
half{0.0},
half{-3.2},
half{0.1162},
half{0.09296},
half{2.791},
half{3.068},
half{0.198},
half{-0.03958},
half{0.0},
half{-0.4355},
half{0.0319},
half{0.17},
half{-4.363},
half{-3.1}};

EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}

0 comments on commit 545bf19

Please sign in to comment.