From 93d77e96299d06dfbf5b847e77e8ac10c8654299 Mon Sep 17 00:00:00 2001 From: Lakhinder Walia <139581206+lakhinderwalia@users.noreply.github.com> Date: Mon, 20 May 2024 21:00:42 -0700 Subject: [PATCH] Keep LayerNorm accumulator at FP32 (#2925) --- .../include/migraphx/kernels/layernorm.hpp | 54 +++++++++++-------- 1 file changed, 33 insertions(+), 21 deletions(-) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp index b52a61eb498..c64ab553159 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -30,6 +30,18 @@ namespace migraphx { +template +struct acc_type +{ + using type = float; +}; + +template <> +struct acc_type +{ + using type = double; +}; + template constexpr auto vec_reduce(const array& a, Op op) { @@ -50,33 +62,33 @@ __device__ void generic_binary_layernorm( using reduce_output = reduce::with_axis; block::template run([&](auto, auto r) { - auto input = r.inner([&](auto x1, auto x2) { return op(x1, x2); })(input1, input2); - using value_type = typename Input1::type; - using vec_value_type = vec_type; + using value_type = typename Input1::type; + using vec_value_type = typename acc_type>::type; + + auto input = r.inner([&](auto x1, auto x2) { + return migraphx::convert(op(x1, x2)); + })(input1, input2); + constexpr auto relements = r.template elements(); constexpr auto relements_r = vec_value_type{1.0 / relements}; auto relements_rsqrt = sqrt(relements_r); - auto means = r.reduce(op::sum{}, - make_array(vec_value_type{0}, vec_value_type{0}), - [&](auto x) { - auto x_out = x * relements_r; - // dividing x by sqrt(relements) before squaring allows computing - // higher values before overflow in low precision - auto x2_sqrt = x * relements_rsqrt; - return make_array(x_out, x2_sqrt * x2_sqrt); - })(input); + auto means = r.reduce(op::sum{}, make_array(0, 0), [&](auto x) { + auto x_out = x * relements_r; + // dividing x by sqrt(relements) before squaring allows computing + // higher values before overflow in low precision + auto x2_sqrt = x * relements_rsqrt; + return make_array(x_out, x2_sqrt * x2_sqrt); + })(input); - auto mean_x = means[0]; - auto mean_x2 = means[1]; - auto variance = mean_x2 - (mean_x * mean_x); - value_type eps_val = implicit_conversion(eps); + auto mean_x = means[0]; + auto mean_x2 = means[1]; + auto variance = mean_x2 - (mean_x * mean_x); + vec_value_type eps_val = implicit_conversion(eps); + auto rsqrt_val = rsqrt(variance + eps_val); r.inner([&](auto& y, auto x, auto... xs) { - auto m = x - mean_x; - - // m * rsqrt(mean(m ^ 2) + epsilon) - y = compute(m * rsqrt(variance + eps_val), xs...); + y = compute(migraphx::convert>((x - mean_x) * rsqrt_val), xs...); })(output, input, inputs...); }); }