diff --git a/src/layer/x86/layernorm_x86.cpp b/src/layer/x86/layernorm_x86.cpp index 21840c6b3d20..91e36163ff32 100644 --- a/src/layer/x86/layernorm_x86.cpp +++ b/src/layer/x86/layernorm_x86.cpp @@ -13,9 +13,6 @@ // specific language governing permissions and limitations under the License. #include "layernorm_x86.h" -#include "x86_usability.h" - -#include #if __SSE2__ #include @@ -24,6 +21,8 @@ #endif // __AVX__ #endif // __SSE2__ +#include "x86_usability.h" + namespace ncnn { LayerNorm_x86::LayerNorm_x86() @@ -33,37 +32,53 @@ LayerNorm_x86::LayerNorm_x86() #endif // __SSE2__ } -static NCNN_FORCEINLINE void fast_mean(float* ptr, float* mean, int elempack, int elemcount, int size) +static void layernorm(float* ptr, const float* gamma_ptr, const float* beta_ptr, float eps, int elemcount, int elempack) { - int i = 0; + const int size = elemcount * elempack; + #if __SSE2__ #if __AVX__ #if __AVX512F__ - __m512 _sum_512 = _mm512_setzero_ps(); - for (; i + 16 <= size; i += 16, ptr += 16) - { - __m512 _cur = _mm512_loadu_ps(ptr); - _sum_512 = _mm512_add_ps(_sum_512, _cur); - } + __m512 _mean_avx512 = _mm512_set1_ps(0.f); #endif // __AVX512F__ - __m256 _sum_256 = _mm256_setzero_ps(); - for (; i + 8 <= size; i += 8, ptr += 8) - { - __m256 _cur = _mm256_loadu_ps(ptr); - _sum_256 = _mm256_add_ps(_sum_256, _cur); - } + __m256 _mean_avx = _mm256_set1_ps(0.f); #endif // __AVX__ - __m128 _sum_128 = _mm_setzero_ps(); - for (; i + 4 <= size; i += 4, ptr += 4) - { - __m128 _cur = _mm_loadu_ps(ptr); - _sum_128 = _mm_add_ps(_sum_128, _cur); - } + __m128 _mean = _mm_set1_ps(0.f); #endif // __SSE2__ - float sum = 0.0f; - for (; i < size; ++i, ++ptr) + float mean = 0.f; { - sum += *ptr; + const float* ptr0 = ptr; + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = _mm512_loadu_ps(ptr0); + _mean_avx512 = _mm512_add_ps(_mean_avx512, _p); + ptr0 += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = _mm256_loadu_ps(ptr0); + _mean_avx = _mm256_add_ps(_mean_avx, _p); + ptr0 += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = _mm_loadu_ps(ptr0); + _mean = _mm_add_ps(_mean, _p); + ptr0 += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + mean += ptr0[0]; + ptr0++; + } } #if __SSE2__ @@ -71,110 +86,128 @@ static NCNN_FORCEINLINE void fast_mean(float* ptr, float* mean, int elempack, in #if __AVX512F__ if (elempack == 16) { - __m512 _mean = _mm512_div_ps(_sum_512, _mm512_set1_ps((float)elemcount)); - _mm512_storeu_ps(mean, _mean); + __m512 _elemcount = _mm512_set1_ps((float)elemcount); + _mean_avx512 = _mm512_div_ps(_mean_avx512, _elemcount); } #endif // __AVX512F__ - if (elempack == 8) { #if __AVX512F__ { - __m256 _low = _mm512_castps512_ps256(_sum_512); - __m256 _high = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(_sum_512), 1)); - _sum_256 = _mm256_add_ps(_sum_256, _high); - _sum_256 = _mm256_add_ps(_sum_256, _low); + __m256 _mean0 = _mm512_castps512_ps256(_mean_avx512); + __m256 _mean1 = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(_mean_avx512), 1)); + _mean_avx = _mm256_add_ps(_mean_avx, _mean0); + _mean_avx = _mm256_add_ps(_mean_avx, _mean1); } #endif // __AVX512F__ - __m256 _mean = _mm256_div_ps(_sum_256, _mm256_set1_ps((float)elemcount)); - _mm256_storeu_ps(mean, _mean); + + __m256 _elemcount = _mm256_set1_ps((float)elemcount); + _mean_avx = _mm256_div_ps(_mean_avx, _elemcount); +#if __AVX512F__ + _mean_avx512 = _mm512_insertf32x8(_mm512_castps256_ps512(_mean_avx), _mean_avx, 1); +#endif // __AVX512F__ } #endif // __AVX__ - if (elempack == 4) { #if __AVX__ #if __AVX512F__ { - __m256 _low = _mm512_castps512_ps256(_sum_512); - __m256 _high = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(_sum_512), 1)); - _sum_256 = _mm256_add_ps(_sum_256, _high); - _sum_256 = _mm256_add_ps(_sum_256, _low); + __m256 _mean0 = _mm512_castps512_ps256(_mean_avx512); + __m256 _mean1 = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(_mean_avx512), 1)); + _mean_avx = _mm256_add_ps(_mean_avx, _mean0); + _mean_avx = _mm256_add_ps(_mean_avx, _mean1); } #endif // __AVX512F__ { - __m128 _low = _mm256_castps256_ps128(_sum_256); - __m128 _high = _mm256_extractf128_ps(_sum_256, 1); - _sum_128 = _mm_add_ps(_sum_128, _low); - _sum_128 = _mm_add_ps(_sum_128, _high); + __m128 _mean0 = _mm256_castps256_ps128(_mean_avx); + __m128 _mean1 = _mm256_extractf128_ps(_mean_avx, 1); + _mean = _mm_add_ps(_mean, _mean0); + _mean = _mm_add_ps(_mean, _mean1); } #endif // __AVX__ - __m128 _mean = _mm_div_ps(_sum_128, _mm_set1_ps((float)elemcount)); - _mm_storeu_ps(mean, _mean); + + __m128 _elemcount = _mm_set1_ps((float)elemcount); + _mean = _mm_div_ps(_mean, _elemcount); +#if __AVX__ + _mean_avx = _mm256_insertf128_ps(_mm256_castps128_ps256(_mean), _mean, 1); +#if __AVX512F__ + _mean_avx512 = _mm512_insertf32x8(_mm512_castps256_ps512(_mean_avx), _mean_avx, 1); +#endif // __AVX512F__ +#endif // __AVX__ } #endif // __SSE2__ - if (elempack == 1) { #if __SSE2__ #if __AVX__ #if __AVX512F__ - sum += _mm512_comp_reduce_add_ps(_sum_512); + mean += _mm512_comp_reduce_add_ps(_mean_avx512); #endif // __AVX512F__ - sum += _mm256_reduce_add_ps(_sum_256); + mean += _mm256_reduce_add_ps(_mean_avx); #endif // __AVX__ - sum += _mm_reduce_add_ps(_sum_128); + mean += _mm_reduce_add_ps(_mean); #endif // __SSE2__ - mean[0] = sum / elemcount; - } -} -static NCNN_FORCEINLINE void fast_var(float* ptr, float* var, const float* mean, int elempack, int elemcount, int size) -{ - const float _mean = mean[0]; + mean = mean / elemcount; #if __SSE2__ - __m128 _mean_128 = (elempack == 4) ? _mm_loadu_ps(mean) : _mm_set1_ps(_mean); + _mean = _mm_set1_ps(mean); #if __AVX__ - __m256 _mean_256 = (elempack == 8) ? _mm256_loadu_ps(mean) : _mm256_insertf128_ps(_mm256_castps128_ps256(_mean_128), _mean_128, 1); + _mean_avx = _mm256_insertf128_ps(_mm256_castps128_ps256(_mean), _mean, 1); #if __AVX512F__ - __m512 _mean_512 = (elempack == 16) ? _mm512_loadu_ps(mean) : _mm512_insertf32x8(_mm512_castps256_ps512(_mean_256), _mean_256, 1); + _mean_avx512 = _mm512_insertf32x8(_mm512_castps256_ps512(_mean_avx), _mean_avx, 1); #endif // __AVX512F__ #endif // __AVX__ #endif // __SSE2__ + } - int i = 0; #if __SSE2__ #if __AVX__ #if __AVX512F__ - __m512 _sq_sum_512 = _mm512_setzero_ps(); - for (; i + 16 <= size; i += 16, ptr += 16) - { - __m512 _cur = _mm512_loadu_ps(ptr); - _cur = _mm512_sub_ps(_cur, _mean_512); - _sq_sum_512 = _mm512_fmadd_ps(_cur, _cur, _sq_sum_512); - } + __m512 _var_avx512 = _mm512_set1_ps(0.f); #endif // __AVX512F__ - __m256 _sq_sum_256 = _mm256_setzero_ps(); - for (; i + 8 <= size; i += 8, ptr += 8) - { - __m256 _cur = _mm256_loadu_ps(ptr); - _cur = _mm256_sub_ps(_cur, _mean_256); - _sq_sum_256 = _mm256_comp_fmadd_ps(_cur, _cur, _sq_sum_256); - } + __m256 _var_avx = _mm256_set1_ps(0.f); #endif // __AVX__ - __m128 _sq_sum_128 = _mm_setzero_ps(); - for (; i + 4 <= size; i += 4, ptr += 4) - { - __m128 _cur = _mm_loadu_ps(ptr); - _cur = _mm_sub_ps(_cur, _mean_128); - _sq_sum_128 = _mm_comp_fmadd_ps(_cur, _cur, _sq_sum_128); - } + __m128 _var = _mm_set1_ps(0.f); #endif // __SSE2__ - float sq_sum = 0.0f; - for (; i < size; ++i, ++ptr) + float var = 0.f; { - float tmp = *ptr - _mean; - sq_sum += tmp * tmp; + const float* ptr0 = ptr; + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = _mm512_loadu_ps(ptr0); + _p = _mm512_sub_ps(_p, _mean_avx512); + _var_avx512 = _mm512_fmadd_ps(_p, _p, _var_avx512); + ptr0 += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = _mm256_loadu_ps(ptr0); + _p = _mm256_sub_ps(_p, _mean_avx); + _var_avx = _mm256_comp_fmadd_ps(_p, _p, _var_avx); + ptr0 += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = _mm_loadu_ps(ptr0); + _p = _mm_sub_ps(_p, _mean); + _var = _mm_comp_fmadd_ps(_p, _p, _var); + ptr0 += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + float v = ptr0[0] - mean; + var += v * v; + ptr0++; + } } #if __SSE2__ @@ -182,384 +215,332 @@ static NCNN_FORCEINLINE void fast_var(float* ptr, float* var, const float* mean, #if __AVX512F__ if (elempack == 16) { - __m512 _var = _mm512_div_ps(_sq_sum_512, _mm512_set1_ps((float)elemcount)); - _mm512_storeu_ps(var, _var); + __m512 _elemcount = _mm512_set1_ps((float)elemcount); + __m512 _eps = _mm512_set1_ps(eps); + _var_avx512 = _mm512_div_ps(_var_avx512, _elemcount); + _var_avx512 = _mm512_add_ps(_var_avx512, _eps); + __m256 _var0 = _mm256_rsqrt_ps(_mm512_extractf32x8_ps(_var_avx512, 0)); + __m256 _var1 = _mm256_rsqrt_ps(_mm512_extractf32x8_ps(_var_avx512, 1)); + _var_avx512 = _mm512_insertf32x8(_mm512_castps256_ps512(_var0), _var1, 1); + _mean_avx512 = _mm512_mul_ps(_mean_avx512, _var_avx512); } #endif // __AVX512F__ - if (elempack == 8) { #if __AVX512F__ { - __m256 _low = _mm512_castps512_ps256(_sq_sum_512); - __m256 _high = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(_sq_sum_512), 1)); - _sq_sum_256 = _mm256_add_ps(_sq_sum_256, _low); - _sq_sum_256 = _mm256_add_ps(_sq_sum_256, _high); + __m256 _var0 = _mm512_castps512_ps256(_var_avx512); + __m256 _var1 = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(_var_avx512), 1)); + _var_avx = _mm256_add_ps(_var_avx, _var0); + _var_avx = _mm256_add_ps(_var_avx, _var1); } #endif // __AVX512F__ - __m256 _var = _mm256_div_ps(_sq_sum_256, _mm256_set1_ps((float)elemcount)); - _mm256_storeu_ps(var, _var); + + __m256 _elemcount = _mm256_set1_ps((float)elemcount); + __m256 _eps = _mm256_set1_ps(eps); + _var_avx = _mm256_div_ps(_var_avx, _elemcount); + _var_avx = _mm256_add_ps(_var_avx, _eps); + _var_avx = _mm256_rsqrt_ps(_var_avx); + _mean_avx = _mm256_mul_ps(_mean_avx, _var_avx); +#if __AVX512F__ + _var_avx512 = _mm512_insertf32x8(_mm512_castps256_ps512(_var_avx), _var_avx, 1); + _mean_avx512 = _mm512_insertf32x8(_mm512_castps256_ps512(_mean_avx), _mean_avx, 1); +#endif // __AVX512F__ } #endif // __AVX__ - if (elempack == 4) { #if __AVX__ #if __AVX512F__ { - __m256 _low = _mm512_castps512_ps256(_sq_sum_512); - __m256 _high = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(_sq_sum_512), 1)); - _sq_sum_256 = _mm256_add_ps(_sq_sum_256, _high); - _sq_sum_256 = _mm256_add_ps(_sq_sum_256, _low); + __m256 _var0 = _mm512_castps512_ps256(_var_avx512); + __m256 _var1 = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(_var_avx512), 1)); + _var_avx = _mm256_add_ps(_var_avx, _var0); + _var_avx = _mm256_add_ps(_var_avx, _var1); } #endif // __AVX512F__ { - __m128 _low = _mm256_castps256_ps128(_sq_sum_256); - __m128 _high = _mm256_extractf128_ps(_sq_sum_256, 1); - _sq_sum_128 = _mm_add_ps(_sq_sum_128, _low); - _sq_sum_128 = _mm_add_ps(_sq_sum_128, _high); + __m128 _var0 = _mm256_castps256_ps128(_var_avx); + __m128 _var1 = _mm256_extractf128_ps(_var_avx, 1); + _var = _mm_add_ps(_var, _var0); + _var = _mm_add_ps(_var, _var1); } #endif // __AVX__ - __m128 _var = _mm_div_ps(_sq_sum_128, _mm_set1_ps((float)elemcount)); - _mm_storeu_ps(var, _var); - } -#endif // __SSE2__ - if (elempack == 1) - { -#if __SSE2__ + __m128 _elemcount = _mm_set1_ps((float)elemcount); + __m128 _eps = _mm_set1_ps(eps); + _var = _mm_div_ps(_var, _elemcount); + _var = _mm_add_ps(_var, _eps); + _var = _mm_rsqrt_ps(_var); + _mean = _mm_mul_ps(_mean, _var); #if __AVX__ + _var_avx = _mm256_insertf128_ps(_mm256_castps128_ps256(_var), _var, 1); + _mean_avx = _mm256_insertf128_ps(_mm256_castps128_ps256(_mean), _mean, 1); #if __AVX512F__ - sq_sum += _mm512_comp_reduce_add_ps(_sq_sum_512); + _var_avx512 = _mm512_insertf32x8(_mm512_castps256_ps512(_var_avx), _var_avx, 1); + _mean_avx512 = _mm512_insertf32x8(_mm512_castps256_ps512(_mean_avx), _mean_avx, 1); #endif // __AVX512F__ - sq_sum += _mm256_reduce_add_ps(_sq_sum_256); #endif // __AVX__ - sq_sum += _mm_reduce_add_ps(_sq_sum_128); -#endif // __SSE2__ - var[0] = sq_sum / elemcount; } -} - -static NCNN_FORCEINLINE void fast_fmadd(float* ptr, const float* a, const float* b, int elempack, int size) -{ - const float _a = a[0]; - const float _b = b[0]; +#endif // __SSE2__ + if (elempack == 1) + { #if __SSE2__ - __m128 _a_128 = (elempack == 4) ? _mm_loadu_ps(a) : _mm_set1_ps(_a); - __m128 _b_128 = (elempack == 4) ? _mm_loadu_ps(b) : _mm_set1_ps(_b); #if __AVX__ - __m256 _a_256 = (elempack == 8) ? _mm256_loadu_ps(a) : _mm256_insertf128_ps(_mm256_castps128_ps256(_a_128), _a_128, 1); - __m256 _b_256 = (elempack == 8) ? _mm256_loadu_ps(b) : _mm256_insertf128_ps(_mm256_castps128_ps256(_b_128), _b_128, 1); #if __AVX512F__ - __m512 _a_512 = (elempack == 16) ? _mm512_loadu_ps(a) : _mm512_insertf32x8(_mm512_castps256_ps512(_a_256), _a_256, 1); - __m512 _b_512 = (elempack == 16) ? _mm512_loadu_ps(b) : _mm512_insertf32x8(_mm512_castps256_ps512(_b_256), _b_256, 1); + var += _mm512_comp_reduce_add_ps(_var_avx512); #endif // __AVX512F__ + var += _mm256_reduce_add_ps(_var_avx); #endif // __AVX__ + var += _mm_reduce_add_ps(_var); #endif // __SSE2__ - int i = 0; + var = 1.f / sqrtf(var / elemcount + eps); + mean = mean * var; #if __SSE2__ + _var = _mm_set1_ps(var); + _mean = _mm_set1_ps(mean); #if __AVX__ + _var_avx = _mm256_insertf128_ps(_mm256_castps128_ps256(_var), _var, 1); + _mean_avx = _mm256_insertf128_ps(_mm256_castps128_ps256(_mean), _mean, 1); #if __AVX512F__ - for (; i + 16 <= size; i += 16, ptr += 16) - { - __m512 _cur = _mm512_loadu_ps(ptr); - _cur = _mm512_fmadd_ps(_cur, _a_512, _b_512); - _mm512_storeu_ps(ptr, _cur); - } + _var_avx512 = _mm512_insertf32x8(_mm512_castps256_ps512(_var_avx), _var_avx, 1); + _mean_avx512 = _mm512_insertf32x8(_mm512_castps256_ps512(_mean_avx), _mean_avx, 1); #endif // __AVX512F__ - for (; i + 8 <= size; i += 8, ptr += 8) - { - __m256 _cur = _mm256_loadu_ps(ptr); - _cur = _mm256_comp_fmadd_ps(_cur, _a_256, _b_256); - _mm256_storeu_ps(ptr, _cur); - } #endif // __AVX__ - for (; i + 4 <= size; i += 4, ptr += 4) - { - __m128 _cur = _mm_loadu_ps(ptr); - _cur = _mm_comp_fmadd_ps(_cur, _a_128, _b_128); - _mm_storeu_ps(ptr, _cur); - } #endif // __SSE2__ - for (; i < size; ++i, ++ptr) - { - *ptr = (*ptr) * _a + _b; } -} -static NCNN_FORCEINLINE void fast_fmadd_fmadd(float* ptr, const float* a, const float* b, const float* gamma, const float* beta, int elempack, int size) -{ + if (gamma_ptr && beta_ptr) + { + int i = 0; #if __SSE2__ #if __AVX__ #if __AVX512F__ - if (elempack == 16) - { - int i = 0; - __m512 _a_512 = _mm512_loadu_ps(a); - __m512 _b_512 = _mm512_loadu_ps(b); - for (; i + 16 <= size; i += 16, ptr += 16, ++gamma, ++beta) + if (elempack == 16) { - __m512 _cur = _mm512_loadu_ps(ptr); - __m512 _gamma = _mm512_set1_ps(*gamma); - __m512 _beta = _mm512_set1_ps(*beta); - _cur = _mm512_fmadd_ps(_cur, _a_512, _b_512); - _cur = _mm512_fmadd_ps(_cur, _gamma, _beta); - _mm512_storeu_ps(ptr, _cur); + for (; i + 15 < size; i += 16) + { + __m512 _p = _mm512_loadu_ps(ptr); + __m512 _gamma = _mm512_set1_ps(gamma_ptr[0]); + __m512 _beta = _mm512_set1_ps(beta_ptr[0]); + _p = _mm512_fmsub_ps(_p, _var_avx512, _mean_avx512); + _p = _mm512_fmadd_ps(_p, _gamma, _beta); + _mm512_storeu_ps(ptr, _p); + ptr += 16; + gamma_ptr += 1; + beta_ptr += 1; + } } - } #endif // __AVX512F__ - - if (elempack == 8) - { - int i = 0; - __m256 _a_256 = _mm256_loadu_ps(a); - __m256 _b_256 = _mm256_loadu_ps(b); -#if __AVX512F__ - __m512 _a_512 = _mm512_insertf32x8(_mm512_castps256_ps512(_a_256), _a_256, 1); - __m512 _b_512 = _mm512_insertf32x8(_mm512_castps256_ps512(_b_256), _b_256, 1); - for (; i + 16 <= size; i += 16, ptr += 16, gamma += 2, beta += 2) + if (elempack == 8) { - __m512 _cur = _mm512_loadu_ps(ptr); - __m512 _gamma_0 = _mm512_set1_ps(gamma[0]); - __m512 _gamma_1 = _mm512_set1_ps(gamma[1]); - __m512 _beta_0 = _mm512_set1_ps(beta[0]); - __m512 _beta_1 = _mm512_set1_ps(beta[1]); - _gamma_0 = _mm512_mask_blend_ps(0xFF00, _gamma_0, _gamma_1); - _beta_0 = _mm512_mask_blend_ps(0xFF00, _beta_0, _beta_1); - _cur = _mm512_fmadd_ps(_cur, _a_512, _b_512); - _cur = _mm512_fmadd_ps(_cur, _gamma_0, _beta_0); - _mm512_storeu_ps(ptr, _cur); - } +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = _mm512_loadu_ps(ptr); + __m256 _gamma0 = _mm256_set1_ps(gamma_ptr[0]); + __m256 _gamma1 = _mm256_set1_ps(gamma_ptr[1]); + __m512 _gamma = _mm512_insertf32x8(_mm512_castps256_ps512(_gamma0), _gamma1, 1); + __m256 _beta0 = _mm256_set1_ps(beta_ptr[0]); + __m256 _beta1 = _mm256_set1_ps(beta_ptr[1]); + __m512 _beta = _mm512_insertf32x8(_mm512_castps256_ps512(_beta0), _beta1, 1); + _p = _mm512_fmsub_ps(_p, _var_avx512, _mean_avx512); + _p = _mm512_fmadd_ps(_p, _gamma, _beta); + _mm512_storeu_ps(ptr, _p); + ptr += 16; + gamma_ptr += 2; + beta_ptr += 2; + } #endif // __AVX512F__ - - for (; i + 8 <= size; i += 8, ptr += 8, ++gamma, ++beta) - { - __m256 _cur = _mm256_loadu_ps(ptr); - __m256 _gamma = _mm256_set1_ps(*gamma); - __m256 _beta = _mm256_set1_ps(*beta); - _cur = _mm256_comp_fmadd_ps(_cur, _a_256, _b_256); - _cur = _mm256_comp_fmadd_ps(_cur, _gamma, _beta); - _mm256_storeu_ps(ptr, _cur); + for (; i + 7 < size; i += 8) + { + __m256 _p = _mm256_loadu_ps(ptr); + __m256 _gamma = _mm256_set1_ps(gamma_ptr[0]); + __m256 _beta = _mm256_set1_ps(beta_ptr[0]); + _p = _mm256_comp_fmsub_ps(_p, _var_avx, _mean_avx); + _p = _mm256_comp_fmadd_ps(_p, _gamma, _beta); + _mm256_storeu_ps(ptr, _p); + ptr += 8; + gamma_ptr += 1; + beta_ptr += 1; + } } - } #endif // __AVX__ - - if (elempack == 4) - { - int i = 0; - __m128 _a_128 = _mm_loadu_ps(a); - __m128 _b_128 = _mm_loadu_ps(b); + if (elempack == 4) + { #if __AVX__ - __m256 _a_256 = _mm256_insertf128_ps(_mm256_castps128_ps256(_a_128), _a_128, 1); - __m256 _b_256 = _mm256_insertf128_ps(_mm256_castps128_ps256(_b_128), _b_128, 1); #if __AVX512F__ - __m512 _a_512 = _mm512_insertf32x8(_mm512_castps256_ps512(_a_256), _a_256, 1); - __m512 _b_512 = _mm512_insertf32x8(_mm512_castps256_ps512(_b_256), _b_256, 1); - for (; i + 16 <= size; i += 16, ptr += 16, gamma += 4, beta += 4) - { - __m512 _cur = _mm512_loadu_ps(ptr); - __m512 _gamma_0 = _mm512_set1_ps(gamma[0]); - __m512 _gamma_1 = _mm512_set1_ps(gamma[1]); - __m512 _gamma_2 = _mm512_set1_ps(gamma[2]); - __m512 _gamma_3 = _mm512_set1_ps(gamma[3]); - __m512 _beta_0 = _mm512_set1_ps(beta[0]); - __m512 _beta_1 = _mm512_set1_ps(beta[1]); - __m512 _beta_2 = _mm512_set1_ps(beta[2]); - __m512 _beta_3 = _mm512_set1_ps(beta[3]); - _gamma_0 = _mm512_mask_blend_ps(0x00F0, _gamma_0, _gamma_1); - _gamma_0 = _mm512_mask_blend_ps(0x0F00, _gamma_0, _gamma_2); - _gamma_0 = _mm512_mask_blend_ps(0xF000, _gamma_0, _gamma_3); - _beta_0 = _mm512_mask_blend_ps(0x00F0, _beta_0, _beta_1); - _beta_0 = _mm512_mask_blend_ps(0x0F00, _beta_0, _beta_2); - _beta_0 = _mm512_mask_blend_ps(0xF000, _beta_0, _beta_3); - _cur = _mm512_fmadd_ps(_cur, _a_512, _b_512); - _cur = _mm512_fmadd_ps(_cur, _gamma_0, _beta_0); - _mm512_storeu_ps(ptr, _cur); - } + for (; i + 15 < size; i += 16) + { + __m512 _p = _mm512_loadu_ps(ptr); + __m128 _gamma0 = _mm_set1_ps(gamma_ptr[0]); + __m128 _gamma1 = _mm_set1_ps(gamma_ptr[1]); + __m128 _gamma2 = _mm_set1_ps(gamma_ptr[2]); + __m128 _gamma3 = _mm_set1_ps(gamma_ptr[3]); + __m256 _gamma01 = _mm256_insertf128_ps(_mm256_castps128_ps256(_gamma0), _gamma1, 1); + __m256 _gamma23 = _mm256_insertf128_ps(_mm256_castps128_ps256(_gamma2), _gamma3, 1); + __m512 _gamma = _mm512_insertf32x8(_mm512_castps256_ps512(_gamma01), _gamma23, 1); + __m128 _beta0 = _mm_set1_ps(beta_ptr[0]); + __m128 _beta1 = _mm_set1_ps(beta_ptr[1]); + __m128 _beta2 = _mm_set1_ps(beta_ptr[2]); + __m128 _beta3 = _mm_set1_ps(beta_ptr[3]); + __m256 _beta01 = _mm256_insertf128_ps(_mm256_castps128_ps256(_beta0), _beta1, 1); + __m256 _beta23 = _mm256_insertf128_ps(_mm256_castps128_ps256(_beta2), _beta3, 1); + __m512 _beta = _mm512_insertf32x8(_mm512_castps256_ps512(_beta01), _beta23, 1); + _p = _mm512_fmsub_ps(_p, _var_avx512, _mean_avx512); + _p = _mm512_fmadd_ps(_p, _gamma, _beta); + _mm512_storeu_ps(ptr, _p); + ptr += 16; + gamma_ptr += 4; + beta_ptr += 4; + } #endif // __AVX512F__ - - for (; i + 8 <= size; i += 8, ptr += 8, gamma += 2, beta += 2) - { - __m256 _cur = _mm256_loadu_ps(ptr); - __m256 _gamma_0 = _mm256_set1_ps(gamma[0]); - __m256 _gamma_1 = _mm256_set1_ps(gamma[1]); - __m256 _beta_0 = _mm256_set1_ps(beta[0]); - __m256 _beta_1 = _mm256_set1_ps(beta[1]); - _gamma_0 = _mm256_blend_ps(_gamma_0, _gamma_1, 0xF0); - _beta_0 = _mm256_blend_ps(_beta_0, _beta_1, 0xF0); - _cur = _mm256_comp_fmadd_ps(_cur, _a_256, _b_256); - _cur = _mm256_comp_fmadd_ps(_cur, _gamma_0, _beta_0); - _mm256_storeu_ps(ptr, _cur); + for (; i + 7 < size; i += 8) + { + __m256 _p = _mm256_loadu_ps(ptr); + __m128 _gamma0 = _mm_set1_ps(gamma_ptr[0]); + __m128 _gamma1 = _mm_set1_ps(gamma_ptr[1]); + __m256 _gamma = _mm256_insertf128_ps(_mm256_castps128_ps256(_gamma0), _gamma1, 1); + __m128 _beta0 = _mm_set1_ps(beta_ptr[0]); + __m128 _beta1 = _mm_set1_ps(beta_ptr[1]); + __m256 _beta = _mm256_insertf128_ps(_mm256_castps128_ps256(_beta0), _beta1, 1); + _p = _mm256_comp_fmsub_ps(_p, _var_avx, _mean_avx); + _p = _mm256_comp_fmadd_ps(_p, _gamma, _beta); + _mm256_storeu_ps(ptr, _p); + ptr += 8; + gamma_ptr += 2; + beta_ptr += 2; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = _mm_loadu_ps(ptr); + __m128 _gamma = _mm_set1_ps(gamma_ptr[0]); + __m128 _beta = _mm_set1_ps(beta_ptr[0]); + _p = _mm_comp_fmsub_ps(_p, _var, _mean); + _p = _mm_comp_fmadd_ps(_p, _gamma, _beta); + _mm_storeu_ps(ptr, _p); + ptr += 4; + gamma_ptr += 1; + beta_ptr += 1; + } } + if (elempack == 1) + { +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = _mm512_loadu_ps(ptr); + __m512 _gamma = _mm512_loadu_ps(gamma_ptr); + __m512 _beta = _mm512_loadu_ps(beta_ptr); + _p = _mm512_fmsub_ps(_p, _var_avx512, _mean_avx512); + _p = _mm512_fmadd_ps(_p, _gamma, _beta); + _mm512_storeu_ps(ptr, _p); + ptr += 16; + gamma_ptr += 16; + beta_ptr += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = _mm256_loadu_ps(ptr); + __m256 _gamma = _mm256_loadu_ps(gamma_ptr); + __m256 _beta = _mm256_loadu_ps(beta_ptr); + _p = _mm256_comp_fmsub_ps(_p, _var_avx, _mean_avx); + _p = _mm256_comp_fmadd_ps(_p, _gamma, _beta); + _mm256_storeu_ps(ptr, _p); + ptr += 8; + gamma_ptr += 8; + beta_ptr += 8; + } #endif // __AVX__ - - for (; i + 4 <= size; i += 4, ptr += 4, ++gamma, ++beta) + for (; i + 3 < size; i += 4) + { + __m128 _p = _mm_loadu_ps(ptr); + __m128 _gamma = _mm_loadu_ps(gamma_ptr); + __m128 _beta = _mm_loadu_ps(beta_ptr); + _p = _mm_comp_fmsub_ps(_p, _var, _mean); + _p = _mm_comp_fmadd_ps(_p, _gamma, _beta); + _mm_storeu_ps(ptr, _p); + ptr += 4; + gamma_ptr += 4; + beta_ptr += 4; + } + } +#endif // __SSE2__ + for (; i < size; i++) { - __m128 _cur = _mm_loadu_ps(ptr); - __m128 _gamma = _mm_set1_ps(*gamma); - __m128 _beta = _mm_set1_ps(*beta); - _cur = _mm_comp_fmadd_ps(_cur, _a_128, _b_128); - _cur = _mm_comp_fmadd_ps(_cur, _gamma, _beta); - _mm_storeu_ps(ptr, _cur); + ptr[0] = (ptr[0] * var - mean) * gamma_ptr[0] + beta_ptr[0]; + ptr++; + gamma_ptr++; + beta_ptr++; } } -#endif // __SSE2__ - - if (elempack == 1) + else { int i = 0; - const float _a = a[0]; - const float _b = b[0]; #if __SSE2__ - __m128 _a_128 = _mm_set1_ps(_a); - __m128 _b_128 = _mm_set1_ps(_b); #if __AVX__ - __m256 _a_256 = _mm256_insertf128_ps(_mm256_castps128_ps256(_a_128), _a_128, 1); - __m256 _b_256 = _mm256_insertf128_ps(_mm256_castps128_ps256(_b_128), _b_128, 1); #if __AVX512F__ - __m512 _a_512 = _mm512_insertf32x8(_mm512_castps256_ps512(_a_256), _a_256, 1); - __m512 _b_512 = _mm512_insertf32x8(_mm512_castps256_ps512(_b_256), _b_256, 1); - for (; i + 16 <= size; i += 16, ptr += 16, gamma += 16, beta += 16) + for (; i + 15 < size; i += 16) { - __m512 _cur = _mm512_loadu_ps(ptr); - __m512 _gamma = _mm512_loadu_ps(gamma); - __m512 _beta = _mm512_loadu_ps(beta); - _cur = _mm512_fmadd_ps(_cur, _a_512, _b_512); - _cur = _mm512_fmadd_ps(_cur, _gamma, _beta); - _mm512_storeu_ps(ptr, _cur); + __m512 _p = _mm512_loadu_ps(ptr); + _p = _mm512_fmsub_ps(_p, _var_avx512, _mean_avx512); + _mm512_storeu_ps(ptr, _p); + ptr += 16; } #endif // __AVX512F__ - - for (; i + 8 <= size; i += 8, ptr += 8, gamma += 8, beta += 8) + for (; i + 7 < size; i += 8) { - __m256 _cur = _mm256_loadu_ps(ptr); - __m256 _gamma = _mm256_loadu_ps(gamma); - __m256 _beta = _mm256_loadu_ps(beta); - _cur = _mm256_comp_fmadd_ps(_cur, _a_256, _b_256); - _cur = _mm256_comp_fmadd_ps(_cur, _gamma, _beta); - _mm256_storeu_ps(ptr, _cur); + __m256 _p = _mm256_loadu_ps(ptr); + _p = _mm256_comp_fmsub_ps(_p, _var_avx, _mean_avx); + _mm256_storeu_ps(ptr, _p); + ptr += 8; } #endif // __AVX__ - - for (; i + 4 <= size; i += 4, ptr += 4, gamma += 4, beta += 4) + for (; i + 3 < size; i += 4) { - __m128 _cur = _mm_loadu_ps(ptr); - __m128 _gamma = _mm_loadu_ps(gamma); - __m128 _beta = _mm_loadu_ps(beta); - _cur = _mm_comp_fmadd_ps(_cur, _a_128, _b_128); - _cur = _mm_comp_fmadd_ps(_cur, _gamma, _beta); - _mm_storeu_ps(ptr, _cur); + __m128 _p = _mm_loadu_ps(ptr); + _p = _mm_comp_fmsub_ps(_p, _var, _mean); + _mm_storeu_ps(ptr, _p); + ptr += 4; } #endif // __SSE2__ - - for (; i < size; ++i, ++ptr, ++gamma, ++beta) + for (; i < size; i++) { - *ptr = ((*ptr) * _a + _b) * (*gamma) + (*beta); + ptr[0] = ptr[0] * var - mean; + ptr++; } } } -static NCNN_FORCEINLINE void fast_1d_layer_norm(float* ptr, int elempack, int elemcount, int size, const float* gamma, const float* beta, int affine, float eps) -{ - float mean[16] = {0.f}, var[16] = {0.f}; - fast_mean(ptr, mean, elempack, elemcount, size); - fast_var(ptr, var, mean, elempack, elemcount, size); - float *a = var, *b = mean; - -#if __SSE2__ -#if __AVX__ -#if __AVX512F__ - if (elempack == 16) - { - __m512 _a = _mm512_set1_ps(1.0f); - __m512 _eps = _mm512_set1_ps(eps); - __m512 _b = _mm512_setzero_ps(); - __m512 _var = _mm512_loadu_ps(var); - _var = _mm512_add_ps(_var, _eps); - __m512 _sqrt_var = _mm512_sqrt_ps(_var); - _a = _mm512_div_ps(_a, _sqrt_var); - __m512 _mean = _mm512_loadu_ps(mean); - _b = _mm512_fnmadd_ps(_mean, _a, _b); - - _mm512_storeu_ps(a, _a); - _mm512_storeu_ps(b, _b); - } -#endif // __AVX512F__ - if (elempack == 8) - { - __m256 _a = _mm256_set1_ps(1.0f); - __m256 _eps = _mm256_set1_ps(eps); - __m256 _b = _mm256_setzero_ps(); - __m256 _var = _mm256_loadu_ps(var); - _var = _mm256_add_ps(_var, _eps); - __m256 _sqrt_var = _mm256_sqrt_ps(_var); - _a = _mm256_div_ps(_a, _sqrt_var); - __m256 _mean = _mm256_loadu_ps(mean); - _b = _mm256_comp_fnmadd_ps(_mean, _a, _b); - - _mm256_storeu_ps(a, _a); - _mm256_storeu_ps(b, _b); - } -#endif // __AVX__ - if (elempack == 4) - { - __m128 _a = _mm_set1_ps(1.0f); - __m128 _eps = _mm_set1_ps(eps); - __m128 _b = _mm_setzero_ps(); - __m128 _var = _mm_loadu_ps(var); - _var = _mm_add_ps(_var, _eps); - __m128 _sqrt_var = _mm_sqrt_ps(_var); - _a = _mm_div_ps(_a, _sqrt_var); - __m128 _mean = _mm_loadu_ps(mean); - _b = _mm_comp_fnmadd_ps(_mean, _a, _b); - - _mm_storeu_ps(a, _a); - _mm_storeu_ps(b, _b); - } -#endif // __SSE2__ - if (elempack == 1) - { - a[0] = 1.0f / sqrtf(var[0] + eps); - b[0] = -mean[0] * (a[0]); - } - - if (affine) - { - fast_fmadd_fmadd(ptr, a, b, gamma, beta, elempack, size); - } - else - { - fast_fmadd(ptr, a, b, elempack, size); - } -} - int LayerNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const { - int dims = bottom_top_blob.dims; - int elempack = bottom_top_blob.elempack; - int w = bottom_top_blob.w; - int h = bottom_top_blob.h; - int channels = bottom_top_blob.c; - - const float* gamma = gamma_data; - const float* beta = beta_data; + const int dims = bottom_top_blob.dims; + const int elempack = bottom_top_blob.elempack; + const int w = bottom_top_blob.w; + const int h = bottom_top_blob.h; + const int channels = bottom_top_blob.c; if (dims == 1) { - int elemcount = w * elempack; + // assert affine_size == w + float* ptr = bottom_top_blob; - // 1D layer norm is special. Treat them as unpacked. - fast_1d_layer_norm(ptr, 1, elemcount, elemcount, gamma, beta, affine, eps); + layernorm(ptr, gamma_data, beta_data, eps, w * elempack, 1); } if (dims == 2) { + // assert affine_size == w + #pragma omp parallel for num_threads(opt.num_threads) - for (int i = 0; i < h; ++i) + for (int i = 0; i < h; i++) { float* ptr = bottom_top_blob.row(i); - fast_1d_layer_norm(ptr, elempack, w, w * elempack, gamma, beta, affine, eps); + layernorm(ptr, gamma_data, beta_data, eps, w, elempack); } } @@ -568,22 +549,22 @@ int LayerNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons if (affine_size == w) { #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; ++q) + for (int q = 0; q < channels; q++) { - for (int i = 0; i < h; ++i) + for (int i = 0; i < h; i++) { float* ptr = bottom_top_blob.channel(q).row(i); - fast_1d_layer_norm(ptr, elempack, w, w * elempack, gamma, beta, affine, eps); + layernorm(ptr, gamma_data, beta_data, eps, w, elempack); } } } else // if (affine_size == w * h) { #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; ++q) + for (int q = 0; q < channels; q++) { float* ptr = bottom_top_blob.channel(q); - fast_1d_layer_norm(ptr, elempack, w * h, w * h * elempack, gamma, beta, affine, eps); + layernorm(ptr, gamma_data, beta_data, eps, w * h, elempack); } } }