From 27568945c5b386eafb307da90d6332a32aca2c2a Mon Sep 17 00:00:00 2001 From: nihui Date: Mon, 9 Dec 2024 08:04:14 +0000 Subject: [PATCH] opt packa packb avx --- src/layer/x86/gemm_int8.h | 1227 ++++++++++++++++--------------------- 1 file changed, 515 insertions(+), 712 deletions(-) diff --git a/src/layer/x86/gemm_int8.h b/src/layer/x86/gemm_int8.h index 0f67eaedd5e..f0ffe3443a8 100644 --- a/src/layer/x86/gemm_int8.h +++ b/src/layer/x86/gemm_int8.h @@ -2379,7 +2379,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i { int kk = 0; #if __AVX512VNNI__ - __m512i _w_shift = _mm512_setzero_epi32(); + __m512i _w_shift = _mm512_setzero_si512(); for (; kk + 3 < max_kk; kk += 4) { __m512 _p0 = _mm512_loadu_ps(p0); @@ -2453,7 +2453,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i { int kk = 0; #if __AVX512VNNI__ - __m512i _w_shift = _mm512_setzero_epi32(); + __m512i _w_shift = _mm512_setzero_si512(); for (; kk + 3 < max_kk; kk += 4) { __m512 _p0 = _mm512_loadu_ps(p0); @@ -2537,7 +2537,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i { int kk = 0; #if __AVX512VNNI__ - __m512i _w_shift = _mm512_setzero_epi32(); + __m512i _w_shift = _mm512_setzero_si512(); for (; kk + 3 < max_kk; kk += 4) { __m512 _p0 = _mm512_loadu_ps(p0); @@ -2633,7 +2633,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i { int kk = 0; #if __AVX512VNNI__ - __m512i _w_shift = _mm512_setzero_epi32(); + __m512i _w_shift = _mm512_setzero_si512(); for (; kk + 3 < max_kk; kk += 4) { __m128 _p0 = _mm_loadu_ps(p0); @@ -2742,138 +2742,70 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i { const float* p0 = (const float*)A + (i + ii) * A_hstep + k * elempack; - const float scale0 = scales[i + ii]; - const float scale1 = scales[i + ii + 1]; - const float scale2 = scales[i + ii + 2]; - const float scale3 = scales[i + ii + 3]; - const float scale4 = scales[i + ii + 4]; - const float scale5 = scales[i + ii + 5]; - const float scale6 = scales[i + ii + 6]; - const float scale7 = scales[i + ii + 7]; + __m256 _scales = _mm256_loadu_ps((const float*)scales + i + ii); +#if __AVX512VNNI__ || __AVXVNNI__ + __m256i _v127 = _mm256_set1_epi8(127); +#endif // __AVX512VNNI__ || __AVXVNNI__ if (elempack == 8) { int kk = 0; #if __AVX512VNNI__ || __AVXVNNI__ - int w_shift0 = 0; - int w_shift1 = 0; - int w_shift2 = 0; - int w_shift3 = 0; - int w_shift4 = 0; - int w_shift5 = 0; - int w_shift6 = 0; - int w_shift7 = 0; + __m256i _w_shift = _mm256_setzero_si256(); for (; kk + 3 < max_kk; kk += 4) { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[8] * scale0); - pp[2] = float2int8(p0[16] * scale0); - pp[3] = float2int8(p0[24] * scale0); - pp[4] = float2int8(p0[1] * scale1); - pp[5] = float2int8(p0[9] * scale1); - pp[6] = float2int8(p0[17] * scale1); - pp[7] = float2int8(p0[25] * scale1); - pp[8] = float2int8(p0[2] * scale2); - pp[9] = float2int8(p0[10] * scale2); - pp[10] = float2int8(p0[18] * scale2); - pp[11] = float2int8(p0[26] * scale2); - pp[12] = float2int8(p0[3] * scale3); - pp[13] = float2int8(p0[11] * scale3); - pp[14] = float2int8(p0[19] * scale3); - pp[15] = float2int8(p0[27] * scale3); - pp[16] = float2int8(p0[4] * scale4); - pp[17] = float2int8(p0[12] * scale4); - pp[18] = float2int8(p0[20] * scale4); - pp[19] = float2int8(p0[28] * scale4); - pp[20] = float2int8(p0[5] * scale5); - pp[21] = float2int8(p0[13] * scale5); - pp[22] = float2int8(p0[21] * scale5); - pp[23] = float2int8(p0[29] * scale5); - pp[24] = float2int8(p0[6] * scale6); - pp[25] = float2int8(p0[14] * scale6); - pp[26] = float2int8(p0[22] * scale6); - pp[27] = float2int8(p0[30] * scale6); - pp[28] = float2int8(p0[7] * scale7); - pp[29] = float2int8(p0[15] * scale7); - pp[30] = float2int8(p0[23] * scale7); - pp[31] = float2int8(p0[31] * scale7); - w_shift0 += pp[0]; - w_shift0 += pp[1]; - w_shift0 += pp[2]; - w_shift0 += pp[3]; - w_shift1 += pp[4]; - w_shift1 += pp[5]; - w_shift1 += pp[6]; - w_shift1 += pp[7]; - w_shift2 += pp[8]; - w_shift2 += pp[9]; - w_shift2 += pp[10]; - w_shift2 += pp[11]; - w_shift3 += pp[12]; - w_shift3 += pp[13]; - w_shift3 += pp[14]; - w_shift3 += pp[15]; - w_shift4 += pp[16]; - w_shift4 += pp[17]; - w_shift4 += pp[18]; - w_shift4 += pp[19]; - w_shift5 += pp[20]; - w_shift5 += pp[21]; - w_shift5 += pp[22]; - w_shift5 += pp[23]; - w_shift6 += pp[24]; - w_shift6 += pp[25]; - w_shift6 += pp[26]; - w_shift6 += pp[27]; - w_shift7 += pp[28]; - w_shift7 += pp[29]; - w_shift7 += pp[30]; - w_shift7 += pp[31]; + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + 8); + __m256 _p2 = _mm256_loadu_ps(p0 + 16); + __m256 _p3 = _mm256_loadu_ps(p0 + 24); + + _p0 = _mm256_mul_ps(_p0, _scales); + _p1 = _mm256_mul_ps(_p1, _scales); + _p2 = _mm256_mul_ps(_p2, _scales); + _p3 = _mm256_mul_ps(_p3, _scales); + + __m128i _pp0 = float2int8_avx(_p0, _p2); + __m128i _pp1 = float2int8_avx(_p1, _p3); + + __m128i _tt0 = _mm_unpacklo_epi8(_pp0, _pp1); + __m128i _tt1 = _mm_unpackhi_epi8(_pp0, _pp1); + _pp0 = _mm_unpacklo_epi16(_tt0, _tt1); + _pp1 = _mm_unpackhi_epi16(_tt0, _tt1); + + __m256i _pp = combine4x2_epi32(_pp0, _pp1); + + _w_shift = _mm256_dpbusd_epi32(_w_shift, _v127, _pp); + + _mm256_storeu_si256((__m256i*)pp, _pp); + pp += 32; p0 += 32; } if (max_kk >= 4) { - ((int*)pp)[0] = w_shift0 * 127; - ((int*)pp)[1] = w_shift1 * 127; - ((int*)pp)[2] = w_shift2 * 127; - ((int*)pp)[3] = w_shift3 * 127; - ((int*)pp)[4] = w_shift4 * 127; - ((int*)pp)[5] = w_shift5 * 127; - ((int*)pp)[6] = w_shift6 * 127; - ((int*)pp)[7] = w_shift7 * 127; + _mm256_storeu_si256((__m256i*)pp, _w_shift); pp += 32; } #endif // __AVX512VNNI__ || __AVXVNNI__ for (; kk + 1 < max_kk; kk += 2) { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[8] * scale0); - pp[2] = float2int8(p0[1] * scale1); - pp[3] = float2int8(p0[9] * scale1); - pp[4] = float2int8(p0[2] * scale2); - pp[5] = float2int8(p0[10] * scale2); - pp[6] = float2int8(p0[3] * scale3); - pp[7] = float2int8(p0[11] * scale3); + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + 8); + + _p0 = _mm256_mul_ps(_p0, _scales); + _p1 = _mm256_mul_ps(_p1, _scales); + + __m128i _pp = float2int8_avx(_p0, _p1); + + __m128i _si = _mm_setr_epi8(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7 ,15); + _pp = _mm_shuffle_epi8(_pp, _si); + #if __AVX2__ - pp[8] = float2int8(p0[4] * scale4); - pp[9] = float2int8(p0[12] * scale4); - pp[10] = float2int8(p0[5] * scale5); - pp[11] = float2int8(p0[13] * scale5); - pp[12] = float2int8(p0[6] * scale6); - pp[13] = float2int8(p0[14] * scale6); - pp[14] = float2int8(p0[7] * scale7); - pp[15] = float2int8(p0[15] * scale7); + _mm_storeu_si128((__m128i*)pp, _pp); pp += 16; #else - pp1[0] = float2int8(p0[4] * scale4); - pp1[1] = float2int8(p0[12] * scale4); - pp1[2] = float2int8(p0[5] * scale5); - pp1[3] = float2int8(p0[13] * scale5); - pp1[4] = float2int8(p0[6] * scale6); - pp1[5] = float2int8(p0[14] * scale6); - pp1[6] = float2int8(p0[7] * scale7); - pp1[7] = float2int8(p0[15] * scale7); + _mm_storel_pd((double*)pp, _mm_castsi128_pd(_pp)); + _mm_storeh_pd((double*)pp1, _mm_castsi128_pd(_pp)); pp += 8; pp1 += 8; #endif @@ -2881,21 +2813,18 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i } for (; kk < max_kk; kk++) { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[1] * scale1); - pp[2] = float2int8(p0[2] * scale2); - pp[3] = float2int8(p0[3] * scale3); + __m256 _p = _mm256_loadu_ps(p0); + + _p = _mm256_mul_ps(_p, _scales); + + int64_t v = float2int8_avx(_p); + #if __AVX2__ - pp[4] = float2int8(p0[4] * scale4); - pp[5] = float2int8(p0[5] * scale5); - pp[6] = float2int8(p0[6] * scale6); - pp[7] = float2int8(p0[7] * scale7); + *(int64_t*)pp = v; pp += 8; #else - pp1[0] = float2int8(p0[4] * scale4); - pp1[1] = float2int8(p0[5] * scale5); - pp1[2] = float2int8(p0[6] * scale6); - pp1[3] = float2int8(p0[7] * scale7); + *(int32_t*)pp = (int32_t)v; + *(int32_t*)pp1 = (int32_t)(v >> 32); pp += 4; pp1 += 4; #endif @@ -2906,125 +2835,69 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i { int kk = 0; #if __AVX512VNNI__ || __AVXVNNI__ - int w_shift0 = 0; - int w_shift1 = 0; - int w_shift2 = 0; - int w_shift3 = 0; - int w_shift4 = 0; - int w_shift5 = 0; - int w_shift6 = 0; - int w_shift7 = 0; + __m256i _w_shift = _mm256_setzero_si256(); for (; kk + 3 < max_kk; kk += 4) { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[4] * scale0); - pp[2] = float2int8(p0[8] * scale0); - pp[3] = float2int8(p0[12] * scale0); - pp[4] = float2int8(p0[1] * scale1); - pp[5] = float2int8(p0[5] * scale1); - pp[6] = float2int8(p0[9] * scale1); - pp[7] = float2int8(p0[13] * scale1); - pp[8] = float2int8(p0[2] * scale2); - pp[9] = float2int8(p0[6] * scale2); - pp[10] = float2int8(p0[10] * scale2); - pp[11] = float2int8(p0[14] * scale2); - pp[12] = float2int8(p0[3] * scale3); - pp[13] = float2int8(p0[7] * scale3); - pp[14] = float2int8(p0[11] * scale3); - pp[15] = float2int8(p0[15] * scale3); - pp[16] = float2int8(p0[A_hstep * 4 + 0] * scale4); - pp[17] = float2int8(p0[A_hstep * 4 + 4] * scale4); - pp[18] = float2int8(p0[A_hstep * 4 + 8] * scale4); - pp[19] = float2int8(p0[A_hstep * 4 + 12] * scale4); - pp[20] = float2int8(p0[A_hstep * 4 + 1] * scale5); - pp[21] = float2int8(p0[A_hstep * 4 + 5] * scale5); - pp[22] = float2int8(p0[A_hstep * 4 + 9] * scale5); - pp[23] = float2int8(p0[A_hstep * 4 + 13] * scale5); - pp[24] = float2int8(p0[A_hstep * 4 + 2] * scale6); - pp[25] = float2int8(p0[A_hstep * 4 + 6] * scale6); - pp[26] = float2int8(p0[A_hstep * 4 + 10] * scale6); - pp[27] = float2int8(p0[A_hstep * 4 + 14] * scale6); - pp[28] = float2int8(p0[A_hstep * 4 + 3] * scale7); - pp[29] = float2int8(p0[A_hstep * 4 + 7] * scale7); - pp[30] = float2int8(p0[A_hstep * 4 + 11] * scale7); - pp[31] = float2int8(p0[A_hstep * 4 + 15] * scale7); - w_shift0 += pp[0]; - w_shift0 += pp[1]; - w_shift0 += pp[2]; - w_shift0 += pp[3]; - w_shift1 += pp[4]; - w_shift1 += pp[5]; - w_shift1 += pp[6]; - w_shift1 += pp[7]; - w_shift2 += pp[8]; - w_shift2 += pp[9]; - w_shift2 += pp[10]; - w_shift2 += pp[11]; - w_shift3 += pp[12]; - w_shift3 += pp[13]; - w_shift3 += pp[14]; - w_shift3 += pp[15]; - w_shift4 += pp[16]; - w_shift4 += pp[17]; - w_shift4 += pp[18]; - w_shift4 += pp[19]; - w_shift5 += pp[20]; - w_shift5 += pp[21]; - w_shift5 += pp[22]; - w_shift5 += pp[23]; - w_shift6 += pp[24]; - w_shift6 += pp[25]; - w_shift6 += pp[26]; - w_shift6 += pp[27]; - w_shift7 += pp[28]; - w_shift7 += pp[29]; - w_shift7 += pp[30]; - w_shift7 += pp[31]; + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + 8); + __m256 _p2 = _mm256_loadu_ps(p0 + A_hstep * 4); + __m256 _p3 = _mm256_loadu_ps(p0 + A_hstep * 4 + 8); + + __m256 _t0 = _mm256_permute2f128_ps(_p0, _p2, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _t1 = _mm256_permute2f128_ps(_p0, _p2, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _t2 = _mm256_permute2f128_ps(_p1, _p3, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _t3 = _mm256_permute2f128_ps(_p1, _p3, _MM_SHUFFLE(0, 3, 0, 1)); + + _t0 = _mm256_mul_ps(_t0, _scales); + _t1 = _mm256_mul_ps(_t1, _scales); + _t2 = _mm256_mul_ps(_t2, _scales); + _t3 = _mm256_mul_ps(_t3, _scales); + + __m128i _pp0 = float2int8_avx(_t0, _t2); + __m128i _pp1 = float2int8_avx(_t1, _t3); + + __m128i _tt0 = _mm_unpacklo_epi8(_pp0, _pp1); + __m128i _tt1 = _mm_unpackhi_epi8(_pp0, _pp1); + _pp0 = _mm_unpacklo_epi16(_tt0, _tt1); + _pp1 = _mm_unpackhi_epi16(_tt0, _tt1); + + __m256i _pp = combine4x2_epi32(_pp0, _pp1); + + _w_shift = _mm256_dpbusd_epi32(_w_shift, _v127, _pp); + + _mm256_storeu_si256((__m256i*)pp, _pp); + pp += 32; p0 += 16; } if (max_kk >= 4) { - ((int*)pp)[0] = w_shift0 * 127; - ((int*)pp)[1] = w_shift1 * 127; - ((int*)pp)[2] = w_shift2 * 127; - ((int*)pp)[3] = w_shift3 * 127; - ((int*)pp)[4] = w_shift4 * 127; - ((int*)pp)[5] = w_shift5 * 127; - ((int*)pp)[6] = w_shift6 * 127; - ((int*)pp)[7] = w_shift7 * 127; + _mm256_storeu_si256((__m256i*)pp, _w_shift); pp += 32; } #endif // __AVX512VNNI__ || __AVXVNNI__ for (; kk + 1 < max_kk; kk += 2) { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[4] * scale0); - pp[2] = float2int8(p0[1] * scale1); - pp[3] = float2int8(p0[5] * scale1); - pp[4] = float2int8(p0[2] * scale2); - pp[5] = float2int8(p0[6] * scale2); - pp[6] = float2int8(p0[3] * scale3); - pp[7] = float2int8(p0[7] * scale3); + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + A_hstep * 4); + + __m256 _t0 = _mm256_permute2f128_ps(_p0, _p1, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _t1 = _mm256_permute2f128_ps(_p0, _p1, _MM_SHUFFLE(0, 3, 0, 1)); + + _t0 = _mm256_mul_ps(_t0, _scales); + _t1 = _mm256_mul_ps(_t1, _scales); + + __m128i _pp = float2int8_avx(_t0, _t1); + + __m128i _si = _mm_setr_epi8(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7 ,15); + _pp = _mm_shuffle_epi8(_pp, _si); + #if __AVX2__ - pp[8] = float2int8(p0[A_hstep * 4 + 0] * scale4); - pp[9] = float2int8(p0[A_hstep * 4 + 4] * scale4); - pp[10] = float2int8(p0[A_hstep * 4 + 1] * scale5); - pp[11] = float2int8(p0[A_hstep * 4 + 5] * scale5); - pp[12] = float2int8(p0[A_hstep * 4 + 2] * scale6); - pp[13] = float2int8(p0[A_hstep * 4 + 6] * scale6); - pp[14] = float2int8(p0[A_hstep * 4 + 3] * scale7); - pp[15] = float2int8(p0[A_hstep * 4 + 7] * scale7); + _mm_storeu_si128((__m128i*)pp, _pp); pp += 16; #else - pp1[0] = float2int8(p0[A_hstep * 4 + 0] * scale4); - pp1[1] = float2int8(p0[A_hstep * 4 + 4] * scale4); - pp1[2] = float2int8(p0[A_hstep * 4 + 1] * scale5); - pp1[3] = float2int8(p0[A_hstep * 4 + 5] * scale5); - pp1[4] = float2int8(p0[A_hstep * 4 + 2] * scale6); - pp1[5] = float2int8(p0[A_hstep * 4 + 6] * scale6); - pp1[6] = float2int8(p0[A_hstep * 4 + 3] * scale7); - pp1[7] = float2int8(p0[A_hstep * 4 + 7] * scale7); + _mm_storel_pd((double*)pp, _mm_castsi128_pd(_pp)); + _mm_storeh_pd((double*)pp1, _mm_castsi128_pd(_pp)); pp += 8; pp1 += 8; #endif @@ -3032,21 +2905,20 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i } for (; kk < max_kk; kk++) { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[1] * scale1); - pp[2] = float2int8(p0[2] * scale2); - pp[3] = float2int8(p0[3] * scale3); + __m128 _p0 = _mm_loadu_ps(p0); + __m128 _p1 = _mm_loadu_ps(p0 + A_hstep * 4); + + __m256 _p = combine4x2_ps(_p0, _p1); + _p = _mm256_mul_ps(_p, _scales); + + int64_t v = float2int8_avx(_p); + #if __AVX2__ - pp[4] = float2int8(p0[A_hstep * 4] * scale4); - pp[5] = float2int8(p0[A_hstep * 4 + 1] * scale5); - pp[6] = float2int8(p0[A_hstep * 4 + 2] * scale6); - pp[7] = float2int8(p0[A_hstep * 4 + 3] * scale7); + *(int64_t*)pp = v; pp += 8; #else - pp1[0] = float2int8(p0[A_hstep * 4] * scale4); - pp1[1] = float2int8(p0[A_hstep * 4 + 1] * scale5); - pp1[2] = float2int8(p0[A_hstep * 4 + 2] * scale6); - pp1[3] = float2int8(p0[A_hstep * 4 + 3] * scale7); + *(int32_t*)pp = (int32_t)v; + *(int32_t*)pp1 = (int32_t)(v >> 32); pp += 4; pp1 += 4; #endif @@ -3057,125 +2929,88 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i { int kk = 0; #if __AVX512VNNI__ || __AVXVNNI__ - int w_shift0 = 0; - int w_shift1 = 0; - int w_shift2 = 0; - int w_shift3 = 0; - int w_shift4 = 0; - int w_shift5 = 0; - int w_shift6 = 0; - int w_shift7 = 0; + __m256i _w_shift = _mm256_setzero_si256(); for (; kk + 3 < max_kk; kk += 4) { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[1] * scale0); - pp[2] = float2int8(p0[2] * scale0); - pp[3] = float2int8(p0[3] * scale0); - pp[4] = float2int8(p0[A_hstep] * scale1); - pp[5] = float2int8(p0[A_hstep + 1] * scale1); - pp[6] = float2int8(p0[A_hstep + 2] * scale1); - pp[7] = float2int8(p0[A_hstep + 3] * scale1); - pp[8] = float2int8(p0[A_hstep * 2] * scale2); - pp[9] = float2int8(p0[A_hstep * 2 + 1] * scale2); - pp[10] = float2int8(p0[A_hstep * 2 + 2] * scale2); - pp[11] = float2int8(p0[A_hstep * 2 + 3] * scale2); - pp[12] = float2int8(p0[A_hstep * 3] * scale3); - pp[13] = float2int8(p0[A_hstep * 3 + 1] * scale3); - pp[14] = float2int8(p0[A_hstep * 3 + 2] * scale3); - pp[15] = float2int8(p0[A_hstep * 3 + 3] * scale3); - pp[16] = float2int8(p0[A_hstep * 4] * scale4); - pp[17] = float2int8(p0[A_hstep * 4 + 1] * scale4); - pp[18] = float2int8(p0[A_hstep * 4 + 2] * scale4); - pp[19] = float2int8(p0[A_hstep * 4 + 3] * scale4); - pp[20] = float2int8(p0[A_hstep * 5] * scale5); - pp[21] = float2int8(p0[A_hstep * 5 + 1] * scale5); - pp[22] = float2int8(p0[A_hstep * 5 + 2] * scale5); - pp[23] = float2int8(p0[A_hstep * 5 + 3] * scale5); - pp[24] = float2int8(p0[A_hstep * 6] * scale6); - pp[25] = float2int8(p0[A_hstep * 6 + 1] * scale6); - pp[26] = float2int8(p0[A_hstep * 6 + 2] * scale6); - pp[27] = float2int8(p0[A_hstep * 6 + 3] * scale6); - pp[28] = float2int8(p0[A_hstep * 7] * scale7); - pp[29] = float2int8(p0[A_hstep * 7 + 1] * scale7); - pp[30] = float2int8(p0[A_hstep * 7 + 2] * scale7); - pp[31] = float2int8(p0[A_hstep * 7 + 3] * scale7); - w_shift0 += pp[0]; - w_shift0 += pp[1]; - w_shift0 += pp[2]; - w_shift0 += pp[3]; - w_shift1 += pp[4]; - w_shift1 += pp[5]; - w_shift1 += pp[6]; - w_shift1 += pp[7]; - w_shift2 += pp[8]; - w_shift2 += pp[9]; - w_shift2 += pp[10]; - w_shift2 += pp[11]; - w_shift3 += pp[12]; - w_shift3 += pp[13]; - w_shift3 += pp[14]; - w_shift3 += pp[15]; - w_shift4 += pp[16]; - w_shift4 += pp[17]; - w_shift4 += pp[18]; - w_shift4 += pp[19]; - w_shift5 += pp[20]; - w_shift5 += pp[21]; - w_shift5 += pp[22]; - w_shift5 += pp[23]; - w_shift6 += pp[24]; - w_shift6 += pp[25]; - w_shift6 += pp[26]; - w_shift6 += pp[27]; - w_shift7 += pp[28]; - w_shift7 += pp[29]; - w_shift7 += pp[30]; - w_shift7 += pp[31]; + __m128 _p0 = _mm_loadu_ps(p0); + __m128 _p1 = _mm_loadu_ps(p0 + A_hstep); + __m128 _p2 = _mm_loadu_ps(p0 + A_hstep * 2); + __m128 _p3 = _mm_loadu_ps(p0 + A_hstep * 3); + __m128 _p4 = _mm_loadu_ps(p0 + A_hstep * 4); + __m128 _p5 = _mm_loadu_ps(p0 + A_hstep * 5); + __m128 _p6 = _mm_loadu_ps(p0 + A_hstep * 6); + __m128 _p7 = _mm_loadu_ps(p0 + A_hstep * 7); + + __m256 _t0 = combine4x2_ps(_p0, _p4); + __m256 _t1 = combine4x2_ps(_p1, _p5); + __m256 _t2 = combine4x2_ps(_p2, _p6); + __m256 _t3 = combine4x2_ps(_p3, _p7); + + __m256 _t4 = _mm256_unpacklo_ps(_t0, _t1); + __m256 _t5 = _mm256_unpackhi_ps(_t0, _t1); + __m256 _t6 = _mm256_unpacklo_ps(_t2, _t3); + __m256 _t7 = _mm256_unpackhi_ps(_t2, _t3); + + _t0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(_t4), _mm256_castps_pd(_t6))); + _t1 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(_t4), _mm256_castps_pd(_t6))); + _t2 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(_t5), _mm256_castps_pd(_t7))); + _t3 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(_t5), _mm256_castps_pd(_t7))); + + _t0 = _mm256_mul_ps(_t0, _scales); + _t1 = _mm256_mul_ps(_t1, _scales); + _t2 = _mm256_mul_ps(_t2, _scales); + _t3 = _mm256_mul_ps(_t3, _scales); + + __m128i _pp0 = float2int8_avx(_t0, _t2); + __m128i _pp1 = float2int8_avx(_t1, _t3); + + __m128i _tt0 = _mm_unpacklo_epi8(_pp0, _pp1); + __m128i _tt1 = _mm_unpackhi_epi8(_pp0, _pp1); + _pp0 = _mm_unpacklo_epi16(_tt0, _tt1); + _pp1 = _mm_unpackhi_epi16(_tt0, _tt1); + + __m256i _pp = combine4x2_epi32(_pp0, _pp1); + + _w_shift = _mm256_comp_dpbusd_epi32(_w_shift, _v127, _pp); + + _mm256_storeu_si256((__m256i*)pp, _pp); + pp += 32; p0 += 4; } if (max_kk >= 4) { - ((int*)pp)[0] = w_shift0 * 127; - ((int*)pp)[1] = w_shift1 * 127; - ((int*)pp)[2] = w_shift2 * 127; - ((int*)pp)[3] = w_shift3 * 127; - ((int*)pp)[4] = w_shift4 * 127; - ((int*)pp)[5] = w_shift5 * 127; - ((int*)pp)[6] = w_shift6 * 127; - ((int*)pp)[7] = w_shift7 * 127; + _mm256_storeu_si256((__m256i*)pp, _w_shift); pp += 32; } #endif // __AVX512VNNI__ || __AVXVNNI__ for (; kk + 1 < max_kk; kk += 2) { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[1] * scale0); - pp[2] = float2int8(p0[A_hstep] * scale1); - pp[3] = float2int8(p0[A_hstep + 1] * scale1); - pp[4] = float2int8(p0[A_hstep * 2] * scale2); - pp[5] = float2int8(p0[A_hstep * 2 + 1] * scale2); - pp[6] = float2int8(p0[A_hstep * 3] * scale3); - pp[7] = float2int8(p0[A_hstep * 3 + 1] * scale3); #if __AVX2__ - pp[8] = float2int8(p0[A_hstep * 4] * scale4); - pp[9] = float2int8(p0[A_hstep * 4 + 1] * scale4); - pp[10] = float2int8(p0[A_hstep * 5] * scale5); - pp[11] = float2int8(p0[A_hstep * 5 + 1] * scale5); - pp[12] = float2int8(p0[A_hstep * 6] * scale6); - pp[13] = float2int8(p0[A_hstep * 6 + 1] * scale6); - pp[14] = float2int8(p0[A_hstep * 7] * scale7); - pp[15] = float2int8(p0[A_hstep * 7 + 1] * scale7); + __m256i _vindex = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + _vindex = _mm256_mullo_epi32(_vindex, _mm256_set1_epi32(A_hstep)); + + __m256 _p0 = _mm256_i32gather_ps(p0, _vindex, sizeof(float)); + __m256 _p1 = _mm256_i32gather_ps(p0 + 1, _vindex, sizeof(float)); +#else + __m256 _p0 = _mm256_setr_ps(p0[0], p0[A_hstep], p0[A_hstep * 2], p0[A_hstep * 3], p0[A_hstep * 4], p0[A_hstep * 5], p0[A_hstep * 6], p0[A_hstep * 7]); + __m256 _p1 = _mm256_setr_ps(p0[1], p0[A_hstep + 1], p0[A_hstep * 2 + 1], p0[A_hstep * 3 + 1], p0[A_hstep * 4 + 1], p0[A_hstep * 5 + 1], p0[A_hstep * 6 + 1], p0[A_hstep * 7 + 1]); +#endif + + _p0 = _mm256_mul_ps(_p0, _scales); + _p1 = _mm256_mul_ps(_p1, _scales); + + __m128i _pp = float2int8_avx(_p0, _p1); + + __m128i _si = _mm_setr_epi8(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7 ,15); + _pp = _mm_shuffle_epi8(_pp, _si); + +#if __AVX2__ + _mm_storeu_si128((__m128i*)pp, _pp); pp += 16; #else - pp1[0] = float2int8(p0[A_hstep * 4] * scale4); - pp1[1] = float2int8(p0[A_hstep * 4 + 1] * scale4); - pp1[2] = float2int8(p0[A_hstep * 5] * scale5); - pp1[3] = float2int8(p0[A_hstep * 5 + 1] * scale5); - pp1[4] = float2int8(p0[A_hstep * 6] * scale6); - pp1[5] = float2int8(p0[A_hstep * 6 + 1] * scale6); - pp1[6] = float2int8(p0[A_hstep * 7] * scale7); - pp1[7] = float2int8(p0[A_hstep * 7 + 1] * scale7); + _mm_storel_pd((double*)pp, _mm_castsi128_pd(_pp)); + _mm_storeh_pd((double*)pp1, _mm_castsi128_pd(_pp)); pp += 8; pp1 += 8; #endif @@ -3183,21 +3018,25 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i } for (; kk < max_kk; kk++) { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[A_hstep] * scale1); - pp[2] = float2int8(p0[A_hstep * 2] * scale2); - pp[3] = float2int8(p0[A_hstep * 3] * scale3); #if __AVX2__ - pp[4] = float2int8(p0[A_hstep * 4] * scale4); - pp[5] = float2int8(p0[A_hstep * 5] * scale5); - pp[6] = float2int8(p0[A_hstep * 6] * scale6); - pp[7] = float2int8(p0[A_hstep * 7] * scale7); + __m256i _vindex = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + _vindex = _mm256_mullo_epi32(_vindex, _mm256_set1_epi32(A_hstep)); + + __m256 _p = _mm256_i32gather_ps(p0, _vindex, sizeof(float)); +#else + __m256 _p = _mm256_setr_ps(p0[0], p0[A_hstep], p0[A_hstep * 2], p0[A_hstep * 3], p0[A_hstep * 4], p0[A_hstep * 5], p0[A_hstep * 6], p0[A_hstep * 7]); +#endif + + _p = _mm256_mul_ps(_p, _scales); + + int64_t v = float2int8_avx(_p); + +#if __AVX2__ + *(int64_t*)pp = v; pp += 8; #else - pp1[0] = float2int8(p0[A_hstep * 4] * scale4); - pp1[1] = float2int8(p0[A_hstep * 5] * scale5); - pp1[2] = float2int8(p0[A_hstep * 6] * scale6); - pp1[3] = float2int8(p0[A_hstep * 7] * scale7); + *(int32_t*)pp = (int32_t)v; + *(int32_t*)pp1 = (int32_t)(v >> 32); pp += 4; pp1 += 4; #endif @@ -4302,7 +4141,7 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int { int kk = 0; #if __AVX512VNNI__ - __m512i _w_shift = _mm512_setzero_epi32(); + __m512i _w_shift = _mm512_setzero_si512(); for (; kk + 15 < max_kk; kk += 16) { __m512 _p0 = _mm512_loadu_ps(p0); @@ -4490,7 +4329,7 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int int kk = 0; #if __AVX512VNNI__ - __m512i _w_shift = _mm512_setzero_epi32(); + __m512i _w_shift = _mm512_setzero_si512(); for (; kk + 7 < max_kk; kk += 8) { __m512 _p0 = _mm512_loadu_ps(p0); @@ -4602,7 +4441,7 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int int kk = 0; #if __AVX512VNNI__ - __m512i _w_shift = _mm512_setzero_epi32(); + __m512i _w_shift = _mm512_setzero_si512(); for (; kk + 3 < max_kk; kk += 4) { __m512 _p0 = _mm512_loadu_ps(p0); @@ -4674,7 +4513,7 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int { int kk = 0; #if __AVX512VNNI__ - __m512i _w_shift = _mm512_setzero_epi32(); + __m512i _w_shift = _mm512_setzero_si512(); for (; kk + 3 < max_kk; kk += 4) { __m512 _p0 = _mm512_loadu_ps(p0); @@ -7254,82 +7093,76 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i { const float* p0 = (const float*)B + (j + jj) * B_hstep + k * elempack; -#if __AVX__ - if (elempack == 8) - { - int kk = 0; -#if __AVX512VNNI__ || __AVXVNNI__ - for (; kk + 3 < max_kk; kk += 4) - { - pp[0] = float2int8(p0[0] * scale) + 127; - pp[1] = float2int8(p0[8] * scale) + 127; - pp[2] = float2int8(p0[16] * scale) + 127; - pp[3] = float2int8(p0[24] * scale) + 127; - pp[4] = float2int8(p0[1] * scale) + 127; - pp[5] = float2int8(p0[9] * scale) + 127; - pp[6] = float2int8(p0[17] * scale) + 127; - pp[7] = float2int8(p0[25] * scale) + 127; - pp[8] = float2int8(p0[2] * scale) + 127; - pp[9] = float2int8(p0[10] * scale) + 127; - pp[10] = float2int8(p0[18] * scale) + 127; - pp[11] = float2int8(p0[26] * scale) + 127; - pp[12] = float2int8(p0[3] * scale) + 127; - pp[13] = float2int8(p0[11] * scale) + 127; - pp[14] = float2int8(p0[19] * scale) + 127; - pp[15] = float2int8(p0[27] * scale) + 127; - pp[16 + 0] = float2int8(p0[4] * scale) + 127; - pp[16 + 1] = float2int8(p0[12] * scale) + 127; - pp[16 + 2] = float2int8(p0[20] * scale) + 127; - pp[16 + 3] = float2int8(p0[28] * scale) + 127; - pp[16 + 4] = float2int8(p0[5] * scale) + 127; - pp[16 + 5] = float2int8(p0[13] * scale) + 127; - pp[16 + 6] = float2int8(p0[21] * scale) + 127; - pp[16 + 7] = float2int8(p0[29] * scale) + 127; - pp[16 + 8] = float2int8(p0[6] * scale) + 127; - pp[16 + 9] = float2int8(p0[14] * scale) + 127; - pp[16 + 10] = float2int8(p0[22] * scale) + 127; - pp[16 + 11] = float2int8(p0[30] * scale) + 127; - pp[16 + 12] = float2int8(p0[7] * scale) + 127; - pp[16 + 13] = float2int8(p0[15] * scale) + 127; - pp[16 + 14] = float2int8(p0[23] * scale) + 127; - pp[16 + 15] = float2int8(p0[31] * scale) + 127; +#if __AVX__ + __m256 _scale = _mm256_set1_ps(scale); +#if __AVX512VNNI__ || __AVXVNNI__ + __m256i _v127 = _mm256_set1_epi8(127); +#endif // __AVX512VNNI__ || __AVXVNNI__ +#else + __m128 _scale = _mm_set1_ps(scale); +#endif // __AVX__ + +#if __AVX__ + if (elempack == 8) + { + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + 8); + __m256 _p2 = _mm256_loadu_ps(p0 + 16); + __m256 _p3 = _mm256_loadu_ps(p0 + 24); + + _p0 = _mm256_mul_ps(_p0, _scale); + _p1 = _mm256_mul_ps(_p1, _scale); + _p2 = _mm256_mul_ps(_p2, _scale); + _p3 = _mm256_mul_ps(_p3, _scale); + + __m128i _pp0 = float2int8_avx(_p0, _p2); + __m128i _pp1 = float2int8_avx(_p1, _p3); + + __m128i _tt0 = _mm_unpacklo_epi8(_pp0, _pp1); + __m128i _tt1 = _mm_unpackhi_epi8(_pp0, _pp1); + _pp0 = _mm_unpacklo_epi16(_tt0, _tt1); + _pp1 = _mm_unpackhi_epi16(_tt0, _tt1); + + __m256i _pp = combine4x2_epi32(_pp0, _pp1); + + _pp = _mm256_add_epi8(_pp, _v127); + + _mm256_storeu_si256((__m256i*)pp, _pp); + pp += 32; p0 += 32; } #endif // __AVX512VNNI__ || __AVXVNNI__ for (; kk + 1 < max_kk; kk += 2) { - pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[8] * scale); - pp[2] = float2int8(p0[1] * scale); - pp[3] = float2int8(p0[9] * scale); - pp[4] = float2int8(p0[2] * scale); - pp[5] = float2int8(p0[10] * scale); - pp[6] = float2int8(p0[3] * scale); - pp[7] = float2int8(p0[11] * scale); - pp[8] = float2int8(p0[4] * scale); - pp[9] = float2int8(p0[12] * scale); - pp[10] = float2int8(p0[5] * scale); - pp[11] = float2int8(p0[13] * scale); - pp[12] = float2int8(p0[6] * scale); - pp[13] = float2int8(p0[14] * scale); - pp[14] = float2int8(p0[7] * scale); - pp[15] = float2int8(p0[15] * scale); + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + 8); + _p0 = _mm256_mul_ps(_p0, _scale); + _p1 = _mm256_mul_ps(_p1, _scale); + + __m128i _pp = float2int8_avx(_p0, _p1); + + __m128i _si = _mm_setr_epi8(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7 ,15); + _pp = _mm_shuffle_epi8(_pp, _si); + + _mm_storeu_si128((__m128i*)pp, _pp); pp += 16; p0 += 16; } for (; kk < max_kk; kk++) { - pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[1] * scale); - pp[2] = float2int8(p0[2] * scale); - pp[3] = float2int8(p0[3] * scale); - pp[4] = float2int8(p0[4] * scale); - pp[5] = float2int8(p0[5] * scale); - pp[6] = float2int8(p0[6] * scale); - pp[7] = float2int8(p0[7] * scale); + __m256 _p = _mm256_loadu_ps(p0); + _p = _mm256_mul_ps(_p, _scale); + + int64_t v = float2int8_avx(_p); + + *(int64_t*)pp = v; pp += 8; p0 += 8; } @@ -7341,75 +7174,86 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i #if __AVX512VNNI__ || __AVXVNNI__ for (; kk + 3 < max_kk; kk += 4) { - pp[0] = float2int8(p0[0] * scale) + 127; - pp[1] = float2int8(p0[4] * scale) + 127; - pp[2] = float2int8(p0[8] * scale) + 127; - pp[3] = float2int8(p0[12] * scale) + 127; - pp[4] = float2int8(p0[1] * scale) + 127; - pp[5] = float2int8(p0[5] * scale) + 127; - pp[6] = float2int8(p0[9] * scale) + 127; - pp[7] = float2int8(p0[13] * scale) + 127; - pp[8] = float2int8(p0[2] * scale) + 127; - pp[9] = float2int8(p0[6] * scale) + 127; - pp[10] = float2int8(p0[10] * scale) + 127; - pp[11] = float2int8(p0[14] * scale) + 127; - pp[12] = float2int8(p0[3] * scale) + 127; - pp[13] = float2int8(p0[7] * scale) + 127; - pp[14] = float2int8(p0[11] * scale) + 127; - pp[15] = float2int8(p0[15] * scale) + 127; - pp[16] = float2int8(p0[B_hstep * 4 + 0] * scale) + 127; - pp[17] = float2int8(p0[B_hstep * 4 + 4] * scale) + 127; - pp[18] = float2int8(p0[B_hstep * 4 + 8] * scale) + 127; - pp[19] = float2int8(p0[B_hstep * 4 + 12] * scale) + 127; - pp[20] = float2int8(p0[B_hstep * 4 + 1] * scale) + 127; - pp[21] = float2int8(p0[B_hstep * 4 + 5] * scale) + 127; - pp[22] = float2int8(p0[B_hstep * 4 + 9] * scale) + 127; - pp[23] = float2int8(p0[B_hstep * 4 + 13] * scale) + 127; - pp[24] = float2int8(p0[B_hstep * 4 + 2] * scale) + 127; - pp[25] = float2int8(p0[B_hstep * 4 + 6] * scale) + 127; - pp[26] = float2int8(p0[B_hstep * 4 + 10] * scale) + 127; - pp[27] = float2int8(p0[B_hstep * 4 + 14] * scale) + 127; - pp[28] = float2int8(p0[B_hstep * 4 + 3] * scale) + 127; - pp[29] = float2int8(p0[B_hstep * 4 + 7] * scale) + 127; - pp[30] = float2int8(p0[B_hstep * 4 + 11] * scale) + 127; - pp[31] = float2int8(p0[B_hstep * 4 + 15] * scale) + 127; + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + 8); + __m256 _p2 = _mm256_loadu_ps(p0 + B_hstep * 4); + __m256 _p3 = _mm256_loadu_ps(p0 + B_hstep * 4 + 8); + + _p0 = _mm256_mul_ps(_p0, _scale); + _p1 = _mm256_mul_ps(_p1, _scale); + _p2 = _mm256_mul_ps(_p2, _scale); + _p3 = _mm256_mul_ps(_p3, _scale); + + __m128i _pp0 = float2int8_avx(_p0, _p1); + __m128i _pp1 = float2int8_avx(_p2, _p3); + + __m256i _pp = combine4x2_epi32(_pp0, _pp1); + + _pp = _mm256_add_epi8(_pp, _v127); + + __m128i _si = _mm_setr_epi8(0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15); + _pp = _mm256_shuffle_epi8(_pp, combine4x2_epi32(_si, _si)); + + _mm256_storeu_si256((__m256i*)pp, _pp); + pp += 32; p0 += 16; } #endif // __AVX512VNNI__ || __AVXVNNI__ for (; kk + 1 < max_kk; kk += 2) { - pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[4] * scale); - pp[2] = float2int8(p0[1] * scale); - pp[3] = float2int8(p0[5] * scale); - pp[4] = float2int8(p0[2] * scale); - pp[5] = float2int8(p0[6] * scale); - pp[6] = float2int8(p0[3] * scale); - pp[7] = float2int8(p0[7] * scale); - pp[8] = float2int8(p0[B_hstep * 4] * scale); - pp[9] = float2int8(p0[B_hstep * 4 + 4] * scale); - pp[10] = float2int8(p0[B_hstep * 4 + 1] * scale); - pp[11] = float2int8(p0[B_hstep * 4 + 5] * scale); - pp[12] = float2int8(p0[B_hstep * 4 + 2] * scale); - pp[13] = float2int8(p0[B_hstep * 4 + 6] * scale); - pp[14] = float2int8(p0[B_hstep * 4 + 3] * scale); - pp[15] = float2int8(p0[B_hstep * 4 + 7] * scale); +#if __AVX__ + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + B_hstep * 4); + + _p0 = _mm256_mul_ps(_p0, _scale); + _p1 = _mm256_mul_ps(_p1, _scale); + + __m128i _pp = float2int8_avx(_p0, _p1); + + __m128i _si = _mm_setr_epi8(0, 4, 1, 5, 2, 6, 3, 7, 8, 12, 9, 13, 10, 14, 11, 15); + _pp = _mm_shuffle_epi8(_pp, _si); +#else // __AVX__ + __m128 _p0 = _mm_loadu_ps(p0); + __m128 _p1 = _mm_loadu_ps(p0 + 4); + __m128 _p2 = _mm_loadu_ps(p0 + B_hstep * 4); + __m128 _p3 = _mm_loadu_ps(p0 + B_hstep * 4 + 4); + + __m128 _t0 = _mm_unpacklo_ps(_p0, _p1); + __m128 _t1 = _mm_unpackhi_ps(_p0, _p1); + __m128 _t2 = _mm_unpacklo_ps(_p2, _p3); + __m128 _t3 = _mm_unpackhi_ps(_p2, _p3); + _t0 = _mm_mul_ps(_t0, _scale); + _t1 = _mm_mul_ps(_t1, _scale); + _t2 = _mm_mul_ps(_t2, _scale); + _t3 = _mm_mul_ps(_t3, _scale); + + __m128i _pp = float2int8_sse(_t0, _t1, _t2, _t3); +#endif // __AVX__ + + _mm_storeu_si128((__m128i*)pp, _pp); pp += 16; p0 += 8; } for (; kk < max_kk; kk++) { - pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[1] * scale); - pp[2] = float2int8(p0[2] * scale); - pp[3] = float2int8(p0[3] * scale); - pp[4] = float2int8(p0[B_hstep * 4] * scale); - pp[5] = float2int8(p0[B_hstep * 4 + 1] * scale); - pp[6] = float2int8(p0[B_hstep * 4 + 2] * scale); - pp[7] = float2int8(p0[B_hstep * 4 + 3] * scale); + __m128 _p0 = _mm_loadu_ps(p0); + __m128 _p1 = _mm_loadu_ps(p0 + B_hstep * 4); + +#if __AVX__ + __m256 _p = combine4x2_ps(_p0, _p1); + _p = _mm256_mul_ps(_p, _scale); + int64_t v = float2int8_avx(_p); +#else // __AVX__ + _p0 = _mm_mul_ps(_p0, _scale); + _p1 = _mm_mul_ps(_p1, _scale); + + int64_t v = float2int8_sse(_p0, _p1); +#endif // __AVX__ + + *(int64_t*)pp = v; pp += 8; p0 += 4; } @@ -7420,75 +7264,105 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i #if __AVX512VNNI__ || __AVXVNNI__ for (; kk + 3 < max_kk; kk += 4) { - pp[0] = float2int8(p0[0] * scale) + 127; - pp[1] = float2int8(p0[1] * scale) + 127; - pp[2] = float2int8(p0[2] * scale) + 127; - pp[3] = float2int8(p0[3] * scale) + 127; - pp[4] = float2int8(p0[B_hstep] * scale) + 127; - pp[5] = float2int8(p0[B_hstep + 1] * scale) + 127; - pp[6] = float2int8(p0[B_hstep + 2] * scale) + 127; - pp[7] = float2int8(p0[B_hstep + 3] * scale) + 127; - pp[8] = float2int8(p0[B_hstep * 2] * scale) + 127; - pp[9] = float2int8(p0[B_hstep * 2 + 1] * scale) + 127; - pp[10] = float2int8(p0[B_hstep * 2 + 2] * scale) + 127; - pp[11] = float2int8(p0[B_hstep * 2 + 3] * scale) + 127; - pp[12] = float2int8(p0[B_hstep * 3] * scale) + 127; - pp[13] = float2int8(p0[B_hstep * 3 + 1] * scale) + 127; - pp[14] = float2int8(p0[B_hstep * 3 + 2] * scale) + 127; - pp[15] = float2int8(p0[B_hstep * 3 + 3] * scale) + 127; - pp[16] = float2int8(p0[B_hstep * 4] * scale) + 127; - pp[17] = float2int8(p0[B_hstep * 4 + 1] * scale) + 127; - pp[18] = float2int8(p0[B_hstep * 4 + 2] * scale) + 127; - pp[19] = float2int8(p0[B_hstep * 4 + 3] * scale) + 127; - pp[20] = float2int8(p0[B_hstep * 5] * scale) + 127; - pp[21] = float2int8(p0[B_hstep * 5 + 1] * scale) + 127; - pp[22] = float2int8(p0[B_hstep * 5 + 2] * scale) + 127; - pp[23] = float2int8(p0[B_hstep * 5 + 3] * scale) + 127; - pp[24] = float2int8(p0[B_hstep * 6] * scale) + 127; - pp[25] = float2int8(p0[B_hstep * 6 + 1] * scale) + 127; - pp[26] = float2int8(p0[B_hstep * 6 + 2] * scale) + 127; - pp[27] = float2int8(p0[B_hstep * 6 + 3] * scale) + 127; - pp[28] = float2int8(p0[B_hstep * 7] * scale) + 127; - pp[29] = float2int8(p0[B_hstep * 7 + 1] * scale) + 127; - pp[30] = float2int8(p0[B_hstep * 7 + 2] * scale) + 127; - pp[31] = float2int8(p0[B_hstep * 7 + 3] * scale) + 127; + __m128 _p0 = _mm_loadu_ps(p0); + __m128 _p1 = _mm_loadu_ps(p0 + B_hstep); + __m128 _p2 = _mm_loadu_ps(p0 + B_hstep * 2); + __m128 _p3 = _mm_loadu_ps(p0 + B_hstep * 3); + __m128 _p4 = _mm_loadu_ps(p0 + B_hstep * 4); + __m128 _p5 = _mm_loadu_ps(p0 + B_hstep * 5); + __m128 _p6 = _mm_loadu_ps(p0 + B_hstep * 6); + __m128 _p7 = _mm_loadu_ps(p0 + B_hstep * 7); + + __m256 _t0 = combine4x2_ps(_p0, _p1); + __m256 _t1 = combine4x2_ps(_p2, _p3); + __m256 _t2 = combine4x2_ps(_p4, _p5); + __m256 _t3 = combine4x2_ps(_p6, _p7); + + _t0 = _mm256_mul_ps(_t0, _scale); + _t1 = _mm256_mul_ps(_t1, _scale); + _t2 = _mm256_mul_ps(_t2, _scale); + _t3 = _mm256_mul_ps(_t3, _scale); + + __m128i _pp0 = float2int8_avx(_t0, _t1); + __m128i _pp1 = float2int8_avx(_t2, _t3); + + __m256i _pp = combine4x2_epi32(_pp0, _pp1); + + _pp = _mm256_add_epi8(_pp, _v127); + + _mm256_storeu_si256((__m256i*)pp, _pp); + pp += 32; p0 += 4; } #endif // __AVX512VNNI__ || __AVXVNNI__ for (; kk + 1 < max_kk; kk += 2) { - pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[1] * scale); - pp[2] = float2int8(p0[B_hstep] * scale); - pp[3] = float2int8(p0[B_hstep + 1] * scale); - pp[4] = float2int8(p0[B_hstep * 2] * scale); - pp[5] = float2int8(p0[B_hstep * 2 + 1] * scale); - pp[6] = float2int8(p0[B_hstep * 3] * scale); - pp[7] = float2int8(p0[B_hstep * 3 + 1] * scale); - pp[8] = float2int8(p0[B_hstep * 4] * scale); - pp[9] = float2int8(p0[B_hstep * 4 + 1] * scale); - pp[10] = float2int8(p0[B_hstep * 5] * scale); - pp[11] = float2int8(p0[B_hstep * 5 + 1] * scale); - pp[12] = float2int8(p0[B_hstep * 6] * scale); - pp[13] = float2int8(p0[B_hstep * 6 + 1] * scale); - pp[14] = float2int8(p0[B_hstep * 7] * scale); - pp[15] = float2int8(p0[B_hstep * 7 + 1] * scale); +#if __AVX__ +#if __AVX2__ + __m256i _vindex = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + _vindex = _mm256_mullo_epi32(_vindex, _mm256_set1_epi32(B_hstep)); + __m256 _p0 = _mm256_i32gather_ps(p0, _vindex, sizeof(float)); + __m256 _p1 = _mm256_i32gather_ps(p0 + 1, _vindex, sizeof(float)); +#else + __m256 _p0 = _mm256_setr_ps(p0[0], p0[1], p0[B_hstep], p0[B_hstep + 1], p0[B_hstep * 2], p0[B_hstep * 2 + 1], p0[B_hstep * 3], p0[B_hstep * 3 + 1]); + __m256 _p1 = _mm256_setr_ps(p0[B_hstep * 4], p0[B_hstep * 4 + 1], p0[B_hstep * 5], p0[B_hstep * 5 + 1], p0[B_hstep * 6], p0[B_hstep * 6 + 1], p0[B_hstep * 7], p0[B_hstep * 7 + 1]); +#endif + + _p0 = _mm256_mul_ps(_p0, _scale); + _p1 = _mm256_mul_ps(_p1, _scale); + + __m128i _pp = float2int8_avx(_p0, _p1); + +#if __AVX2__ + __m128i _si = _mm_setr_epi8(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7 ,15); + _pp = _mm_shuffle_epi8(_pp, _si); +#endif +#else // __AVX__ + __m128 _p0 = _mm_setr_ps(p0[0], p0[1], p0[B_hstep], p0[B_hstep + 1]); + __m128 _p1 = _mm_setr_ps(p0[B_hstep * 2], p0[B_hstep * 2 + 1], p0[B_hstep * 3], p0[B_hstep * 3 + 1]); + __m128 _p2 = _mm_setr_ps(p0[B_hstep * 4], p0[B_hstep * 4 + 1], p0[B_hstep * 5], p0[B_hstep * 5 + 1]); + __m128 _p3 = _mm_setr_ps(p0[B_hstep * 6], p0[B_hstep * 6 + 1], p0[B_hstep * 7], p0[B_hstep * 7 + 1]); + + _p0 = _mm_mul_ps(_p0, _scale); + _p1 = _mm_mul_ps(_p1, _scale); + _p2 = _mm_mul_ps(_p2, _scale); + _p3 = _mm_mul_ps(_p3, _scale); + + __m128i _pp = float2int8_sse(_p0, _p1, _p2, _p3); +#endif // __AVX__ + + _mm_storeu_si128((__m128i*)pp, _pp); pp += 16; p0 += 2; } for (; kk < max_kk; kk++) { - pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[B_hstep] * scale); - pp[2] = float2int8(p0[B_hstep * 2] * scale); - pp[3] = float2int8(p0[B_hstep * 3] * scale); - pp[4] = float2int8(p0[B_hstep * 4] * scale); - pp[5] = float2int8(p0[B_hstep * 5] * scale); - pp[6] = float2int8(p0[B_hstep * 6] * scale); - pp[7] = float2int8(p0[B_hstep * 7] * scale); +#if __AVX__ +#if __AVX2__ + __m256i _vindex = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + _vindex = _mm256_mullo_epi32(_vindex, _mm256_set1_epi32(B_hstep)); + __m256 _p = _mm256_i32gather_ps(p0, _vindex, sizeof(float)); +#else + __m256 _p = _mm256_setr_ps(p0[0], p0[B_hstep], p0[B_hstep * 2], p0[B_hstep * 3], p0[B_hstep * 4], p0[B_hstep * 5], p0[B_hstep * 6], p0[B_hstep * 7]); +#endif + + _p = _mm256_mul_ps(_p, _scale); + + int64_t v = float2int8_avx(_p); +#else // __AVX__ + __m128 _p0 = _mm_setr_ps(p0[0], p0[B_hstep], p0[B_hstep * 2], p0[B_hstep * 3]); + __m128 _p1 = _mm_setr_ps(p0[B_hstep * 4], p0[B_hstep * 5], p0[B_hstep * 6], p0[B_hstep * 7]); + + _p0 = _mm_mul_ps(_p0, _scale); + _p1 = _mm_mul_ps(_p1, _scale); + + int64_t v = float2int8_sse(_p0, _p1); +#endif // __AVX__ + + *(int64_t*)pp = v; pp += 8; p0++; } @@ -7499,6 +7373,11 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i #if __AVX512F__ if (elempack == 16) { + __m512 _scale = _mm512_set1_ps(scale); +#if __AVX512VNNI__ || __AVXVNNI__ + __m128i _v127 = _mm_set1_epi8(127); +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; jj + 15 < max_jj; jj += 16) { const float* p0 = (const float*)B + (j + jj) * B_hstep + k * elempack; @@ -7511,73 +7390,33 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i #if __AVX512VNNI__ for (; kk + 3 < max_kk; kk += 4) { - pp[0] = float2int8(p0[0] * scale) + 127; - pp[1] = float2int8(p0[16] * scale) + 127; - pp[2] = float2int8(p0[32] * scale) + 127; - pp[3] = float2int8(p0[48] * scale) + 127; - pp[4] = float2int8(p0[1] * scale) + 127; - pp[5] = float2int8(p0[17] * scale) + 127; - pp[6] = float2int8(p0[33] * scale) + 127; - pp[7] = float2int8(p0[49] * scale) + 127; - pp[8] = float2int8(p0[2] * scale) + 127; - pp[9] = float2int8(p0[18] * scale) + 127; - pp[10] = float2int8(p0[34] * scale) + 127; - pp[11] = float2int8(p0[50] * scale) + 127; - pp[12] = float2int8(p0[3] * scale) + 127; - pp[13] = float2int8(p0[19] * scale) + 127; - pp[14] = float2int8(p0[35] * scale) + 127; - pp[15] = float2int8(p0[51] * scale) + 127; - - pp1[0] = float2int8(p0[4] * scale) + 127; - pp1[1] = float2int8(p0[20] * scale) + 127; - pp1[2] = float2int8(p0[36] * scale) + 127; - pp1[3] = float2int8(p0[52] * scale) + 127; - pp1[4] = float2int8(p0[5] * scale) + 127; - pp1[5] = float2int8(p0[21] * scale) + 127; - pp1[6] = float2int8(p0[37] * scale) + 127; - pp1[7] = float2int8(p0[53] * scale) + 127; - pp1[8] = float2int8(p0[6] * scale) + 127; - pp1[9] = float2int8(p0[22] * scale) + 127; - pp1[10] = float2int8(p0[38] * scale) + 127; - pp1[11] = float2int8(p0[54] * scale) + 127; - pp1[12] = float2int8(p0[7] * scale) + 127; - pp1[13] = float2int8(p0[23] * scale) + 127; - pp1[14] = float2int8(p0[39] * scale) + 127; - pp1[15] = float2int8(p0[55] * scale) + 127; - - pp2[0] = float2int8(p0[8] * scale) + 127; - pp2[1] = float2int8(p0[24] * scale) + 127; - pp2[2] = float2int8(p0[40] * scale) + 127; - pp2[3] = float2int8(p0[56] * scale) + 127; - pp2[4] = float2int8(p0[9] * scale) + 127; - pp2[5] = float2int8(p0[25] * scale) + 127; - pp2[6] = float2int8(p0[41] * scale) + 127; - pp2[7] = float2int8(p0[57] * scale) + 127; - pp2[8] = float2int8(p0[10] * scale) + 127; - pp2[9] = float2int8(p0[26] * scale) + 127; - pp2[10] = float2int8(p0[42] * scale) + 127; - pp2[11] = float2int8(p0[58] * scale) + 127; - pp2[12] = float2int8(p0[11] * scale) + 127; - pp2[13] = float2int8(p0[27] * scale) + 127; - pp2[14] = float2int8(p0[43] * scale) + 127; - pp2[15] = float2int8(p0[59] * scale) + 127; - - pp3[0] = float2int8(p0[12] * scale) + 127; - pp3[1] = float2int8(p0[28] * scale) + 127; - pp3[2] = float2int8(p0[44] * scale) + 127; - pp3[3] = float2int8(p0[60] * scale) + 127; - pp3[4] = float2int8(p0[13] * scale) + 127; - pp3[5] = float2int8(p0[29] * scale) + 127; - pp3[6] = float2int8(p0[45] * scale) + 127; - pp3[7] = float2int8(p0[61] * scale) + 127; - pp3[8] = float2int8(p0[14] * scale) + 127; - pp3[9] = float2int8(p0[30] * scale) + 127; - pp3[10] = float2int8(p0[46] * scale) + 127; - pp3[11] = float2int8(p0[62] * scale) + 127; - pp3[12] = float2int8(p0[15] * scale) + 127; - pp3[13] = float2int8(p0[31] * scale) + 127; - pp3[14] = float2int8(p0[47] * scale) + 127; - pp3[15] = float2int8(p0[63] * scale) + 127; + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + 16); + __m512 _p2 = _mm512_loadu_ps(p0 + 32); + __m512 _p3 = _mm512_loadu_ps(p0 + 48); + + _p0 = _mm512_mul_ps(_p0, _scale); + _p1 = _mm512_mul_ps(_p1, _scale); + _p2 = _mm512_mul_ps(_p2, _scale); + _p3 = _mm512_mul_ps(_p3, _scale); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + + _pp0 = _mm_add_epi8(_pp0, _v127); + _pp1 = _mm_add_epi8(_pp1, _v127); + _pp2 = _mm_add_epi8(_pp2, _v127); + _pp3 = _mm_add_epi8(_pp3, _v127); + + transpose16x4_epi8(_pp0, _pp1, _pp2, _pp3); + + _mm_storeu_si128((__m128i*)pp, _pp0); + _mm_storeu_si128((__m128i*)pp1, _pp1); + _mm_storeu_si128((__m128i*)pp2, _pp2); + _mm_storeu_si128((__m128i*)pp3, _pp3); + pp += 16; pp1 += 16; pp2 += 16; @@ -7587,41 +7426,23 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i #endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { - pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[16] * scale); - pp[2] = float2int8(p0[1] * scale); - pp[3] = float2int8(p0[17] * scale); - pp[4] = float2int8(p0[2] * scale); - pp[5] = float2int8(p0[18] * scale); - pp[6] = float2int8(p0[3] * scale); - pp[7] = float2int8(p0[19] * scale); + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + 16); + + _p0 = _mm512_mul_ps(_p0, _scale); + _p1 = _mm512_mul_ps(_p1, _scale); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + + __m128i _t0 = _mm_unpacklo_epi8(_pp0, _pp1); + __m128i _t1 = _mm_unpackhi_epi8(_pp0, _pp1); + + _mm_storel_pd((double*)pp, _mm_castsi128_pd(_t0)); + _mm_storeh_pd((double*)pp1, _mm_castsi128_pd(_t0)); + _mm_storel_pd((double*)pp2, _mm_castsi128_pd(_t1)); + _mm_storeh_pd((double*)pp3, _mm_castsi128_pd(_t1)); - pp1[0] = float2int8(p0[4] * scale); - pp1[1] = float2int8(p0[20] * scale); - pp1[2] = float2int8(p0[5] * scale); - pp1[3] = float2int8(p0[21] * scale); - pp1[4] = float2int8(p0[6] * scale); - pp1[5] = float2int8(p0[22] * scale); - pp1[6] = float2int8(p0[7] * scale); - pp1[7] = float2int8(p0[23] * scale); - - pp2[0] = float2int8(p0[8] * scale); - pp2[1] = float2int8(p0[24] * scale); - pp2[2] = float2int8(p0[9] * scale); - pp2[3] = float2int8(p0[25] * scale); - pp2[4] = float2int8(p0[10] * scale); - pp2[5] = float2int8(p0[26] * scale); - pp2[6] = float2int8(p0[11] * scale); - pp2[7] = float2int8(p0[27] * scale); - - pp3[0] = float2int8(p0[12] * scale); - pp3[1] = float2int8(p0[28] * scale); - pp3[2] = float2int8(p0[13] * scale); - pp3[3] = float2int8(p0[29] * scale); - pp3[4] = float2int8(p0[14] * scale); - pp3[5] = float2int8(p0[30] * scale); - pp3[6] = float2int8(p0[15] * scale); - pp3[7] = float2int8(p0[31] * scale); pp += 8; pp1 += 8; pp2 += 8; @@ -7630,25 +7451,17 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i } for (; kk < max_kk; kk++) { - pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[1] * scale); - pp[2] = float2int8(p0[2] * scale); - pp[3] = float2int8(p0[3] * scale); + __m512 _p = _mm512_loadu_ps(p0); - pp1[0] = float2int8(p0[4] * scale); - pp1[1] = float2int8(p0[5] * scale); - pp1[2] = float2int8(p0[6] * scale); - pp1[3] = float2int8(p0[7] * scale); + _p = _mm512_mul_ps(_p, _scale); + + __m128i _v = float2int8_avx512(_p); - pp2[0] = float2int8(p0[8] * scale); - pp2[1] = float2int8(p0[9] * scale); - pp2[2] = float2int8(p0[10] * scale); - pp2[3] = float2int8(p0[11] * scale); + *(int*)pp = _mm_extract_epi32(_v, 0); + *(int*)pp1 = _mm_extract_epi32(_v, 1); + *(int*)pp2 = _mm_extract_epi32(_v, 2); + *(int*)pp3 = _mm_extract_epi32(_v, 3); - pp3[0] = float2int8(p0[12] * scale); - pp3[1] = float2int8(p0[13] * scale); - pp3[2] = float2int8(p0[14] * scale); - pp3[3] = float2int8(p0[15] * scale); pp += 4; pp1 += 4; pp2 += 4; @@ -7662,6 +7475,11 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i #endif // __AVX512F__ if (elempack == 8) { + __m256 _scale = _mm256_set1_ps(scale); +#if __AVX512VNNI__ || __AVXVNNI__ + __m128i _v127 = _mm_set1_epi8(127); +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; jj + 7 < max_jj; jj += 8) { const float* p0 = (const float*)B + (j + jj) * B_hstep + k * elempack; @@ -7672,39 +7490,30 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i #if __AVX512VNNI__ || __AVXVNNI__ for (; kk + 3 < max_kk; kk += 4) { - pp[0] = float2int8(p0[0] * scale) + 127; - pp[1] = float2int8(p0[8] * scale) + 127; - pp[2] = float2int8(p0[16] * scale) + 127; - pp[3] = float2int8(p0[24] * scale) + 127; - pp[4] = float2int8(p0[1] * scale) + 127; - pp[5] = float2int8(p0[9] * scale) + 127; - pp[6] = float2int8(p0[17] * scale) + 127; - pp[7] = float2int8(p0[25] * scale) + 127; - pp[8] = float2int8(p0[2] * scale) + 127; - pp[9] = float2int8(p0[10] * scale) + 127; - pp[10] = float2int8(p0[18] * scale) + 127; - pp[11] = float2int8(p0[26] * scale) + 127; - pp[12] = float2int8(p0[3] * scale) + 127; - pp[13] = float2int8(p0[11] * scale) + 127; - pp[14] = float2int8(p0[19] * scale) + 127; - pp[15] = float2int8(p0[27] * scale) + 127; + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + 8); + __m256 _p2 = _mm256_loadu_ps(p0 + 16); + __m256 _p3 = _mm256_loadu_ps(p0 + 24); + + _p0 = _mm256_mul_ps(_p0, _scale); + _p1 = _mm256_mul_ps(_p1, _scale); + _p2 = _mm256_mul_ps(_p2, _scale); + _p3 = _mm256_mul_ps(_p3, _scale); + + __m128i _pp0 = float2int8_avx(_p0, _p2); + __m128i _pp1 = float2int8_avx(_p1, _p3); + + _pp0 = _mm_add_epi8(_pp0, _v127); + _pp1 = _mm_add_epi8(_pp1, _v127); + + __m128i _tt0 = _mm_unpacklo_epi8(_pp0, _pp1); + __m128i _tt1 = _mm_unpackhi_epi8(_pp0, _pp1); + _pp0 = _mm_unpacklo_epi16(_tt0, _tt1); + _pp1 = _mm_unpackhi_epi16(_tt0, _tt1); + + _mm_storeu_si128((__m128i*)pp, _pp0); + _mm_storeu_si128((__m128i*)pp1, _pp1); - pp1[0] = float2int8(p0[4] * scale) + 127; - pp1[1] = float2int8(p0[12] * scale) + 127; - pp1[2] = float2int8(p0[20] * scale) + 127; - pp1[3] = float2int8(p0[28] * scale) + 127; - pp1[4] = float2int8(p0[5] * scale) + 127; - pp1[5] = float2int8(p0[13] * scale) + 127; - pp1[6] = float2int8(p0[21] * scale) + 127; - pp1[7] = float2int8(p0[29] * scale) + 127; - pp1[8] = float2int8(p0[6] * scale) + 127; - pp1[9] = float2int8(p0[14] * scale) + 127; - pp1[10] = float2int8(p0[22] * scale) + 127; - pp1[11] = float2int8(p0[30] * scale) + 127; - pp1[12] = float2int8(p0[7] * scale) + 127; - pp1[13] = float2int8(p0[15] * scale) + 127; - pp1[14] = float2int8(p0[23] * scale) + 127; - pp1[15] = float2int8(p0[31] * scale) + 127; pp += 16; pp1 += 16; p0 += 32; @@ -7712,39 +7521,33 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i #endif // __AVX512VNNI__ || __AVXVNNI__ for (; kk + 1 < max_kk; kk += 2) { - pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[8] * scale); - pp[2] = float2int8(p0[1] * scale); - pp[3] = float2int8(p0[9] * scale); - pp[4] = float2int8(p0[2] * scale); - pp[5] = float2int8(p0[10] * scale); - pp[6] = float2int8(p0[3] * scale); - pp[7] = float2int8(p0[11] * scale); + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + 8); - pp1[0] = float2int8(p0[4] * scale); - pp1[1] = float2int8(p0[12] * scale); - pp1[2] = float2int8(p0[5] * scale); - pp1[3] = float2int8(p0[13] * scale); - pp1[4] = float2int8(p0[6] * scale); - pp1[5] = float2int8(p0[14] * scale); - pp1[6] = float2int8(p0[7] * scale); - pp1[7] = float2int8(p0[15] * scale); + _p0 = _mm256_mul_ps(_p0, _scale); + _p1 = _mm256_mul_ps(_p1, _scale); + __m128i _pp = float2int8_avx(_p0, _p1); + + __m128i _si = _mm_setr_epi8(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7 ,15); + _pp = _mm_shuffle_epi8(_pp, _si); + + _mm_storel_pd((double*)pp, _mm_castsi128_pd(_pp)); + _mm_storeh_pd((double*)pp1, _mm_castsi128_pd(_pp)); pp += 8; pp1 += 8; p0 += 16; } for (; kk < max_kk; kk++) { - pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[1] * scale); - pp[2] = float2int8(p0[2] * scale); - pp[3] = float2int8(p0[3] * scale); + __m256 _p = _mm256_loadu_ps(p0); + + _p = _mm256_mul_ps(_p, _scale); + + int64_t v = float2int8_avx(_p); - pp1[0] = float2int8(p0[4] * scale); - pp1[1] = float2int8(p0[5] * scale); - pp1[2] = float2int8(p0[6] * scale); - pp1[3] = float2int8(p0[7] * scale); + *(int32_t*)pp = (int32_t)v; + *(int32_t*)pp1 = (int32_t)(v >> 32); pp += 4; pp1 += 4;