From 27a0135220935b9118add4d807f075f1738a9100 Mon Sep 17 00:00:00 2001 From: Cydral <53169060+Cydral@users.noreply.github.com> Date: Thu, 29 Aug 2024 13:40:10 +0200 Subject: [PATCH 01/14] Add customizable dropout layer with compile-time rate specification (#3000) * Add customizable dropout layer with compile-time rate specification * Update to the name of the new dropout rate customisation class * Fix: Replace float template parameter with int for C++17 compatibility * Update dlib/dnn/layers_abstract.h --------- Co-authored-by: Davis E. King --- dlib/dnn/layers.h | 18 ++++++++++++++++++ dlib/dnn/layers_abstract.h | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/dlib/dnn/layers.h b/dlib/dnn/layers.h index f92749d7b0..77ff918a6e 100644 --- a/dlib/dnn/layers.h +++ b/dlib/dnn/layers.h @@ -2134,6 +2134,24 @@ namespace dlib template using dropout = add_layer; +// ---------------------------------------------------------------------------------------- + + template + class dropout_rate_ : public dropout_ + { + public: + explicit dropout_rate_() : dropout_(static_cast(DROP_RATE_PERCENT) / 100.0f) + { + static_assert(DROP_RATE_PERCENT >= 0 && DROP_RATE_PERCENT <= 100, + "DROP_RATE_PERCENT must be between 0 and 100, inclusive."); + } + }; + + template + using dropout_rate = add_layer, SUBNET>; + template + using dropout_10 = add_layer, SUBNET>; + // ---------------------------------------------------------------------------------------- class multiply_ diff --git a/dlib/dnn/layers_abstract.h b/dlib/dnn/layers_abstract.h index 1ddecead03..7a29ab1347 100644 --- a/dlib/dnn/layers_abstract.h +++ b/dlib/dnn/layers_abstract.h @@ -1433,6 +1433,41 @@ namespace dlib template using dropout = add_layer; +// ---------------------------------------------------------------------------------------- + + template + class dropout_rate_ : public dropout_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a customizable dropout layer that inherits from + the dropout_ class. It allows specifying the dropout rate at compile-time, + which is particularly useful for deep networks with many layers where it + might be cumbersome to explicitly modify the dropout rate for each layer + individually. + + The main advantage of this layer is that it offers the possibility to specify + the dropout rate at the moment of network construction, providing more + flexibility and clarity in the network architecture definition. + + TEMPLATE PARAMETERS + - DROP_RATE_PERCENT: A int value between 0 and 100 that specifies the dropout rate. + This value is set at compile-time and cannot be changed during runtime. + !*/ + + public: + explicit dropout_rate_(); + /*! + ensures + - Constructs a dropout layer with a dropout rate of DROP_RATE. + - Calls the base class constructor dropout_(DROP_RATE). + !*/ + }; + + template + using dropout_rate = add_layer, SUBNET>; + template + using dropout_10 = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- class multiply_ From 253098eb1bebb39ea7ca32a6db6f177c8e326b9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0?= <1671644+arrufat@users.noreply.github.com> Date: Sun, 1 Sep 2024 22:05:09 +0900 Subject: [PATCH 02/14] Fix layer_normalize gradients (#3001) * Fix layer_normalize gradients * fix layer_norm CPU * attempt to fix the cuda version * fix gamma_grad and beta_grad * update cuda test * use a block of size 1 to avoid race conditions * improve the speed of CUDA path of layer_norm * improve the speed of CUDA path of layer_norm --- dlib/cuda/cpu_dlib.cpp | 130 ++++++++++++++------------ dlib/cuda/cpu_dlib.h | 4 +- dlib/cuda/cuda_dlib.cu | 186 ++++++++++++++++++++++--------------- dlib/cuda/cuda_dlib.h | 4 +- dlib/cuda/tensor_tools.cpp | 8 +- dlib/cuda/tensor_tools.h | 12 +-- dlib/dnn/layers.h | 5 +- dlib/test/dnn.cpp | 13 ++- 8 files changed, 213 insertions(+), 149 deletions(-) diff --git a/dlib/cuda/cpu_dlib.cpp b/dlib/cuda/cpu_dlib.cpp index 6b6d6b39f0..b8b5a41239 100644 --- a/dlib/cuda/cpu_dlib.cpp +++ b/dlib/cuda/cpu_dlib.cpp @@ -1270,22 +1270,19 @@ namespace dlib const tensor& beta ) { - const long num = src.k() * src.nr() * src.nc(); DLIB_CASSERT( have_same_dimensions(gamma, beta) && - src.k() == gamma.k() && - src.nr() == gamma.nr() && - src.nc() == gamma.nc() && + gamma.k() == src.k() && + gamma.nr() == 1 && + gamma.nc() == 1 && eps > 0, + "\nsrc.k(): " << src.k() << "\ngamma.k(): " << gamma.k() << "\ngamma.nr(): " << gamma.nr() << "\ngamma.nc(): " << gamma.nc() << "\nbeta.k(): " << beta.k() << "\nbeta.nr(): " << beta.nr() << "\nbeta.nc(): " << beta.nc() << - "\nsrc.k(): " << src.k() << - "\nsrc.nr(): " << src.nr() << - "\nsrc.nc(): " << src.nc() << "\neps: " << eps ); @@ -1296,43 +1293,50 @@ namespace dlib // first compute means and invstds means = 0; invstds = 0; - const auto p_invstds = invstds.host(); - const auto p_means = means.host(); - auto p_src = src.host(); + const float* p_src = src.host(); + float* p_invstds = invstds.host(); + float* p_means = means.host(); + const long num = src.nr() * src.nc(); // compute means, and sum of squares for (long n = 0; n < src.num_samples(); ++n) { - for (long i = 0; i < num; ++i) + for (long k = 0; k < src.k(); ++k) { - float val = p_src[n*num+i]; - p_means[n] += val; - p_invstds[n] += val*val; + for (long i = 0; i < num; ++i) + { + p_means[n] += *p_src; + p_invstds[n] += (*p_src) * (*p_src); + ++p_src; + } } } - means /= num; - invstds /= num; + means /= src.k() * num; + invstds /= src.k () * num; // copy data back to host - invstds.host(); means.host(); + invstds.host(); + means.host(); // compute variances for (long n = 0; n < src.num_samples(); ++n) { - auto var = p_invstds[n] - p_means[n] * p_means[n]; - p_invstds[n] = 1.0f / std::sqrt(var + eps); + p_invstds[n] = 1.0f / std::sqrt(p_invstds[n] - p_means[n] * p_means[n] + eps); } p_src = src.host(); - auto p_dest = dest.host(); - auto p_gamma = gamma.host(); - auto p_beta = beta.host(); + float* p_dest = dest.host(); + const float* p_gamma = gamma.host(); + const float* p_beta = beta.host(); for (long n = 0; n < src.num_samples(); ++n) { - for (long i = 0; i < num; ++i) + for (long k = 0; k < src.k(); ++k) { - *p_dest = (*p_src - p_means[n])*p_invstds[n]; - *p_dest = (*p_dest)*p_gamma[i] + p_beta[i]; - ++p_src; - ++p_dest; + for (long i = 0; i < num; ++i) + { + *p_dest = (*p_src - p_means[n]) * p_invstds[n]; + *p_dest = (*p_dest) * p_gamma[k] + p_beta[k]; + ++p_src; + ++p_dest; + } } } } @@ -1346,22 +1350,26 @@ namespace dlib const tensor& gamma, tensor& src_grad, tensor& gamma_grad, - tensor& beta_grad + tensor& beta_grad, + resizable_tensor& dmeans, + resizable_tensor& dvars ) { - const long num = src.k() * src.nr() * src.nc(); + const long num = src.nr() * src.nc(); DLIB_CASSERT(src.num_samples() == means.size()); DLIB_CASSERT(src.num_samples() == invstds.size()); - DLIB_CASSERT(src.k() == gamma.k()); - DLIB_CASSERT(src.nr() == gamma_grad.nr()); - DLIB_CASSERT(src.nc() == beta_grad.nc()); + DLIB_CASSERT(have_same_dimensions(gamma, gamma_grad)); + DLIB_CASSERT(have_same_dimensions(gamma_grad, beta_grad)); + DLIB_CASSERT(gamma.k() == src.k()); + DLIB_CASSERT(gamma.nr() == 1); + DLIB_CASSERT(gamma.nc() == 1); DLIB_CASSERT(have_same_dimensions(gradient_input, src)); DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad)); - DLIB_CASSERT(have_same_dimensions(gamma_grad, beta_grad)); DLIB_CASSERT(eps > 0); beta_grad = 0; gamma_grad = 0; + auto p_grad = gradient_input.host(); auto p_src = src.host(); const auto p_gamma = gamma.host(); @@ -1370,7 +1378,6 @@ namespace dlib const auto p_invstds = invstds.host(); const auto p_means = means.host(); - resizable_tensor dvars, dmeans; dvars.copy_size(invstds); dmeans.copy_size(means); dvars = 0; @@ -1380,34 +1387,41 @@ namespace dlib for (long n = 0; n < src.num_samples(); ++n) { - for (long i = 0; i < num; ++i) + const float invstd_pow = -0.5 * std::pow(p_invstds[n], 3.0f); + for (long k = 0; k < src.k(); ++k) { - const float x_hat = (*p_src - p_means[n])*p_invstds[n]; - p_beta_grad[i] += *p_grad; - p_gamma_grad[i] += (*p_grad)*x_hat; + for (long i = 0; i < num; ++i) + { + const float x_hat = (*p_src - p_means[n]) * p_invstds[n]; + p_beta_grad[k] += *p_grad; + p_gamma_grad[k] += (*p_grad) * x_hat; - const float dx = *p_grad * p_gamma[n]; + const float dx = *p_grad * p_gamma[k]; - p_dvars[n] += dx*(*p_src - p_means[n])*-0.5*p_invstds[n]*p_invstds[n]*p_invstds[n]; + p_dvars[n] += dx * (*p_src - p_means[n]) * invstd_pow; - ++p_grad; - ++p_src; + ++p_grad; + ++p_src; + } } } - const float invnum = 1.0f/num; p_grad = gradient_input.host(); p_src = src.host(); + const float invnum = 1.0f / (src.k() * num); for (long n = 0; n < src.num_samples(); ++n) { - for (long i = 0; i < num; ++i) + for (long k = 0; k < src.k(); ++k) { - const float dx = *p_grad * p_gamma[i]; + for (long i = 0; i < num; ++i) + { + const float dx = *p_grad * p_gamma[k]; - p_dmeans[n] += dx*-p_invstds[n] + p_dvars[n] * -2*(*p_src - p_means[n])*invnum; + p_dmeans[n] += -dx * p_invstds[n] + p_dvars[n] * -2 * (*p_src - p_means[n]) * invnum; - ++p_grad; - ++p_src; + ++p_grad; + ++p_src; + } } } p_grad = gradient_input.host(); @@ -1415,18 +1429,20 @@ namespace dlib auto p_src_grad = src_grad.host(); for (long n = 0; n < src.num_samples(); ++n) { - for (long i = 0; i < num; ++i) + for (long k = 0; k < src.k(); ++k) { - const float dx = *p_grad * p_gamma[i]; - - *p_src_grad += dx*p_invstds[n] + - p_dvars[n] *2*(*p_src - p_means[n])*invnum + - p_dmeans[n]*invnum; + for (long i = 0; i < num; ++i) + { + const float dx = *p_grad * p_gamma[k]; + *p_src_grad += dx * p_invstds[n] + + p_dvars[n] * 2 * (*p_src - p_means[n]) * invnum + + p_dmeans[n] * invnum; - ++p_grad; - ++p_src; - ++p_src_grad; + ++p_grad; + ++p_src; + ++p_src_grad; + } } } } diff --git a/dlib/cuda/cpu_dlib.h b/dlib/cuda/cpu_dlib.h index 1b85d75897..79ef9842b5 100644 --- a/dlib/cuda/cpu_dlib.h +++ b/dlib/cuda/cpu_dlib.h @@ -250,7 +250,9 @@ namespace dlib const tensor& gamma, tensor& src_grad, tensor& gamma_grad, - tensor& beta_grad + tensor& beta_grad, + resizable_tensor& dmeans, + resizable_tensor& dvars ); // ----------------------------------------------------------------------------------- diff --git a/dlib/cuda/cuda_dlib.cu b/dlib/cuda/cuda_dlib.cu index 3b44cb0dad..70b8ccb0e0 100644 --- a/dlib/cuda/cuda_dlib.cu +++ b/dlib/cuda/cuda_dlib.cu @@ -2085,21 +2085,32 @@ namespace dlib // ---------------------------------------------------------------------------------------- - __global__ void _cuda_layer_normalize(float* out, const float* s, float* m, float* v, const float* g, const float* b, float eps, size_t ns, size_t num) + __global__ void _cuda_layer_normalize( + float* out, + const float* s, + float* m, + float* v, + const float* g, + const float* b, + float eps, + size_t ns, + size_t k, + size_t num + ) { // compute means and sum of squares for (auto n : grid_stride_range_y(0, ns)) { - auto p = s + n * num; + const auto ps = s + n * k * num; float means = 0; float invstds = 0; - for (auto i : grid_stride_range(0, num)) + for (auto i : grid_stride_range(0, k * num)) { - means += p[i]; - invstds += p[i] * p[i]; + means += ps[i]; + invstds += ps[i] * ps[i]; } - warp_reduce_atomic_add(m[n], means/num); - warp_reduce_atomic_add(v[n], invstds/num); + warp_reduce_atomic_add(m[n], means / (k * num)); + warp_reduce_atomic_add(v[n], invstds / (k * num)); } __syncthreads(); @@ -2108,61 +2119,19 @@ namespace dlib { for (auto i : grid_stride_range(0, 1)) { - auto var = v[n] - m[n] * m[n]; - v[n] = 1.0f / std::sqrt(var + eps); + v[n] = 1.0f / std::sqrt(v[n] - m[n] * m[n] + eps); } } __syncthreads(); for (auto n : grid_stride_range_y(0, ns)) { - for (auto i : grid_stride_range(0, num)) + const auto ps = s + n * k * num; + const auto pout = out + n * k * num; + for (auto i : grid_stride_range(0, k * num)) { - const float val = (s[n*num+i]-m[n])*v[n]; - out[n*num+i] = val*g[i]+b[i]; - } - } - } - - __global__ void _cuda_layer_normalize_gradient(float* out, float* gg, float* bg, const float* s, const float* gi, const float* m, const float* v, const float* g, float* dm, float* dv, float eps, size_t ns, size_t num) - { - for (auto n : grid_stride_range_y(0, ns)) - { - float temp_dv = 0; - for (auto i : grid_stride_range(0, num)) - { - auto idx = n*num+i; - const float x_hat = (s[idx] - m[n])*v[n]; - bg[i] += gi[idx]; - gg[i] += gi[idx]*x_hat; - - const float dx = gi[idx] * g[n]; - temp_dv += dx*(s[idx] - m[n])*-0.5*v[n]*v[n]*v[n]; - } - warp_reduce_atomic_add(dv[n], temp_dv); - } - __syncthreads(); - - for (auto n : grid_stride_range_y(0, ns)) - { - float temp_dm = 0; - for (auto i : grid_stride_range(0, num)) - { - auto idx = n*num+i; - const float dx = gi[idx]*g[i]; - temp_dm += dx*-v[n] + dv[n] * -2*(s[idx] - m[n])/num; - } - warp_reduce_atomic_add(dm[n], temp_dm); - } - __syncthreads(); - - for (auto n : grid_stride_range_y(0, ns)) - { - for (auto i : grid_stride_range(0, num)) - { - auto idx = n*num+i; - const float dx = gi[idx]*g[i]; - out[idx] += dx*v[n] + dv[n] * 2*(s[idx] - m[n])/num + dm[n]/num; + pout[i] = (ps[i] - m[n]) * v[n]; + pout[i] = pout[i] * g[i / num] + b[i / num]; } } } @@ -2177,22 +2146,20 @@ namespace dlib const tensor& beta ) { - const long num = src.k() * src.nr() * src.nc(); + const long num = src.nr() * src.nc(); DLIB_CASSERT( have_same_dimensions(gamma, beta) && - src.k() == gamma.k() && - src.nr() == gamma.nr() && - src.nc() == gamma.nc() && + gamma.k() == src.k() && + gamma.nr() == 1 && + gamma.nc() == 1 && eps > 0, + "\nsrc.k(): " << src.k() << "\ngamma.k(): " << gamma.k() << "\ngamma.nr(): " << gamma.nr() << "\ngamma.nc(): " << gamma.nc() << "\nbeta.k(): " << beta.k() << "\nbeta.nr(): " << beta.nr() << "\nbeta.nc(): " << beta.nc() << - "\nsrc.k(): " << src.k() << - "\nsrc.nr(): " << src.nr() << - "\nsrc.nc(): " << src.nc() << "\neps: " << eps ); @@ -2201,8 +2168,78 @@ namespace dlib invstds.set_size(src.num_samples()); means = 0; invstds = 0; - launch_kernel(_cuda_layer_normalize, max_jobs(num, src.num_samples()), dest.device(), src.device(), - means.device(), invstds.device(), gamma.device(), beta.device(), eps, src.num_samples(), num); + launch_kernel(_cuda_layer_normalize, max_jobs(src.k() * num, src.num_samples()), dest.device(), src.device(), + means.device(), invstds.device(), gamma.device(), beta.device(), eps, src.num_samples(), src.k(), num); + } + + // ---------------------------------------------------------------------------------------- + + __global__ void _cuda_layer_normalize_gradient( + float* out, + float* gg, + float* bg, + const float* s, + const float* gi, + const float* m, + const float* v, + const float* g, + float* dm, + float* dv, + float eps, + size_t ns, + size_t ks, + size_t num) + { + for (auto nk : grid_stride_range_y(0, ns * ks)) + { + const auto n = nk / ks; + const auto k = nk % ks; + const auto ps = s + (n * ks + k) * num; + const auto pgi = gi + (n * ks + k) * num; + const float invstd_pow = -0.5 * std::pow(v[n], 3.0f); + float temp_bg = 0; + float temp_gg = 0; + float temp_dv = 0; + for (auto i : grid_stride_range(0, num)) + { + const float x_hat = (ps[i] - m[n]) * v[n]; + const float dx = pgi[i] * g[i / num]; + temp_bg += pgi[i]; + temp_gg += pgi[i] * x_hat; + temp_dv += dx * (ps[i] - m[n]) * invstd_pow; + } + warp_reduce_atomic_add(bg[k], temp_bg); + warp_reduce_atomic_add(gg[k], temp_gg); + warp_reduce_atomic_add(dv[n], temp_dv); + } + __syncthreads(); + + const float invnum = 1.0f / (ks * num); + for (auto n : grid_stride_range_y(0, ns)) + { + const auto ps = s + n * ks * num; + const auto pgi = gi + n * ks * num; + float temp_dm = 0; + for (auto i : grid_stride_range(0, ks * num)) + { + const float dx = pgi[i] * g[i / num]; + temp_dm += -dx * v[n] + dv[n] * -2 * (ps[i] - m[n]) * invnum; + } + warp_reduce_atomic_add(dm[n], temp_dm); + } + __syncthreads(); + + for (auto n : grid_stride_range_y(0, ns)) + { + const auto ps = s + n * ks * num; + const auto pgi = gi + n * ks * num; + const auto pout = out + n * ks * num; + for (auto i : grid_stride_range(0, ks * num)) + { + const float dx = pgi[i] * g[i / num]; + pout[i] += dx * v[n] + dv[n] * 2 * (ps[i] - m[n]) * invnum + dm[n] * invnum; + } + } } void layer_normalize_gradient ( @@ -2214,32 +2251,33 @@ namespace dlib const tensor& gamma, tensor& src_grad, tensor& gamma_grad, - tensor& beta_grad + tensor& beta_grad, + resizable_tensor& dmeans, + resizable_tensor& dvars ) { - const long num = src.k() * src.nr() * src.nc(); + const long num = src.nr() * src.nc(); DLIB_CASSERT(src.num_samples() == means.size()); DLIB_CASSERT(src.num_samples() == invstds.size()); - DLIB_CASSERT(src.k() == gamma.k()); - DLIB_CASSERT(src.nr() == gamma.nr()); - DLIB_CASSERT(src.nc() == gamma.nc()); + DLIB_CASSERT(have_same_dimensions(gamma, gamma_grad)); + DLIB_CASSERT(have_same_dimensions(gamma_grad, beta_grad)); + DLIB_CASSERT(gamma.k() == src.k()); + DLIB_CASSERT(gamma.nr() == 1); + DLIB_CASSERT(gamma.nc() == 1); DLIB_CASSERT(have_same_dimensions(gradient_input, src)); DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad)); - DLIB_CASSERT(have_same_dimensions(gamma_grad, gamma)); - DLIB_CASSERT(have_same_dimensions(gamma_grad, beta_grad)); DLIB_CASSERT(eps > 0); beta_grad = 0; gamma_grad = 0; - resizable_tensor dvars, dmeans; dvars.copy_size(invstds); dmeans.copy_size(means); dvars = 0; dmeans = 0; - launch_kernel(_cuda_layer_normalize_gradient, max_jobs(num, src.num_samples()), + launch_kernel(_cuda_layer_normalize_gradient, max_jobs(src.k() * num, src.num_samples()), src_grad.device(), gamma_grad.device(), beta_grad.device(), src.device(), gradient_input.device(), means.device(), invstds.device(), gamma.device(), - dmeans.device(), dvars.device(), eps, src.num_samples(), num); + dmeans.device(), dvars.device(), eps, src.num_samples(), src.k(), num); } // ---------------------------------------------------------------------------------------- diff --git a/dlib/cuda/cuda_dlib.h b/dlib/cuda/cuda_dlib.h index a9a3517723..d157e1b655 100644 --- a/dlib/cuda/cuda_dlib.h +++ b/dlib/cuda/cuda_dlib.h @@ -357,7 +357,9 @@ namespace dlib const tensor& gamma, tensor& src_grad, tensor& gamma_grad, - tensor& beta_grad + tensor& beta_grad, + resizable_tensor& dmeans, + resizable_tensor& dvars ); // ----------------------------------------------------------------------------------- diff --git a/dlib/cuda/tensor_tools.cpp b/dlib/cuda/tensor_tools.cpp index af9eea9c1a..7e11000ceb 100644 --- a/dlib/cuda/tensor_tools.cpp +++ b/dlib/cuda/tensor_tools.cpp @@ -684,13 +684,15 @@ namespace dlib { namespace tt const tensor& gamma, tensor& src_grad, tensor& gamma_grad, - tensor& beta_grad + tensor& beta_grad, + resizable_tensor& dmeans, + resizable_tensor& dvars ) { #ifdef DLIB_USE_CUDA - cuda::layer_normalize_gradient(eps, gradient_input, means, invstds, src, gamma, src_grad, gamma_grad, beta_grad); + cuda::layer_normalize_gradient(eps, gradient_input, means, invstds, src, gamma, src_grad, gamma_grad, beta_grad, dmeans, dvars); #else - cpu::layer_normalize_gradient(eps, gradient_input, means, invstds, src, gamma, src_grad, gamma_grad, beta_grad); + cpu::layer_normalize_gradient(eps, gradient_input, means, invstds, src, gamma, src_grad, gamma_grad, beta_grad, dmeans, dvars); #endif } diff --git a/dlib/cuda/tensor_tools.h b/dlib/cuda/tensor_tools.h index 15602a826f..31310a9612 100644 --- a/dlib/cuda/tensor_tools.h +++ b/dlib/cuda/tensor_tools.h @@ -814,13 +814,13 @@ namespace dlib { namespace tt /*! requires - eps > 0 - - src.num_samples() == gamma.size() == beta.size() + - src.k() == gamma.size() == beta.size() + - gamma.num_samples() == gamma.nr() == gamma.nc() == 1 - have_same_dimensions(gamma, beta) == true - - beta.num_samples() ==beta.nr() ==gamma.nc() == 1 ensures - have_same_dimensions(#dest, src) == true - #means.size() == invstds.size() == src.num_samples() - - #dest == the normalized version of src. + - #dest == the normalized version of src, sample-wise. - #means == the mean values of the contents of src. - #invstds == 1/(the standard deviation values of the contents of src). !*/ @@ -834,7 +834,9 @@ namespace dlib { namespace tt const tensor& gamma, tensor& src_grad, tensor& gamma_grad, - tensor& beta_grad + tensor& beta_grad, + resizable_tensor& dmeans, + resizable_tensor& dvars ); /*! requires @@ -847,8 +849,6 @@ namespace dlib { namespace tt - have_same_dimensions(gamma, beta_grad) == true - means.size() == src.num_samples() - invstds.size() == src.num_samples() - - have_same_dimensions(means, gamma) == true - - have_same_dimensions(invstds, gamma) == true ensures - Let f(src,gamma,beta) == dot(gradient_input, dest output of layer_normalize(eps,dest,means,invstds,src,gamma,beta)) diff --git a/dlib/dnn/layers.h b/dlib/dnn/layers.h index 77ff918a6e..7dd6b51e43 100644 --- a/dlib/dnn/layers.h +++ b/dlib/dnn/layers.h @@ -1403,7 +1403,7 @@ namespace dlib template void setup (const SUBNET& sub) { - gamma = alias_tensor(1, sub.get_output().k(), sub.get_output().nr(), sub.get_output().nc()); + gamma = alias_tensor(1, sub.get_output().k()); beta = gamma; params.set_size(gamma.size()+beta.size()); @@ -1426,7 +1426,7 @@ namespace dlib auto g = gamma(params, 0); auto g_grad = gamma(params_grad, 0); auto b_grad = beta(params_grad, gamma.size()); - tt::layer_normalize_gradient(eps, gradient_input, means, invstds, sub.get_output(), g, sub.get_gradient_input(), g_grad, b_grad); + tt::layer_normalize_gradient(eps, gradient_input, means, invstds, sub.get_output(), g, sub.get_gradient_input(), g_grad, b_grad, dmeans, dvars); } const tensor& get_layer_params() const { return params; }; @@ -1493,6 +1493,7 @@ namespace dlib resizable_tensor params; alias_tensor gamma, beta; resizable_tensor means, invstds; + resizable_tensor dmeans, dvars; double learning_rate_multiplier; double weight_decay_multiplier; double bias_learning_rate_multiplier; diff --git a/dlib/test/dnn.cpp b/dlib/test/dnn.cpp index 7a40c6ccd0..c4bff74df3 100644 --- a/dlib/test/dnn.cpp +++ b/dlib/test/dnn.cpp @@ -607,7 +607,7 @@ namespace tt::tensor_rand rnd(0); rnd.fill_uniform(x); resizable_tensor means_cpu(x.num_samples()), invstds_cpu(x.num_samples()); - resizable_tensor gamma(1, x.k(), x.nr(), x.nc()), beta(1, x.k(), x.nr(), x.nc()); + resizable_tensor gamma(1, x.k(), 1, 1), beta(1, x.k(), 1, 1); gamma = 1; beta = 0; const float eps = 1e-5; @@ -639,16 +639,19 @@ namespace DLIB_TEST(max(abs(mat(means_cpu) - mat(means_cuda))) < 1e-5); DLIB_TEST(max(abs(mat(invstds_cpu) - mat(invstds_cuda))) < 1e-5); resizable_tensor gradient_input(x); - resizable_tensor src_grad_cpu(x), gamma_grad_cpu(1, x.k(), x.nr(), x.nc()), beta_grad_cpu(1, x.k(), x.nr(), x.nc()); - resizable_tensor src_grad_cuda(x), gamma_grad_cuda(1, x.k(), x.nr(), x.nc()), beta_grad_cuda(1, x.k(), x.nr(), x.nc()); + resizable_tensor src_grad_cpu(x), gamma_grad_cpu(1, x.k(), 1, 1), beta_grad_cpu(1, x.k(), 1, 1); + resizable_tensor src_grad_cuda(x), gamma_grad_cuda(1, x.k(), 1, 1), beta_grad_cuda(1, x.k(), 1, 1); + resizable_tensor dmeans_cpu, dvars_cpu, dmeans_cuda, dvars_cuda; rnd.fill_gaussian(gradient_input); src_grad_cpu = 0; src_grad_cuda = 0; - cpu::layer_normalize_gradient(eps, gradient_input, means_cpu, invstds_cpu, x, gamma, src_grad_cpu, gamma_grad_cpu, beta_grad_cpu); - cuda::layer_normalize_gradient(eps, gradient_input, means_cuda, invstds_cuda, x, gamma, src_grad_cuda, gamma_grad_cuda, beta_grad_cuda); + cpu::layer_normalize_gradient(eps, gradient_input, means_cpu, invstds_cpu, x, gamma, src_grad_cpu, gamma_grad_cpu, beta_grad_cpu, dmeans_cpu, dvars_cpu); + cuda::layer_normalize_gradient(eps, gradient_input, means_cuda, invstds_cuda, x, gamma, src_grad_cuda, gamma_grad_cuda, beta_grad_cuda, dmeans_cuda, dvars_cuda); DLIB_TEST(max(abs(mat(src_grad_cpu) - mat(src_grad_cuda))) < 1e-5); DLIB_TEST(max(abs(mat(gamma_grad_cpu) - mat(gamma_grad_cuda))) < 1e-5); DLIB_TEST(max(abs(mat(beta_grad_cpu) - mat(beta_grad_cuda))) < 1e-5); + DLIB_TEST(max(abs(mat(dmeans_cpu) - mat(dmeans_cuda))) < 1e-4); + DLIB_TEST(max(abs(mat(dvars_cpu) - mat(dvars_cuda))) < 1e-4); #endif } From fafdac37f12814981fdc97013aaef4886b6d1bc0 Mon Sep 17 00:00:00 2001 From: Cydral <53169060+Cydral@users.noreply.github.com> Date: Sat, 7 Sep 2024 19:29:40 +0200 Subject: [PATCH 03/14] Add RMS Normalization Layer (#2999) * Add RMS Normalization Layer * Update dnn.cpp * Missing entry in visitors.h to take into account the new rms_norm_ layer * Fix test function name * Fix dangling pointer issue in CUDA implementation of rms_normalize_gradient * Fixing the dnn.cpp test program for the new rms_norm_ layer * General update of the rms_norm_ class --- dlib/cuda/cpu_dlib.cpp | 138 ++++++++++++++++++++++++++++ dlib/cuda/cpu_dlib.h | 20 ++++ dlib/cuda/cuda_dlib.cu | 160 ++++++++++++++++++++++++++++++++ dlib/cuda/cuda_dlib.h | 20 ++++ dlib/cuda/tensor_tools.cpp | 34 +++++++ dlib/cuda/tensor_tools.h | 53 ++++++++++- dlib/dnn/layers.h | 125 +++++++++++++++++++++++++ dlib/dnn/layers_abstract.h | 172 +++++++++++++++++++++++++++++++++++ dlib/dnn/visitors.h | 39 ++++++++ dlib/dnn/visitors_abstract.h | 4 +- dlib/test/dnn.cpp | 102 ++++++++++++++++++++- 11 files changed, 863 insertions(+), 4 deletions(-) diff --git a/dlib/cuda/cpu_dlib.cpp b/dlib/cuda/cpu_dlib.cpp index b8b5a41239..1acb35e004 100644 --- a/dlib/cuda/cpu_dlib.cpp +++ b/dlib/cuda/cpu_dlib.cpp @@ -1447,6 +1447,144 @@ namespace dlib } } +// ----------------------------------------------------------------------------------- + + void rms_normalize( + const double eps, + resizable_tensor& dest, + resizable_tensor& scale, + const tensor& src, + const tensor& gamma + ) + { + DLIB_CASSERT( + gamma.k() == src.k() && + gamma.nr() == 1 && + gamma.nc() == 1 && + eps > 0, + "\nsrc.k(): " << src.k() << + "\ngamma.k(): " << gamma.k() << + "\ngamma.nr(): " << gamma.nr() << + "\ngamma.nc(): " << gamma.nc() << + "\neps: " << eps + ); + + const long ns = src.num_samples(); + const long ks = src.k(); + const long num = src.nr() * src.nc(); + + dest.copy_size(src); + scale.set_size(ns); + + // Compute RMS values + scale = 0; + const float* p_src = src.host(); + float* p_scale = scale.host(); + for (long n = 0; n < ns; ++n) + { + for (long k = 0; k < ks; ++k) + { + for (long i = 0; i < num; ++i) + { + p_scale[n] += (*p_src) * (*p_src); + ++p_src; + } + } + p_scale[n] = 1.0f / std::sqrt(p_scale[n] / (ks * num) + static_cast(eps)); + } + scale.host(); + + // Apply RMS normalization + p_src = src.host(); + float* p_dest = dest.host(); + const float* p_gamma = gamma.host(); + for (long n = 0; n < ns; ++n) + { + for (long k = 0; k < ks; ++k) + { + for (long i = 0; i < num; ++i) + { + *p_dest = (*p_src) * p_scale[n] * p_gamma[k]; + ++p_src; + ++p_dest; + } + } + } + } + + void rms_normalize_gradient( + const tensor& gradient_input, + const tensor& scale, + const tensor& src, + const tensor& gamma, + tensor& src_grad, + tensor& gamma_grad, + resizable_tensor& dscale + ) + { + DLIB_CASSERT(src.num_samples() == scale.size()); + DLIB_CASSERT(have_same_dimensions(gamma, gamma_grad)); + DLIB_CASSERT(gamma.k() == src.k()); + DLIB_CASSERT(gamma.nr() == 1); + DLIB_CASSERT(gamma.nc() == 1); + DLIB_CASSERT(have_same_dimensions(gradient_input, src)); + DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad)); + + const long ns = src.num_samples(); + const long ks = src.k(); + const long num = src.nr() * src.nc(); + + gamma_grad = 0; + dscale.copy_size(scale); + dscale = 0; + + auto p_grad = gradient_input.host(); + auto p_src = src.host(); + const auto p_gamma = gamma.host(); + const auto p_gamma_grad = gamma_grad.host(); + const auto p_scale = scale.host(); + auto p_dscale = dscale.host(); + + for (long n = 0; n < ns; ++n) + { + const float scale_pow = -0.5f * std::pow(p_scale[n], 3.0f); + for (long k = 0; k < ks; ++k) + { + for (long i = 0; i < num; ++i) + { + const float x_hat = *p_src * p_scale[n]; + p_gamma_grad[k] += (*p_grad) * x_hat; + + const float dx = *p_grad * p_gamma[k]; + p_dscale[n] += dx * *p_src * scale_pow; + + ++p_grad; + ++p_src; + } + } + } + + p_grad = gradient_input.host(); + p_src = src.host(); + auto p_src_grad = src_grad.host(); + const float invnum = 1.0f / (ks * num); + for (long n = 0; n < ns; ++n) + { + for (long k = 0; k < ks; ++k) + { + for (long i = 0; i < num; ++i) + { + const float dx = *p_grad * p_gamma[k]; + *p_src_grad += dx * p_scale[n] + p_dscale[n] * 2 * *p_src * invnum; + + ++p_grad; + ++p_src; + ++p_src_grad; + } + } + } + } + // ----------------------------------------------------------------------------------- void threshold ( diff --git a/dlib/cuda/cpu_dlib.h b/dlib/cuda/cpu_dlib.h index 79ef9842b5..45bc57fa97 100644 --- a/dlib/cuda/cpu_dlib.h +++ b/dlib/cuda/cpu_dlib.h @@ -255,6 +255,26 @@ namespace dlib resizable_tensor& dvars ); + // ----------------------------------------------------------------------------------- + + void rms_normalize( + const double eps, + resizable_tensor& dest, + resizable_tensor& scale, + const tensor& src, + const tensor& gamma + ); + + void rms_normalize_gradient( + const tensor& gradient_input, + const tensor& scale, + const tensor& src, + const tensor& gamma, + tensor& src_grad, + tensor& gamma_grad, + resizable_tensor& dscale + ); + // ----------------------------------------------------------------------------------- void threshold ( diff --git a/dlib/cuda/cuda_dlib.cu b/dlib/cuda/cuda_dlib.cu index 70b8ccb0e0..3484baa7b3 100644 --- a/dlib/cuda/cuda_dlib.cu +++ b/dlib/cuda/cuda_dlib.cu @@ -2280,6 +2280,166 @@ namespace dlib dmeans.device(), dvars.device(), eps, src.num_samples(), src.k(), num); } + // ---------------------------------------------------------------------------------------- + + __global__ void _cuda_rms_normalize( + float* dest, + float* scale, + const float* src, + const float* gamma, + float eps, + size_t ns, + size_t ks, + size_t num + ) + { + for (auto n : grid_stride_range_y(0, ns)) + { + const auto ps = src + n * ks * num; + float sum_squares = 0.0f; + for (auto i : grid_stride_range(0, ks * num)) + { + sum_squares += ps[i] * ps[i]; + } + warp_reduce_atomic_add(scale[n], sum_squares / (ks * num)); + } + __syncthreads(); + + for (auto n : grid_stride_range_y(0, ns)) + { + for (auto i : grid_stride_range(0, 1)) + { + scale[n] = 1.0f / std::sqrt(scale[n] + eps); + } + } + __syncthreads(); + + for (auto n : grid_stride_range_y(0, ns)) + { + const auto ps = src + n * ks * num; + const auto pd = dest + n * ks * num; + for (auto i : grid_stride_range(0, ks * num)) + { + pd[i] = ps[i] * scale[n] * gamma[i / num]; + } + } + } + + void rms_normalize( + const double eps, + resizable_tensor& dest, + resizable_tensor& scale, + const tensor& src, + const tensor& gamma + ) + { + DLIB_CASSERT( + gamma.k() == src.k() && + gamma.nr() == 1 && + gamma.nc() == 1 && + eps > 0, + "\nsrc.k(): " << src.k() << + "\ngamma.k(): " << gamma.k() << + "\ngamma.nr(): " << gamma.nr() << + "\ngamma.nc(): " << gamma.nc() << + "\neps: " << eps + ); + + const long ns = src.num_samples(); + const long ks = src.k(); + const long num = src.nr() * src.nc(); + + dest.copy_size(src); + scale.set_size(ns); + scale = 0; + + launch_kernel(_cuda_rms_normalize, max_jobs(ks * num, ns), + dest.device(), scale.device(), src.device(), gamma.device(), eps, ns, ks, num); + } + + // ---------------------------------------------------------------------------------------- + + __global__ void _cuda_rms_normalize_gradient( + float* src_grad, + float* gamma_grad, + float* dscale, + const float* src, + const float* gradient_input, + const float* scale, + const float* gamma, + size_t ns, + size_t ks, + size_t num + ) + { + for (auto nk : grid_stride_range_y(0, ns * ks)) + { + const auto n = nk / ks; + const auto k = nk % ks; + const auto ps = src + (n * ks + k) * num; + const auto pgi = gradient_input + (n * ks + k) * num; + const float scale_pow = -0.5f * std::pow(scale[n], 3.0f); + float temp_gg = 0.0f; + float temp_ds = 0.0f; + for (auto i : grid_stride_range(0, num)) + { + const float x_hat = ps[i] * scale[n]; + const float dx = pgi[i] * gamma[i / num]; + temp_gg += pgi[i] * x_hat; + temp_ds += dx * ps[i] * scale_pow; + } + warp_reduce_atomic_add(gamma_grad[k], temp_gg); + warp_reduce_atomic_add(dscale[n], temp_ds); + } + __syncthreads(); + + const float invnum = 1.0f / (ks * num); + for (auto n : grid_stride_range_y(0, ns)) + { + const auto ps = src + n * ks * num; + const auto pgi = gradient_input + n * ks * num; + const auto psg = src_grad + n * ks * num; + for (auto i : grid_stride_range(0, ks * num)) + { + const float dx = pgi[i] * gamma[i / num]; + psg[i] += dx * scale[n] + dscale[n] * 2 * ps[i] * invnum; + } + } + } + + void rms_normalize_gradient( + const tensor& gradient_input, + const tensor& scale, + const tensor& src, + const tensor& gamma, + tensor& src_grad, + tensor& gamma_grad, + resizable_tensor& dscale + ) + { + DLIB_CASSERT(src.num_samples() == scale.size()); + DLIB_CASSERT(have_same_dimensions(gamma, gamma_grad)); + DLIB_CASSERT(gamma.k() == src.k()); + DLIB_CASSERT(gamma.nr() == 1); + DLIB_CASSERT(gamma.nc() == 1); + DLIB_CASSERT(have_same_dimensions(gradient_input, src)); + DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad)); + + const long ns = src.num_samples(); + const long ks = src.k(); + const long num = src.nr() * src.nc(); + + gamma_grad = 0; + dscale.copy_size(scale); + dscale = 0; + + // Lancement du kernel CUDA + launch_kernel(_cuda_rms_normalize_gradient, max_jobs(ks * num, ns), + src_grad.device(), gamma_grad.device(), dscale.device(), + src.device(), gradient_input.device(), scale.device(), gamma.device(), + ns, ks, num); + } + // ---------------------------------------------------------------------------------------- __global__ void _cuda_copy_tensor_add_to (float* dest, size_t size, const float* src, size_t dest_stride, size_t src_stride, size_t block_size) diff --git a/dlib/cuda/cuda_dlib.h b/dlib/cuda/cuda_dlib.h index d157e1b655..059c6dd442 100644 --- a/dlib/cuda/cuda_dlib.h +++ b/dlib/cuda/cuda_dlib.h @@ -362,6 +362,26 @@ namespace dlib resizable_tensor& dvars ); + // ----------------------------------------------------------------------------------- + + void rms_normalize( + const double eps, + resizable_tensor& dest, + resizable_tensor& scale, + const tensor& src, + const tensor& gamma + ); + + void rms_normalize_gradient( + const tensor& gradient_input, + const tensor& scale, + const tensor& src, + const tensor& gamma, + tensor& src_grad, + tensor& gamma_grad, + resizable_tensor& dscale + ); + // ----------------------------------------------------------------------------------- void threshold ( diff --git a/dlib/cuda/tensor_tools.cpp b/dlib/cuda/tensor_tools.cpp index 7e11000ceb..f4b684dec9 100644 --- a/dlib/cuda/tensor_tools.cpp +++ b/dlib/cuda/tensor_tools.cpp @@ -696,6 +696,40 @@ namespace dlib { namespace tt #endif } +// ---------------------------------------------------------------------------------------- + + void rms_normalize( + const double eps, + resizable_tensor& dest, + resizable_tensor& scale, + const tensor& src, + const tensor& gamma + ) + { +#ifdef DLIB_USE_CUDA + cuda::rms_normalize(eps, dest, scale, src, gamma); +#else + cpu::rms_normalize(eps, dest, scale, src, gamma); +#endif + } + + void rms_normalize_gradient( + const tensor& gradient_input, + const tensor& scale, + const tensor& src, + const tensor& gamma, + tensor& src_grad, + tensor& gamma_grad, + resizable_tensor& dscale + ) + { +#ifdef DLIB_USE_CUDA + cuda::rms_normalize_gradient(gradient_input, scale, src, gamma, src_grad, gamma_grad, dscale); +#else + cpu::rms_normalize_gradient(gradient_input, scale, src, gamma, src_grad, gamma_grad, dscale); +#endif + } + // ---------------------------------------------------------------------------------------- void threshold ( diff --git a/dlib/cuda/tensor_tools.h b/dlib/cuda/tensor_tools.h index 31310a9612..245035c56a 100644 --- a/dlib/cuda/tensor_tools.h +++ b/dlib/cuda/tensor_tools.h @@ -857,7 +857,58 @@ namespace dlib { namespace tt - Assigns the gradient of f() with respect to beta to #beta_grad. !*/ - // ----------------------------------------------------------------------------------- +// ----------------------------------------------------------------------------------- + + void rms_normalize( + const double eps, + resizable_tensor& dest, + resizable_tensor& scale, + const tensor& src, + const tensor& gamma + ); + /*! + requires + - eps > 0 + - gamma.k() == src.k() + - gamma.nr() == 1 + - gamma.nc() == 1 + ensures + - have_same_dimensions(#dest, src) == true + - #scale.size() == src.num_samples() + - #dest == the RMS normalized version of src + - #scale contains the RMS (Root Mean Square) values used to normalize each sample of src. + - Each element of #dest is computed as: + - #dest[n, k, i, j] == src[n, k, i, j] * gamma[k] / scale[n] + where n is the sample index, k is the channel index, and i, j are the spatial indices. + !*/ + + void rms_normalize_gradient( + const tensor& gradient_input, + const tensor& scale, + const tensor& src, + const tensor& gamma, + tensor& src_grad, + tensor& gamma_grad, + resizable_tensor& dscale + ); + /*! + requires + - scale.size() == src.num_samples() + - have_same_dimensions(gamma, gamma_grad) + - gamma.k() == src.k() + - gamma.nr() == 1 + - gamma.nc() == 1 + - have_same_dimensions(gradient_input, src) + - have_same_dimensions(gradient_input, src_grad) + ensures + - Let f(src, gamma) == dot(gradient_input, dest output of + rms_normalize(eps, dest, scale, src, gamma)) + - Adds the gradient of f() with respect to src to #src_grad + - Assigns the gradient of f() with respect to gamma to #gamma_grad + - #dscale contains the gradients of f() with respect to the RMS values. + !*/ + +// ----------------------------------------------------------------------------------- void threshold ( tensor& data, diff --git a/dlib/dnn/layers.h b/dlib/dnn/layers.h index 7dd6b51e43..ef11c1b34b 100644 --- a/dlib/dnn/layers.h +++ b/dlib/dnn/layers.h @@ -1504,6 +1504,131 @@ namespace dlib template using layer_norm = add_layer; +// ---------------------------------------------------------------------------------------- + + const double DEFAULT_RMS_NORM_EPS = 1e-5; + + class rms_norm_ + { + public: + explicit rms_norm_( + double eps_ = DEFAULT_RMS_NORM_EPS + ) : + learning_rate_multiplier(1), + weight_decay_multiplier(0), + bias_learning_rate_multiplier(1), + bias_weight_decay_multiplier(1), + eps(eps_) + { + } + + double get_eps() const { return eps; } + + double get_learning_rate_multiplier() const { return learning_rate_multiplier; } + double get_weight_decay_multiplier() const { return weight_decay_multiplier; } + void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; } + void set_weight_decay_multiplier(double val) { weight_decay_multiplier = val; } + + double get_bias_learning_rate_multiplier() const { return bias_learning_rate_multiplier; } + double get_bias_weight_decay_multiplier() const { return bias_weight_decay_multiplier; } + void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; } + void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; } + + inline dpoint map_input_to_output(const dpoint& p) const { return p; } + inline dpoint map_output_to_input(const dpoint& p) const { return p; } + + template + void setup(const SUBNET& sub) + { + gamma = alias_tensor(1, sub.get_output().k()); + params.set_size(gamma.size()); + gamma(params, 0) = 1; + } + + template + void forward(const SUBNET& sub, resizable_tensor& output) + { + auto g = gamma(params, 0); + tt::rms_normalize(eps, output, scale, sub.get_output(), g); + } + + template + void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad) + { + auto g = gamma(params, 0); + auto g_grad = gamma(params_grad, 0); + tt::rms_normalize_gradient(gradient_input, scale, sub.get_output(), g, sub.get_gradient_input(), g_grad, dscale); + } + + const tensor& get_layer_params() const { return params; }; + tensor& get_layer_params() { return params; }; + + friend void serialize(const rms_norm_& item, std::ostream& out) + { + serialize("rms_norm_", out); + serialize(item.params, out); + serialize(item.gamma, out); + serialize(item.learning_rate_multiplier, out); + serialize(item.weight_decay_multiplier, out); + serialize(item.bias_learning_rate_multiplier, out); + serialize(item.bias_weight_decay_multiplier, out); + serialize(item.eps, out); + } + + friend void deserialize(rms_norm_& item, std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "rms_norm_") + throw serialization_error("Unexpected version '" + version + "' found while deserializing dlib::rms_norm_."); + deserialize(item.params, in); + deserialize(item.gamma, in); + deserialize(item.learning_rate_multiplier, in); + deserialize(item.weight_decay_multiplier, in); + deserialize(item.bias_learning_rate_multiplier, in); + deserialize(item.bias_weight_decay_multiplier, in); + deserialize(item.eps, in); + } + + friend std::ostream& operator<<(std::ostream& out, const rms_norm_& item) + { + out << "rms_norm"; + out << " (eps=" << item.eps << ")"; + out << " learning_rate_mult=" << item.learning_rate_multiplier; + out << " weight_decay_mult=" << item.weight_decay_multiplier; + out << " bias_learning_rate_mult=" << item.bias_learning_rate_multiplier; + out << " bias_weight_decay_mult=" << item.bias_weight_decay_multiplier; + return out; + } + + friend void to_xml(const rms_norm_& item, std::ostream& out) + { + out << "\n"; + out << mat(item.params); + out << "\n"; + } + + private: + resizable_tensor params; + alias_tensor gamma; + resizable_tensor scale; + resizable_tensor dscale; + double learning_rate_multiplier; + double weight_decay_multiplier; + double bias_learning_rate_multiplier; + double bias_weight_decay_multiplier; + double eps; + }; + + template + using rms_norm = add_layer; + // ---------------------------------------------------------------------------------------- enum layer_mode { diff --git a/dlib/dnn/layers_abstract.h b/dlib/dnn/layers_abstract.h index 7a29ab1347..8c09442e75 100644 --- a/dlib/dnn/layers_abstract.h +++ b/dlib/dnn/layers_abstract.h @@ -1468,6 +1468,7 @@ namespace dlib using dropout_rate = add_layer, SUBNET>; template using dropout_10 = add_layer, SUBNET>; + // ---------------------------------------------------------------------------------------- class multiply_ @@ -1665,6 +1666,177 @@ namespace dlib !*/ }; +// ---------------------------------------------------------------------------------------- + + const float DEFAULT_RMS_NORM_EPS = 1e-5f; + + class rms_norm_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This object implements the EXAMPLE_COMPUTATIONAL_LAYER_ interface + defined above, specifically defining a root mean square (RMS) normalization layer. + + RMS normalization is a technique that normalizes the input tensor based on the + root mean square (RMS) of its elements. Unlike traditional layer normalization, + which both centers and scales the data, RMS normalization only scales by the RMS + value. This makes it computationally more efficient, as it avoids the need to + compute the mean and subtract it from each element. + + This layer produces output tensors with the same dimensionality as the input tensors. + Specifically, for an input tensor with shape [num_samples, k, nr, nc], the RMS + normalization is applied across the [nr, nc] dimensions independently for each + element in the [k] dimension and for each sample in the [num_samples] dimension. + The scaling factor (RMS) and the learnable scaling parameter (gamma) are both of + size [k]. + + The key characteristics of this layer are: + - The RMS of the elements in each sample is standardized to 1. + - It does not center the data (i.e., it does not subtract the mean). + - A learnable scaling factor (gamma) is applied after normalization, allowing the + model to adapt the scaling dynamically. + + This layer is particularly effective in various natural language processing tasks, + where it has been shown to provide performance similar to or better than traditional + layer normalization, with reduced computational overhead. + !*/ + + public: + rms_norm_( + ); + /*! + ensures + - #get_learning_rate_multiplier() == 1 + - #get_weight_decay_multiplier() == 0 + - #get_bias_learning_rate_multiplier() == 1 + - #get_bias_weight_decay_multiplier() == 1 + - #get_eps() == DEFAULT_RMS_NORM_EPS + !*/ + + explicit rms_norm_( + float eps_ = DEFAULT_RMS_NORM_EPS + ); + /*! + requires + - eps > 0 + ensures + - #get_learning_rate_multiplier() == 1 + - #get_weight_decay_multiplier() == 0 + - #get_bias_learning_rate_multiplier() == 1 + - #get_bias_weight_decay_multiplier() == 1 + - #get_eps() == eps_ + !*/ + + float get_eps( + ) const; + /*! + ensures + - When doing RMS normalization, we are dividing by the root mean square. + This epsilon value returned by this function is added to the + mean square to prevent division by zero. + !*/ + + void set_eps( + float val + ); + /*! + requires + - val > 0 + ensures + - #get_eps() == val + !*/ + + double get_learning_rate_multiplier( + ) const; + /*! + ensures + - returns a multiplier number. The interpretation is that this object is + requesting that the learning rate used to optimize its parameters be + multiplied by get_learning_rate_multiplier(). + !*/ + + double get_weight_decay_multiplier( + ) const; + /*! + ensures + - returns a multiplier number. The interpretation is that this object is + requesting that the weight decay used to optimize its parameters be + multiplied by get_weight_decay_multiplier(). + !*/ + + void set_learning_rate_multiplier( + double val + ); + /*! + requires + - val >= 0 + ensures + - #get_learning_rate_multiplier() == val + !*/ + + void set_weight_decay_multiplier( + double val + ); + /*! + requires + - val >= 0 + ensures + - #get_weight_decay_multiplier() == val + !*/ + + double get_bias_learning_rate_multiplier( + ) const; + /*! + ensures + - returns a multiplier number. The interpretation is that this object is + requesting that the learning rate used to optimize its bias parameters be + multiplied by get_learning_rate_multiplier()*get_bias_learning_rate_multiplier(). + !*/ + + double get_bias_weight_decay_multiplier( + ) const; + /*! + ensures + - returns a multiplier number. The interpretation is that this object is + requesting that the weight decay used to optimize its bias parameters be + multiplied by get_weight_decay_multiplier()*get_bias_weight_decay_multiplier(). + !*/ + + void set_bias_learning_rate_multiplier( + double val + ); + /*! + requires + - val >= 0 + ensures + - #get_bias_learning_rate_multiplier() == val + !*/ + + void set_bias_weight_decay_multiplier( + double val + ); + /*! + requires + - val >= 0 + ensures + - #get_bias_weight_decay_multiplier() == val + !*/ + + template void setup (const SUBNET& sub); + template void forward(const SUBNET& sub, resizable_tensor& output); + template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); + dpoint map_input_to_output(dpoint p) const; + dpoint map_output_to_input(dpoint p) const; + const tensor& get_layer_params() const; + tensor& get_layer_params(); + /*! + These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. + !*/ + }; + + template + using rms_norm = add_layer; + // ---------------------------------------------------------------------------------------- enum layer_mode diff --git a/dlib/dnn/visitors.h b/dlib/dnn/visitors.h index c40bcbd33e..a601066795 100644 --- a/dlib/dnn/visitors.h +++ b/dlib/dnn/visitors.h @@ -308,6 +308,14 @@ namespace dlib set_bias_weight_decay_multiplier(l.subnet().layer_details(), 0); } + template + void disable_input_bias(add_layer& l) + { + disable_bias(l.subnet().layer_details()); + set_bias_learning_rate_multiplier(l.subnet().layer_details(), 0); + set_bias_weight_decay_multiplier(l.subnet().layer_details(), 0); + } + template void disable_input_bias(add_layer, U, E>& l) { @@ -333,6 +341,14 @@ namespace dlib set_bias_weight_decay_multiplier(l.subnet().get_repeated_layer(0).layer_details(), 0); } + template class R, typename U, typename E> + void disable_input_bias(add_layer, E>& l) + { + disable_bias(l.subnet().get_repeated_layer(0).layer_details()); + set_bias_learning_rate_multiplier(l.subnet().get_repeated_layer(0).layer_details(), 0); + set_bias_weight_decay_multiplier(l.subnet().get_repeated_layer(0).layer_details(), 0); + } + // handle input repeat layer with tag case template void disable_input_bias(add_layer, add_tag_layer, E>& ) @@ -344,6 +360,11 @@ namespace dlib { } + template + void disable_input_bias(add_layer, E>& ) + { + } + // handle tag layer case template void disable_input_bias(add_layer, add_tag_layer, E>& ) @@ -355,6 +376,11 @@ namespace dlib { } + template + void disable_input_bias(add_layer, E>& ) + { + } + // handle skip layer case template class TAG, typename U, typename E> void disable_input_bias(add_layer, add_skip_layer, E>& ) @@ -366,6 +392,11 @@ namespace dlib { } + template