Skip to content

Commit

Permalink
Keep LayerNorm accumulator at FP32 (#2925)
Browse files Browse the repository at this point in the history
  • Loading branch information
lakhinderwalia authored May 21, 2024
1 parent 3ace932 commit 93d77e9
Showing 1 changed file with 33 additions and 21 deletions.
54 changes: 33 additions & 21 deletions src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -30,6 +30,18 @@

namespace migraphx {

template <typename T>
struct acc_type
{
using type = float;
};

template <>
struct acc_type<double>
{
using type = double;
};

template <class T, index_int N, class Op>
constexpr auto vec_reduce(const array<T, N>& a, Op op)
{
Expand All @@ -50,33 +62,33 @@ __device__ void generic_binary_layernorm(
using reduce_output = reduce::with_axis<Input1, Axis>;

block::template run<reduce_output>([&](auto, auto r) {
auto input = r.inner([&](auto x1, auto x2) { return op(x1, x2); })(input1, input2);
using value_type = typename Input1::type;
using vec_value_type = vec_type<value_type>;
using value_type = typename Input1::type;
using vec_value_type = typename acc_type<vec_type<value_type>>::type;

auto input = r.inner([&](auto x1, auto x2) {
return migraphx::convert<vec_value_type>(op(x1, x2));
})(input1, input2);

constexpr auto relements = r.template elements<Input1>();
constexpr auto relements_r = vec_value_type{1.0 / relements};
auto relements_rsqrt = sqrt(relements_r);

auto means = r.reduce(op::sum{},
make_array<vec_value_type>(vec_value_type{0}, vec_value_type{0}),
[&](auto x) {
auto x_out = x * relements_r;
// dividing x by sqrt(relements) before squaring allows computing
// higher values before overflow in low precision
auto x2_sqrt = x * relements_rsqrt;
return make_array(x_out, x2_sqrt * x2_sqrt);
})(input);
auto means = r.reduce(op::sum{}, make_array<vec_value_type>(0, 0), [&](auto x) {
auto x_out = x * relements_r;
// dividing x by sqrt(relements) before squaring allows computing
// higher values before overflow in low precision
auto x2_sqrt = x * relements_rsqrt;
return make_array(x_out, x2_sqrt * x2_sqrt);
})(input);

auto mean_x = means[0];
auto mean_x2 = means[1];
auto variance = mean_x2 - (mean_x * mean_x);
value_type eps_val = implicit_conversion(eps);
auto mean_x = means[0];
auto mean_x2 = means[1];
auto variance = mean_x2 - (mean_x * mean_x);
vec_value_type eps_val = implicit_conversion(eps);
auto rsqrt_val = rsqrt(variance + eps_val);

r.inner([&](auto& y, auto x, auto... xs) {
auto m = x - mean_x;

// m * rsqrt(mean(m ^ 2) + epsilon)
y = compute(m * rsqrt(variance + eps_val), xs...);
y = compute(migraphx::convert<vec_type<value_type>>((x - mean_x) * rsqrt_val), xs...);
})(output, input, inputs...);
});
}
Expand Down

0 comments on commit 93d77e9

Please sign in to comment.