From 8ba8efea82bd6ea50c3b553a786a9861b7db5864 Mon Sep 17 00:00:00 2001 From: nihui Date: Mon, 9 Dec 2024 11:43:40 +0000 Subject: [PATCH] opt packa packb avx --- src/layer/x86/gemm_int8.h | 2035 +++++++++++-------------------------- 1 file changed, 574 insertions(+), 1461 deletions(-) diff --git a/src/layer/x86/gemm_int8.h b/src/layer/x86/gemm_int8.h index 3d9ef1022d4..165ce9eb38b 100644 --- a/src/layer/x86/gemm_int8.h +++ b/src/layer/x86/gemm_int8.h @@ -4592,446 +4592,104 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int { const float* p0 = (const float*)A + k * A_hstep + (i + ii) * elempack; - const float scale0 = scales[i + ii]; - const float scale1 = scales[i + ii + 1]; - const float scale2 = scales[i + ii + 2]; - const float scale3 = scales[i + ii + 3]; - const float scale4 = scales[i + ii + 4]; - const float scale5 = scales[i + ii + 5]; - const float scale6 = scales[i + ii + 6]; - const float scale7 = scales[i + ii + 7]; + __m256 _scales = _mm256_loadu_ps((const float*)scales + i + ii); +#if __AVX512VNNI__ || __AVXVNNI__ + __m256i _v127 = _mm256_set1_epi8(127); +#endif // __AVX512VNNI__ || __AVXVNNI__ #if __AVX512F__ if (elempack == 16) { int kk = 0; #if __AVX512VNNI__ - int w_shift0 = 0; - int w_shift1 = 0; - int w_shift2 = 0; - int w_shift3 = 0; - int w_shift4 = 0; - int w_shift5 = 0; - int w_shift6 = 0; - int w_shift7 = 0; + __m512i _w_shift_avx512 = _mm512_setzero_si512(); + __m512i _v127_avx512 = _mm512_set1_epi8(127); for (; kk + 15 < max_kk; kk += 16) { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[1] * scale0); - pp[2] = float2int8(p0[2 + 0] * scale0); - pp[3] = float2int8(p0[2 + 1] * scale0); - pp[4] = float2int8(p0[16] * scale1); - pp[5] = float2int8(p0[17] * scale1); - pp[6] = float2int8(p0[2 + 16] * scale1); - pp[7] = float2int8(p0[2 + 17] * scale1); - pp[8] = float2int8(p0[32] * scale2); - pp[9] = float2int8(p0[33] * scale2); - pp[10] = float2int8(p0[2 + 32] * scale2); - pp[11] = float2int8(p0[2 + 33] * scale2); - pp[12] = float2int8(p0[48] * scale3); - pp[13] = float2int8(p0[49] * scale3); - pp[14] = float2int8(p0[2 + 48] * scale3); - pp[15] = float2int8(p0[2 + 49] * scale3); - pp[16] = float2int8(p0[64] * scale4); - pp[17] = float2int8(p0[65] * scale4); - pp[18] = float2int8(p0[2 + 64] * scale4); - pp[19] = float2int8(p0[2 + 65] * scale4); - pp[20] = float2int8(p0[80] * scale5); - pp[21] = float2int8(p0[81] * scale5); - pp[22] = float2int8(p0[2 + 80] * scale5); - pp[23] = float2int8(p0[2 + 81] * scale5); - pp[24] = float2int8(p0[96] * scale6); - pp[25] = float2int8(p0[97] * scale6); - pp[26] = float2int8(p0[2 + 96] * scale6); - pp[27] = float2int8(p0[2 + 97] * scale6); - pp[28] = float2int8(p0[112] * scale7); - pp[29] = float2int8(p0[113] * scale7); - pp[30] = float2int8(p0[2 + 112] * scale7); - pp[31] = float2int8(p0[2 + 113] * scale7); - - pp[32 + 0] = float2int8(p0[4 + 0] * scale0); - pp[32 + 1] = float2int8(p0[4 + 1] * scale0); - pp[32 + 2] = float2int8(p0[6 + 0] * scale0); - pp[32 + 3] = float2int8(p0[6 + 1] * scale0); - pp[32 + 4] = float2int8(p0[4 + 16] * scale1); - pp[32 + 5] = float2int8(p0[4 + 17] * scale1); - pp[32 + 6] = float2int8(p0[6 + 16] * scale1); - pp[32 + 7] = float2int8(p0[6 + 17] * scale1); - pp[32 + 8] = float2int8(p0[4 + 32] * scale2); - pp[32 + 9] = float2int8(p0[4 + 33] * scale2); - pp[32 + 10] = float2int8(p0[6 + 32] * scale2); - pp[32 + 11] = float2int8(p0[6 + 33] * scale2); - pp[32 + 12] = float2int8(p0[4 + 48] * scale3); - pp[32 + 13] = float2int8(p0[4 + 49] * scale3); - pp[32 + 14] = float2int8(p0[6 + 48] * scale3); - pp[32 + 15] = float2int8(p0[6 + 49] * scale3); - pp[32 + 16] = float2int8(p0[4 + 64] * scale4); - pp[32 + 17] = float2int8(p0[4 + 65] * scale4); - pp[32 + 18] = float2int8(p0[6 + 64] * scale4); - pp[32 + 19] = float2int8(p0[6 + 65] * scale4); - pp[32 + 20] = float2int8(p0[4 + 80] * scale5); - pp[32 + 21] = float2int8(p0[4 + 81] * scale5); - pp[32 + 22] = float2int8(p0[6 + 80] * scale5); - pp[32 + 23] = float2int8(p0[6 + 81] * scale5); - pp[32 + 24] = float2int8(p0[4 + 96] * scale6); - pp[32 + 25] = float2int8(p0[4 + 97] * scale6); - pp[32 + 26] = float2int8(p0[6 + 96] * scale6); - pp[32 + 27] = float2int8(p0[6 + 97] * scale6); - pp[32 + 28] = float2int8(p0[4 + 112] * scale7); - pp[32 + 29] = float2int8(p0[4 + 113] * scale7); - pp[32 + 30] = float2int8(p0[6 + 112] * scale7); - pp[32 + 31] = float2int8(p0[6 + 113] * scale7); - - pp[64 + 0] = float2int8(p0[8 + 0] * scale0); - pp[64 + 1] = float2int8(p0[8 + 1] * scale0); - pp[64 + 2] = float2int8(p0[10 + 0] * scale0); - pp[64 + 3] = float2int8(p0[10 + 1] * scale0); - pp[64 + 4] = float2int8(p0[8 + 16] * scale1); - pp[64 + 5] = float2int8(p0[8 + 17] * scale1); - pp[64 + 6] = float2int8(p0[10 + 16] * scale1); - pp[64 + 7] = float2int8(p0[10 + 17] * scale1); - pp[64 + 8] = float2int8(p0[8 + 32] * scale2); - pp[64 + 9] = float2int8(p0[8 + 33] * scale2); - pp[64 + 10] = float2int8(p0[10 + 32] * scale2); - pp[64 + 11] = float2int8(p0[10 + 33] * scale2); - pp[64 + 12] = float2int8(p0[8 + 48] * scale3); - pp[64 + 13] = float2int8(p0[8 + 49] * scale3); - pp[64 + 14] = float2int8(p0[10 + 48] * scale3); - pp[64 + 15] = float2int8(p0[10 + 49] * scale3); - pp[64 + 16] = float2int8(p0[8 + 64] * scale4); - pp[64 + 17] = float2int8(p0[8 + 65] * scale4); - pp[64 + 18] = float2int8(p0[10 + 64] * scale4); - pp[64 + 19] = float2int8(p0[10 + 65] * scale4); - pp[64 + 20] = float2int8(p0[8 + 80] * scale5); - pp[64 + 21] = float2int8(p0[8 + 81] * scale5); - pp[64 + 22] = float2int8(p0[10 + 80] * scale5); - pp[64 + 23] = float2int8(p0[10 + 81] * scale5); - pp[64 + 24] = float2int8(p0[8 + 96] * scale6); - pp[64 + 25] = float2int8(p0[8 + 97] * scale6); - pp[64 + 26] = float2int8(p0[10 + 96] * scale6); - pp[64 + 27] = float2int8(p0[10 + 97] * scale6); - pp[64 + 28] = float2int8(p0[8 + 112] * scale7); - pp[64 + 29] = float2int8(p0[8 + 113] * scale7); - pp[64 + 30] = float2int8(p0[10 + 112] * scale7); - pp[64 + 31] = float2int8(p0[10 + 113] * scale7); - - pp[96 + 0] = float2int8(p0[12 + 0] * scale0); - pp[96 + 1] = float2int8(p0[12 + 1] * scale0); - pp[96 + 2] = float2int8(p0[14 + 0] * scale0); - pp[96 + 3] = float2int8(p0[14 + 1] * scale0); - pp[96 + 4] = float2int8(p0[12 + 16] * scale1); - pp[96 + 5] = float2int8(p0[12 + 17] * scale1); - pp[96 + 6] = float2int8(p0[14 + 16] * scale1); - pp[96 + 7] = float2int8(p0[14 + 17] * scale1); - pp[96 + 8] = float2int8(p0[12 + 32] * scale2); - pp[96 + 9] = float2int8(p0[12 + 33] * scale2); - pp[96 + 10] = float2int8(p0[14 + 32] * scale2); - pp[96 + 11] = float2int8(p0[14 + 33] * scale2); - pp[96 + 12] = float2int8(p0[12 + 48] * scale3); - pp[96 + 13] = float2int8(p0[12 + 49] * scale3); - pp[96 + 14] = float2int8(p0[14 + 48] * scale3); - pp[96 + 15] = float2int8(p0[14 + 49] * scale3); - pp[96 + 16] = float2int8(p0[12 + 64] * scale4); - pp[96 + 17] = float2int8(p0[12 + 65] * scale4); - pp[96 + 18] = float2int8(p0[14 + 64] * scale4); - pp[96 + 19] = float2int8(p0[14 + 65] * scale4); - pp[96 + 20] = float2int8(p0[12 + 80] * scale5); - pp[96 + 21] = float2int8(p0[12 + 81] * scale5); - pp[96 + 22] = float2int8(p0[14 + 80] * scale5); - pp[96 + 23] = float2int8(p0[14 + 81] * scale5); - pp[96 + 24] = float2int8(p0[12 + 96] * scale6); - pp[96 + 25] = float2int8(p0[12 + 97] * scale6); - pp[96 + 26] = float2int8(p0[14 + 96] * scale6); - pp[96 + 27] = float2int8(p0[14 + 97] * scale6); - pp[96 + 28] = float2int8(p0[12 + 112] * scale7); - pp[96 + 29] = float2int8(p0[12 + 113] * scale7); - pp[96 + 30] = float2int8(p0[14 + 112] * scale7); - pp[96 + 31] = float2int8(p0[14 + 113] * scale7); + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + 16); + __m512 _p2 = _mm512_loadu_ps(p0 + 32); + __m512 _p3 = _mm512_loadu_ps(p0 + 48); + __m512 _p4 = _mm512_loadu_ps(p0 + 64); + __m512 _p5 = _mm512_loadu_ps(p0 + 80); + __m512 _p6 = _mm512_loadu_ps(p0 + 96); + __m512 _p7 = _mm512_loadu_ps(p0 + 112); - w_shift0 += pp[0]; - w_shift0 += pp[1]; - w_shift0 += pp[2]; - w_shift0 += pp[3]; - w_shift1 += pp[4]; - w_shift1 += pp[5]; - w_shift1 += pp[6]; - w_shift1 += pp[7]; - w_shift2 += pp[8]; - w_shift2 += pp[9]; - w_shift2 += pp[10]; - w_shift2 += pp[11]; - w_shift3 += pp[12]; - w_shift3 += pp[13]; - w_shift3 += pp[14]; - w_shift3 += pp[15]; - w_shift4 += pp[16]; - w_shift4 += pp[17]; - w_shift4 += pp[18]; - w_shift4 += pp[19]; - w_shift5 += pp[20]; - w_shift5 += pp[21]; - w_shift5 += pp[22]; - w_shift5 += pp[23]; - w_shift6 += pp[24]; - w_shift6 += pp[25]; - w_shift6 += pp[26]; - w_shift6 += pp[27]; - w_shift7 += pp[28]; - w_shift7 += pp[29]; - w_shift7 += pp[30]; - w_shift7 += pp[31]; + _p0 = _mm512_mul_ps(_p0, _mm512_set1_ps(scales[i + ii])); + _p1 = _mm512_mul_ps(_p1, _mm512_set1_ps(scales[i + ii + 1])); + _p2 = _mm512_mul_ps(_p2, _mm512_set1_ps(scales[i + ii + 2])); + _p3 = _mm512_mul_ps(_p3, _mm512_set1_ps(scales[i + ii + 3])); + _p4 = _mm512_mul_ps(_p4, _mm512_set1_ps(scales[i + ii + 4])); + _p5 = _mm512_mul_ps(_p5, _mm512_set1_ps(scales[i + ii + 5])); + _p6 = _mm512_mul_ps(_p6, _mm512_set1_ps(scales[i + ii + 6])); + _p7 = _mm512_mul_ps(_p7, _mm512_set1_ps(scales[i + ii + 7])); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + __m128i _pp4 = float2int8_avx512(_p4); + __m128i _pp5 = float2int8_avx512(_p5); + __m128i _pp6 = float2int8_avx512(_p6); + __m128i _pp7 = float2int8_avx512(_p7); + + transpose4x8_epi32(_pp0, _pp1, _pp2, _pp3, _pp4, _pp5, _pp6, _pp7); + + __m512i _t0 = combine4x4_epi32(_pp0, _pp1, _pp2, _pp3); + __m512i _t1 = combine4x4_epi32(_pp4, _pp5, _pp6, _pp7); + + _w_shift_avx512 = _mm512_dpbusd_epi32(_w_shift_avx512, _v127_avx512, _t0); + _w_shift_avx512 = _mm512_dpbusd_epi32(_w_shift_avx512, _v127_avx512, _t1); + + _mm512_storeu_si512((__m512i*)pp, _t0); + _mm512_storeu_si512((__m512i*)(pp + 64), _t1); - w_shift0 += pp[32 + 0]; - w_shift0 += pp[32 + 1]; - w_shift0 += pp[32 + 2]; - w_shift0 += pp[32 + 3]; - w_shift1 += pp[32 + 4]; - w_shift1 += pp[32 + 5]; - w_shift1 += pp[32 + 6]; - w_shift1 += pp[32 + 7]; - w_shift2 += pp[32 + 8]; - w_shift2 += pp[32 + 9]; - w_shift2 += pp[32 + 10]; - w_shift2 += pp[32 + 11]; - w_shift3 += pp[32 + 12]; - w_shift3 += pp[32 + 13]; - w_shift3 += pp[32 + 14]; - w_shift3 += pp[32 + 15]; - w_shift4 += pp[32 + 16]; - w_shift4 += pp[32 + 17]; - w_shift4 += pp[32 + 18]; - w_shift4 += pp[32 + 19]; - w_shift5 += pp[32 + 20]; - w_shift5 += pp[32 + 21]; - w_shift5 += pp[32 + 22]; - w_shift5 += pp[32 + 23]; - w_shift6 += pp[32 + 24]; - w_shift6 += pp[32 + 25]; - w_shift6 += pp[32 + 26]; - w_shift6 += pp[32 + 27]; - w_shift7 += pp[32 + 28]; - w_shift7 += pp[32 + 29]; - w_shift7 += pp[32 + 30]; - w_shift7 += pp[32 + 31]; - - w_shift0 += pp[64 + 0]; - w_shift0 += pp[64 + 1]; - w_shift0 += pp[64 + 2]; - w_shift0 += pp[64 + 3]; - w_shift1 += pp[64 + 4]; - w_shift1 += pp[64 + 5]; - w_shift1 += pp[64 + 6]; - w_shift1 += pp[64 + 7]; - w_shift2 += pp[64 + 8]; - w_shift2 += pp[64 + 9]; - w_shift2 += pp[64 + 10]; - w_shift2 += pp[64 + 11]; - w_shift3 += pp[64 + 12]; - w_shift3 += pp[64 + 13]; - w_shift3 += pp[64 + 14]; - w_shift3 += pp[64 + 15]; - w_shift4 += pp[64 + 16]; - w_shift4 += pp[64 + 17]; - w_shift4 += pp[64 + 18]; - w_shift4 += pp[64 + 19]; - w_shift5 += pp[64 + 20]; - w_shift5 += pp[64 + 21]; - w_shift5 += pp[64 + 22]; - w_shift5 += pp[64 + 23]; - w_shift6 += pp[64 + 24]; - w_shift6 += pp[64 + 25]; - w_shift6 += pp[64 + 26]; - w_shift6 += pp[64 + 27]; - w_shift7 += pp[64 + 28]; - w_shift7 += pp[64 + 29]; - w_shift7 += pp[64 + 30]; - w_shift7 += pp[64 + 31]; - - w_shift0 += pp[96 + 0]; - w_shift0 += pp[96 + 1]; - w_shift0 += pp[96 + 2]; - w_shift0 += pp[96 + 3]; - w_shift1 += pp[96 + 4]; - w_shift1 += pp[96 + 5]; - w_shift1 += pp[96 + 6]; - w_shift1 += pp[96 + 7]; - w_shift2 += pp[96 + 8]; - w_shift2 += pp[96 + 9]; - w_shift2 += pp[96 + 10]; - w_shift2 += pp[96 + 11]; - w_shift3 += pp[96 + 12]; - w_shift3 += pp[96 + 13]; - w_shift3 += pp[96 + 14]; - w_shift3 += pp[96 + 15]; - w_shift4 += pp[96 + 16]; - w_shift4 += pp[96 + 17]; - w_shift4 += pp[96 + 18]; - w_shift4 += pp[96 + 19]; - w_shift5 += pp[96 + 20]; - w_shift5 += pp[96 + 21]; - w_shift5 += pp[96 + 22]; - w_shift5 += pp[96 + 23]; - w_shift6 += pp[96 + 24]; - w_shift6 += pp[96 + 25]; - w_shift6 += pp[96 + 26]; - w_shift6 += pp[96 + 27]; - w_shift7 += pp[96 + 28]; - w_shift7 += pp[96 + 29]; - w_shift7 += pp[96 + 30]; - w_shift7 += pp[96 + 31]; pp += 128; p0 += A_hstep * 16; } if (max_kk >= 4) { - ((int*)pp)[0] = w_shift0 * 127; - ((int*)pp)[1] = w_shift1 * 127; - ((int*)pp)[2] = w_shift2 * 127; - ((int*)pp)[3] = w_shift3 * 127; - ((int*)pp)[4] = w_shift4 * 127; - ((int*)pp)[5] = w_shift5 * 127; - ((int*)pp)[6] = w_shift6 * 127; - ((int*)pp)[7] = w_shift7 * 127; + __m256i _w_shift = _mm256_add_epi32(_mm512_extracti32x8_epi32(_w_shift_avx512, 0), _mm512_extracti32x8_epi32(_w_shift_avx512, 1)); + _mm256_storeu_si256((__m256i*)pp, _w_shift); pp += 32; } #else // __AVX512VNNI__ for (; kk + 15 < max_kk; kk += 16) { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[1] * scale0); - pp[2] = float2int8(p0[16] * scale1); - pp[3] = float2int8(p0[17] * scale1); - pp[4] = float2int8(p0[32] * scale2); - pp[5] = float2int8(p0[33] * scale2); - pp[6] = float2int8(p0[48] * scale3); - pp[7] = float2int8(p0[49] * scale3); - pp[8] = float2int8(p0[64] * scale4); - pp[9] = float2int8(p0[65] * scale4); - pp[10] = float2int8(p0[80] * scale5); - pp[11] = float2int8(p0[81] * scale5); - pp[12] = float2int8(p0[96] * scale6); - pp[13] = float2int8(p0[97] * scale6); - pp[14] = float2int8(p0[112] * scale7); - pp[15] = float2int8(p0[113] * scale7); - - pp[16 + 0] = float2int8(p0[2 + 0] * scale0); - pp[16 + 1] = float2int8(p0[2 + 1] * scale0); - pp[16 + 2] = float2int8(p0[2 + 16] * scale1); - pp[16 + 3] = float2int8(p0[2 + 17] * scale1); - pp[16 + 4] = float2int8(p0[2 + 32] * scale2); - pp[16 + 5] = float2int8(p0[2 + 33] * scale2); - pp[16 + 6] = float2int8(p0[2 + 48] * scale3); - pp[16 + 7] = float2int8(p0[2 + 49] * scale3); - pp[16 + 8] = float2int8(p0[2 + 64] * scale4); - pp[16 + 9] = float2int8(p0[2 + 65] * scale4); - pp[16 + 10] = float2int8(p0[2 + 80] * scale5); - pp[16 + 11] = float2int8(p0[2 + 81] * scale5); - pp[16 + 12] = float2int8(p0[2 + 96] * scale6); - pp[16 + 13] = float2int8(p0[2 + 97] * scale6); - pp[16 + 14] = float2int8(p0[2 + 112] * scale7); - pp[16 + 15] = float2int8(p0[2 + 113] * scale7); - - pp[32 + 0] = float2int8(p0[4 + 0] * scale0); - pp[32 + 1] = float2int8(p0[4 + 1] * scale0); - pp[32 + 2] = float2int8(p0[4 + 16] * scale1); - pp[32 + 3] = float2int8(p0[4 + 17] * scale1); - pp[32 + 4] = float2int8(p0[4 + 32] * scale2); - pp[32 + 5] = float2int8(p0[4 + 33] * scale2); - pp[32 + 6] = float2int8(p0[4 + 48] * scale3); - pp[32 + 7] = float2int8(p0[4 + 49] * scale3); - pp[32 + 8] = float2int8(p0[4 + 64] * scale4); - pp[32 + 9] = float2int8(p0[4 + 65] * scale4); - pp[32 + 10] = float2int8(p0[4 + 80] * scale5); - pp[32 + 11] = float2int8(p0[4 + 81] * scale5); - pp[32 + 12] = float2int8(p0[4 + 96] * scale6); - pp[32 + 13] = float2int8(p0[4 + 97] * scale6); - pp[32 + 14] = float2int8(p0[4 + 112] * scale7); - pp[32 + 15] = float2int8(p0[4 + 113] * scale7); - - pp[48 + 0] = float2int8(p0[6 + 0] * scale0); - pp[48 + 1] = float2int8(p0[6 + 1] * scale0); - pp[48 + 2] = float2int8(p0[6 + 16] * scale1); - pp[48 + 3] = float2int8(p0[6 + 17] * scale1); - pp[48 + 4] = float2int8(p0[6 + 32] * scale2); - pp[48 + 5] = float2int8(p0[6 + 33] * scale2); - pp[48 + 6] = float2int8(p0[6 + 48] * scale3); - pp[48 + 7] = float2int8(p0[6 + 49] * scale3); - pp[48 + 8] = float2int8(p0[6 + 64] * scale4); - pp[48 + 9] = float2int8(p0[6 + 65] * scale4); - pp[48 + 10] = float2int8(p0[6 + 80] * scale5); - pp[48 + 11] = float2int8(p0[6 + 81] * scale5); - pp[48 + 12] = float2int8(p0[6 + 96] * scale6); - pp[48 + 13] = float2int8(p0[6 + 97] * scale6); - pp[48 + 14] = float2int8(p0[6 + 112] * scale7); - pp[48 + 15] = float2int8(p0[6 + 113] * scale7); - - pp[64 + 0] = float2int8(p0[8 + 0] * scale0); - pp[64 + 1] = float2int8(p0[8 + 1] * scale0); - pp[64 + 2] = float2int8(p0[8 + 16] * scale1); - pp[64 + 3] = float2int8(p0[8 + 17] * scale1); - pp[64 + 4] = float2int8(p0[8 + 32] * scale2); - pp[64 + 5] = float2int8(p0[8 + 33] * scale2); - pp[64 + 6] = float2int8(p0[8 + 48] * scale3); - pp[64 + 7] = float2int8(p0[8 + 49] * scale3); - pp[64 + 8] = float2int8(p0[8 + 64] * scale4); - pp[64 + 9] = float2int8(p0[8 + 65] * scale4); - pp[64 + 10] = float2int8(p0[8 + 80] * scale5); - pp[64 + 11] = float2int8(p0[8 + 81] * scale5); - pp[64 + 12] = float2int8(p0[8 + 96] * scale6); - pp[64 + 13] = float2int8(p0[8 + 97] * scale6); - pp[64 + 14] = float2int8(p0[8 + 112] * scale7); - pp[64 + 15] = float2int8(p0[8 + 113] * scale7); - - pp[80 + 0] = float2int8(p0[10 + 0] * scale0); - pp[80 + 1] = float2int8(p0[10 + 1] * scale0); - pp[80 + 2] = float2int8(p0[10 + 16] * scale1); - pp[80 + 3] = float2int8(p0[10 + 17] * scale1); - pp[80 + 4] = float2int8(p0[10 + 32] * scale2); - pp[80 + 5] = float2int8(p0[10 + 33] * scale2); - pp[80 + 6] = float2int8(p0[10 + 48] * scale3); - pp[80 + 7] = float2int8(p0[10 + 49] * scale3); - pp[80 + 8] = float2int8(p0[10 + 64] * scale4); - pp[80 + 9] = float2int8(p0[10 + 65] * scale4); - pp[80 + 10] = float2int8(p0[10 + 80] * scale5); - pp[80 + 11] = float2int8(p0[10 + 81] * scale5); - pp[80 + 12] = float2int8(p0[10 + 96] * scale6); - pp[80 + 13] = float2int8(p0[10 + 97] * scale6); - pp[80 + 14] = float2int8(p0[10 + 112] * scale7); - pp[80 + 15] = float2int8(p0[10 + 113] * scale7); - - pp[96 + 0] = float2int8(p0[12 + 0] * scale0); - pp[96 + 1] = float2int8(p0[12 + 1] * scale0); - pp[96 + 2] = float2int8(p0[12 + 16] * scale1); - pp[96 + 3] = float2int8(p0[12 + 17] * scale1); - pp[96 + 4] = float2int8(p0[12 + 32] * scale2); - pp[96 + 5] = float2int8(p0[12 + 33] * scale2); - pp[96 + 6] = float2int8(p0[12 + 48] * scale3); - pp[96 + 7] = float2int8(p0[12 + 49] * scale3); - pp[96 + 8] = float2int8(p0[12 + 64] * scale4); - pp[96 + 9] = float2int8(p0[12 + 65] * scale4); - pp[96 + 10] = float2int8(p0[12 + 80] * scale5); - pp[96 + 11] = float2int8(p0[12 + 81] * scale5); - pp[96 + 12] = float2int8(p0[12 + 96] * scale6); - pp[96 + 13] = float2int8(p0[12 + 97] * scale6); - pp[96 + 14] = float2int8(p0[12 + 112] * scale7); - pp[96 + 15] = float2int8(p0[12 + 113] * scale7); - - pp[112 + 0] = float2int8(p0[14 + 0] * scale0); - pp[112 + 1] = float2int8(p0[14 + 1] * scale0); - pp[112 + 2] = float2int8(p0[14 + 16] * scale1); - pp[112 + 3] = float2int8(p0[14 + 17] * scale1); - pp[112 + 4] = float2int8(p0[14 + 32] * scale2); - pp[112 + 5] = float2int8(p0[14 + 33] * scale2); - pp[112 + 6] = float2int8(p0[14 + 48] * scale3); - pp[112 + 7] = float2int8(p0[14 + 49] * scale3); - pp[112 + 8] = float2int8(p0[14 + 64] * scale4); - pp[112 + 9] = float2int8(p0[14 + 65] * scale4); - pp[112 + 10] = float2int8(p0[14 + 80] * scale5); - pp[112 + 11] = float2int8(p0[14 + 81] * scale5); - pp[112 + 12] = float2int8(p0[14 + 96] * scale6); - pp[112 + 13] = float2int8(p0[14 + 97] * scale6); - pp[112 + 14] = float2int8(p0[14 + 112] * scale7); - pp[112 + 15] = float2int8(p0[14 + 113] * scale7); + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + 16); + __m512 _p2 = _mm512_loadu_ps(p0 + 32); + __m512 _p3 = _mm512_loadu_ps(p0 + 48); + __m512 _p4 = _mm512_loadu_ps(p0 + 64); + __m512 _p5 = _mm512_loadu_ps(p0 + 80); + __m512 _p6 = _mm512_loadu_ps(p0 + 96); + __m512 _p7 = _mm512_loadu_ps(p0 + 112); + + _p0 = _mm512_mul_ps(_p0, _mm512_set1_ps(scales[i + ii])); + _p1 = _mm512_mul_ps(_p1, _mm512_set1_ps(scales[i + ii + 1])); + _p2 = _mm512_mul_ps(_p2, _mm512_set1_ps(scales[i + ii + 2])); + _p3 = _mm512_mul_ps(_p3, _mm512_set1_ps(scales[i + ii + 3])); + _p4 = _mm512_mul_ps(_p4, _mm512_set1_ps(scales[i + ii + 4])); + _p5 = _mm512_mul_ps(_p5, _mm512_set1_ps(scales[i + ii + 5])); + _p6 = _mm512_mul_ps(_p6, _mm512_set1_ps(scales[i + ii + 6])); + _p7 = _mm512_mul_ps(_p7, _mm512_set1_ps(scales[i + ii + 7])); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + __m128i _pp4 = float2int8_avx512(_p4); + __m128i _pp5 = float2int8_avx512(_p5); + __m128i _pp6 = float2int8_avx512(_p6); + __m128i _pp7 = float2int8_avx512(_p7); + + transpose8x8_epi16(_pp0, _pp1, _pp2, _pp3, _pp4, _pp5, _pp6, _pp7); + + __m512i _t0 = combine4x4_epi32(_pp0, _pp1, _pp2, _pp3); + __m512i _t1 = combine4x4_epi32(_pp4, _pp5, _pp6, _pp7); + + _mm512_storeu_si512((__m512i*)pp, _t0); + _mm512_storeu_si512((__m512i*)(pp + 64), _t1); pp += 128; p0 += A_hstep * 16; @@ -5043,449 +4701,193 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int { int kk = 0; #if __AVX512VNNI__ || __AVXVNNI__ - int w_shift0 = 0; - int w_shift1 = 0; - int w_shift2 = 0; - int w_shift3 = 0; - int w_shift4 = 0; - int w_shift5 = 0; - int w_shift6 = 0; - int w_shift7 = 0; + __m256i _w_shift = _mm256_setzero_si256(); for (; kk + 7 < max_kk; kk += 8) { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[1] * scale0); - pp[2] = float2int8(p0[2] * scale0); - pp[3] = float2int8(p0[3] * scale0); - pp[4] = float2int8(p0[8] * scale1); - pp[5] = float2int8(p0[9] * scale1); - pp[6] = float2int8(p0[10] * scale1); - pp[7] = float2int8(p0[11] * scale1); - pp[8] = float2int8(p0[16] * scale2); - pp[9] = float2int8(p0[17] * scale2); - pp[10] = float2int8(p0[18] * scale2); - pp[11] = float2int8(p0[19] * scale2); - pp[12] = float2int8(p0[24] * scale3); - pp[13] = float2int8(p0[25] * scale3); - pp[14] = float2int8(p0[26] * scale3); - pp[15] = float2int8(p0[27] * scale3); - pp[16] = float2int8(p0[32] * scale4); - pp[17] = float2int8(p0[33] * scale4); - pp[18] = float2int8(p0[34] * scale4); - pp[19] = float2int8(p0[35] * scale4); - pp[20] = float2int8(p0[40] * scale5); - pp[21] = float2int8(p0[41] * scale5); - pp[22] = float2int8(p0[42] * scale5); - pp[23] = float2int8(p0[43] * scale5); - pp[24] = float2int8(p0[48] * scale6); - pp[25] = float2int8(p0[49] * scale6); - pp[26] = float2int8(p0[50] * scale6); - pp[27] = float2int8(p0[51] * scale6); - pp[28] = float2int8(p0[56] * scale7); - pp[29] = float2int8(p0[57] * scale7); - pp[30] = float2int8(p0[58] * scale7); - pp[31] = float2int8(p0[59] * scale7); - - pp[32 + 0] = float2int8(p0[4] * scale0); - pp[32 + 1] = float2int8(p0[5] * scale0); - pp[32 + 2] = float2int8(p0[6] * scale0); - pp[32 + 3] = float2int8(p0[7] * scale0); - pp[32 + 4] = float2int8(p0[12] * scale1); - pp[32 + 5] = float2int8(p0[13] * scale1); - pp[32 + 6] = float2int8(p0[14] * scale1); - pp[32 + 7] = float2int8(p0[15] * scale1); - pp[32 + 8] = float2int8(p0[20] * scale2); - pp[32 + 9] = float2int8(p0[21] * scale2); - pp[32 + 10] = float2int8(p0[22] * scale2); - pp[32 + 11] = float2int8(p0[23] * scale2); - pp[32 + 12] = float2int8(p0[28] * scale3); - pp[32 + 13] = float2int8(p0[29] * scale3); - pp[32 + 14] = float2int8(p0[30] * scale3); - pp[32 + 15] = float2int8(p0[31] * scale3); - pp[32 + 16] = float2int8(p0[36] * scale4); - pp[32 + 17] = float2int8(p0[37] * scale4); - pp[32 + 18] = float2int8(p0[38] * scale4); - pp[32 + 19] = float2int8(p0[39] * scale4); - pp[32 + 20] = float2int8(p0[44] * scale5); - pp[32 + 21] = float2int8(p0[45] * scale5); - pp[32 + 22] = float2int8(p0[46] * scale5); - pp[32 + 23] = float2int8(p0[47] * scale5); - pp[32 + 24] = float2int8(p0[52] * scale6); - pp[32 + 25] = float2int8(p0[53] * scale6); - pp[32 + 26] = float2int8(p0[54] * scale6); - pp[32 + 27] = float2int8(p0[55] * scale6); - pp[32 + 28] = float2int8(p0[60] * scale7); - pp[32 + 29] = float2int8(p0[61] * scale7); - pp[32 + 30] = float2int8(p0[62] * scale7); - pp[32 + 31] = float2int8(p0[63] * scale7); + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + 8); + __m256 _p2 = _mm256_loadu_ps(p0 + 16); + __m256 _p3 = _mm256_loadu_ps(p0 + 24); + __m256 _p4 = _mm256_loadu_ps(p0 + 32); + __m256 _p5 = _mm256_loadu_ps(p0 + 40); + __m256 _p6 = _mm256_loadu_ps(p0 + 48); + __m256 _p7 = _mm256_loadu_ps(p0 + 56); + + _p0 = _mm256_mul_ps(_p0, _mm256_set1_ps(scales[i + ii])); + _p1 = _mm256_mul_ps(_p1, _mm256_set1_ps(scales[i + ii + 1])); + _p2 = _mm256_mul_ps(_p2, _mm256_set1_ps(scales[i + ii + 2])); + _p3 = _mm256_mul_ps(_p3, _mm256_set1_ps(scales[i + ii + 3])); + _p4 = _mm256_mul_ps(_p4, _mm256_set1_ps(scales[i + ii + 4])); + _p5 = _mm256_mul_ps(_p5, _mm256_set1_ps(scales[i + ii + 5])); + _p6 = _mm256_mul_ps(_p6, _mm256_set1_ps(scales[i + ii + 6])); + _p7 = _mm256_mul_ps(_p7, _mm256_set1_ps(scales[i + ii + 7])); - w_shift0 += pp[0]; - w_shift0 += pp[1]; - w_shift0 += pp[2]; - w_shift0 += pp[3]; - w_shift1 += pp[4]; - w_shift1 += pp[5]; - w_shift1 += pp[6]; - w_shift1 += pp[7]; - w_shift2 += pp[8]; - w_shift2 += pp[9]; - w_shift2 += pp[10]; - w_shift2 += pp[11]; - w_shift3 += pp[12]; - w_shift3 += pp[13]; - w_shift3 += pp[14]; - w_shift3 += pp[15]; - w_shift4 += pp[16]; - w_shift4 += pp[17]; - w_shift4 += pp[18]; - w_shift4 += pp[19]; - w_shift5 += pp[20]; - w_shift5 += pp[21]; - w_shift5 += pp[22]; - w_shift5 += pp[23]; - w_shift6 += pp[24]; - w_shift6 += pp[25]; - w_shift6 += pp[26]; - w_shift6 += pp[27]; - w_shift7 += pp[28]; - w_shift7 += pp[29]; - w_shift7 += pp[30]; - w_shift7 += pp[31]; + __m128i _pp0 = float2int8_avx(_p0, _p2); + __m128i _pp1 = float2int8_avx(_p1, _p3); + __m128i _pp2 = float2int8_avx(_p4, _p6); + __m128i _pp3 = float2int8_avx(_p5, _p7); + + __m256i _t0 = combine4x2_epi32(_pp0, _pp2); + __m256i _t1 = combine4x2_epi32(_pp1, _pp3); + + __m256i _t2 = _mm256_unpacklo_epi32(_t0, _t1); + __m256i _t3 = _mm256_unpackhi_epi32(_t0, _t1); + _t0 = _mm256_unpacklo_epi64(_t2, _t3); + _t1 = _mm256_unpackhi_epi64(_t2, _t3); + + _w_shift = _mm256_comp_dpbusd_epi32(_w_shift, _v127, _t0); + _w_shift = _mm256_comp_dpbusd_epi32(_w_shift, _v127, _t1); + + _mm256_storeu_si256((__m256i*)pp, _t0); + _mm256_storeu_si256((__m256i*)(pp + 32), _t1); - w_shift0 += pp[32 + 0]; - w_shift0 += pp[32 + 1]; - w_shift0 += pp[32 + 2]; - w_shift0 += pp[32 + 3]; - w_shift1 += pp[32 + 4]; - w_shift1 += pp[32 + 5]; - w_shift1 += pp[32 + 6]; - w_shift1 += pp[32 + 7]; - w_shift2 += pp[32 + 8]; - w_shift2 += pp[32 + 9]; - w_shift2 += pp[32 + 10]; - w_shift2 += pp[32 + 11]; - w_shift3 += pp[32 + 12]; - w_shift3 += pp[32 + 13]; - w_shift3 += pp[32 + 14]; - w_shift3 += pp[32 + 15]; - w_shift4 += pp[32 + 16]; - w_shift4 += pp[32 + 17]; - w_shift4 += pp[32 + 18]; - w_shift4 += pp[32 + 19]; - w_shift5 += pp[32 + 20]; - w_shift5 += pp[32 + 21]; - w_shift5 += pp[32 + 22]; - w_shift5 += pp[32 + 23]; - w_shift6 += pp[32 + 24]; - w_shift6 += pp[32 + 25]; - w_shift6 += pp[32 + 26]; - w_shift6 += pp[32 + 27]; - w_shift7 += pp[32 + 28]; - w_shift7 += pp[32 + 29]; - w_shift7 += pp[32 + 30]; - w_shift7 += pp[32 + 31]; pp += 64; p0 += A_hstep * 8; } if (max_kk >= 4) { - ((int*)pp)[0] = w_shift0 * 127; - ((int*)pp)[1] = w_shift1 * 127; - ((int*)pp)[2] = w_shift2 * 127; - ((int*)pp)[3] = w_shift3 * 127; - ((int*)pp)[4] = w_shift4 * 127; - ((int*)pp)[5] = w_shift5 * 127; - ((int*)pp)[6] = w_shift6 * 127; - ((int*)pp)[7] = w_shift7 * 127; + _mm256_storeu_si256((__m256i*)pp, _w_shift); pp += 32; } #else // __AVX512VNNI__ || __AVXVNNI__ for (; kk + 7 < max_kk; kk += 8) { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[1] * scale0); - pp[2] = float2int8(p0[8] * scale1); - pp[3] = float2int8(p0[9] * scale1); - pp[4] = float2int8(p0[16] * scale2); - pp[5] = float2int8(p0[17] * scale2); - pp[6] = float2int8(p0[24] * scale3); - pp[7] = float2int8(p0[25] * scale3); -#if __AVX2__ - pp[8] = float2int8(p0[32] * scale4); - pp[9] = float2int8(p0[33] * scale4); - pp[10] = float2int8(p0[40] * scale5); - pp[11] = float2int8(p0[41] * scale5); - pp[12] = float2int8(p0[48] * scale6); - pp[13] = float2int8(p0[49] * scale6); - pp[14] = float2int8(p0[56] * scale7); - pp[15] = float2int8(p0[57] * scale7); - pp += 16; -#else - pp1[0] = float2int8(p0[32] * scale4); - pp1[1] = float2int8(p0[33] * scale4); - pp1[2] = float2int8(p0[40] * scale5); - pp1[3] = float2int8(p0[41] * scale5); - pp1[4] = float2int8(p0[48] * scale6); - pp1[5] = float2int8(p0[49] * scale6); - pp1[6] = float2int8(p0[56] * scale7); - pp1[7] = float2int8(p0[57] * scale7); - pp += 8; - pp1 += 8; -#endif + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + 8); + __m256 _p2 = _mm256_loadu_ps(p0 + 16); + __m256 _p3 = _mm256_loadu_ps(p0 + 24); + __m256 _p4 = _mm256_loadu_ps(p0 + 32); + __m256 _p5 = _mm256_loadu_ps(p0 + 40); + __m256 _p6 = _mm256_loadu_ps(p0 + 48); + __m256 _p7 = _mm256_loadu_ps(p0 + 56); + + _p0 = _mm256_mul_ps(_p0, _mm256_set1_ps(scales[i + ii])); + _p1 = _mm256_mul_ps(_p1, _mm256_set1_ps(scales[i + ii + 1])); + _p2 = _mm256_mul_ps(_p2, _mm256_set1_ps(scales[i + ii + 2])); + _p3 = _mm256_mul_ps(_p3, _mm256_set1_ps(scales[i + ii + 3])); + _p4 = _mm256_mul_ps(_p4, _mm256_set1_ps(scales[i + ii + 4])); + _p5 = _mm256_mul_ps(_p5, _mm256_set1_ps(scales[i + ii + 5])); + _p6 = _mm256_mul_ps(_p6, _mm256_set1_ps(scales[i + ii + 6])); + _p7 = _mm256_mul_ps(_p7, _mm256_set1_ps(scales[i + ii + 7])); - pp[0] = float2int8(p0[2] * scale0); - pp[1] = float2int8(p0[3] * scale0); - pp[2] = float2int8(p0[10] * scale1); - pp[3] = float2int8(p0[11] * scale1); - pp[4] = float2int8(p0[18] * scale2); - pp[5] = float2int8(p0[19] * scale2); - pp[6] = float2int8(p0[26] * scale3); - pp[7] = float2int8(p0[27] * scale3); -#if __AVX2__ - pp[8] = float2int8(p0[34] * scale4); - pp[9] = float2int8(p0[35] * scale4); - pp[10] = float2int8(p0[42] * scale5); - pp[11] = float2int8(p0[43] * scale5); - pp[12] = float2int8(p0[50] * scale6); - pp[13] = float2int8(p0[51] * scale6); - pp[14] = float2int8(p0[58] * scale7); - pp[15] = float2int8(p0[59] * scale7); - pp += 16; -#else - pp1[0] = float2int8(p0[34] * scale4); - pp1[1] = float2int8(p0[35] * scale4); - pp1[2] = float2int8(p0[42] * scale5); - pp1[3] = float2int8(p0[43] * scale5); - pp1[4] = float2int8(p0[50] * scale6); - pp1[5] = float2int8(p0[51] * scale6); - pp1[6] = float2int8(p0[58] * scale7); - pp1[7] = float2int8(p0[59] * scale7); - pp += 8; - pp1 += 8; -#endif + __m128i _pp0 = float2int8_avx(_p0, _p2); + __m128i _pp1 = float2int8_avx(_p1, _p3); + __m128i _pp2 = float2int8_avx(_p4, _p6); + __m128i _pp3 = float2int8_avx(_p5, _p7); - pp[0] = float2int8(p0[4] * scale0); - pp[1] = float2int8(p0[5] * scale0); - pp[2] = float2int8(p0[12] * scale1); - pp[3] = float2int8(p0[13] * scale1); - pp[4] = float2int8(p0[20] * scale2); - pp[5] = float2int8(p0[21] * scale2); - pp[6] = float2int8(p0[28] * scale3); - pp[7] = float2int8(p0[29] * scale3); #if __AVX2__ - pp[8] = float2int8(p0[36] * scale4); - pp[9] = float2int8(p0[37] * scale4); - pp[10] = float2int8(p0[44] * scale5); - pp[11] = float2int8(p0[45] * scale5); - pp[12] = float2int8(p0[52] * scale6); - pp[13] = float2int8(p0[53] * scale6); - pp[14] = float2int8(p0[60] * scale7); - pp[15] = float2int8(p0[61] * scale7); - pp += 16; -#else - pp1[0] = float2int8(p0[36] * scale4); - pp1[1] = float2int8(p0[37] * scale4); - pp1[2] = float2int8(p0[44] * scale5); - pp1[3] = float2int8(p0[45] * scale5); - pp1[4] = float2int8(p0[52] * scale6); - pp1[5] = float2int8(p0[53] * scale6); - pp1[6] = float2int8(p0[60] * scale7); - pp1[7] = float2int8(p0[61] * scale7); - pp += 8; - pp1 += 8; -#endif + __m256i _t0 = combine4x2_epi32(_pp0, _pp2); + __m256i _t1 = combine4x2_epi32(_pp1, _pp3); - pp[0] = float2int8(p0[6] * scale0); - pp[1] = float2int8(p0[7] * scale0); - pp[2] = float2int8(p0[14] * scale1); - pp[3] = float2int8(p0[15] * scale1); - pp[4] = float2int8(p0[22] * scale2); - pp[5] = float2int8(p0[23] * scale2); - pp[6] = float2int8(p0[30] * scale3); - pp[7] = float2int8(p0[31] * scale3); -#if __AVX2__ - pp[8] = float2int8(p0[38] * scale4); - pp[9] = float2int8(p0[39] * scale4); - pp[10] = float2int8(p0[46] * scale5); - pp[11] = float2int8(p0[47] * scale5); - pp[12] = float2int8(p0[54] * scale6); - pp[13] = float2int8(p0[55] * scale6); - pp[14] = float2int8(p0[62] * scale7); - pp[15] = float2int8(p0[63] * scale7); - pp += 16; + __m256i _t2 = _mm256_unpacklo_epi16(_t0, _t1); + __m256i _t3 = _mm256_unpackhi_epi16(_t0, _t1); + _t0 = _mm256_unpacklo_epi32(_t2, _t3); + _t1 = _mm256_unpackhi_epi32(_t2, _t3); + + _t0 = _mm256_permute4x64_epi64(_t0, _MM_SHUFFLE(3, 1, 2, 0)); + _t1 = _mm256_permute4x64_epi64(_t1, _MM_SHUFFLE(3, 1, 2, 0)); + + _mm256_storeu_si256((__m256i*)pp, _t0); + _mm256_storeu_si256((__m256i*)(pp + 32), _t1); + pp += 64; #else - pp1[0] = float2int8(p0[38] * scale4); - pp1[1] = float2int8(p0[39] * scale4); - pp1[2] = float2int8(p0[46] * scale5); - pp1[3] = float2int8(p0[47] * scale5); - pp1[4] = float2int8(p0[54] * scale6); - pp1[5] = float2int8(p0[55] * scale6); - pp1[6] = float2int8(p0[62] * scale7); - pp1[7] = float2int8(p0[63] * scale7); - pp += 8; - pp1 += 8; + __m128i _t0 = _mm_unpacklo_epi16(_pp0, _pp1); + __m128i _t1 = _mm_unpackhi_epi16(_pp0, _pp1); + __m128i _t2 = _mm_unpacklo_epi16(_pp2, _pp3); + __m128i _t3 = _mm_unpackhi_epi16(_pp2, _pp3); + _pp0 = _mm_unpacklo_epi16(_t0, _t1); + _pp1 = _mm_unpackhi_epi16(_t0, _t1); + _pp2 = _mm_unpacklo_epi16(_t2, _t3); + _pp3 = _mm_unpackhi_epi16(_t2, _t3); + + __m256i _t4 = combine4x2_epi32(_pp0, _pp1); + __m256i _t5 = combine4x2_epi32(_pp2, _pp3); + + _mm256_storeu_si256((__m256i*)pp, _t4); + _mm256_storeu_si256((__m256i*)pp1, _t5); + pp += 32; + pp1 += 32; #endif - p0 += A_hstep * 8; } #endif // __AVX512VNNI__ || __AVXVNNI__ } if (elempack == 4) { + __m256 _scales0 = _scales; + __m256 _scales1 = _scales; + __m256 _scales2 = _scales; + __m256 _scales3 = _scales; + transpose8x4_ps(_scales0, _scales1, _scales2, _scales3); + int kk = 0; #if __AVX512VNNI__ || __AVXVNNI__ - int w_shift0 = 0; - int w_shift1 = 0; - int w_shift2 = 0; - int w_shift3 = 0; - int w_shift4 = 0; - int w_shift5 = 0; - int w_shift6 = 0; - int w_shift7 = 0; + __m256i _w_shift = _mm256_setzero_si256(); for (; kk + 3 < max_kk; kk += 4) { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[1] * scale0); - pp[2] = float2int8(p0[2] * scale0); - pp[3] = float2int8(p0[3] * scale0); - pp[4] = float2int8(p0[4] * scale1); - pp[5] = float2int8(p0[5] * scale1); - pp[6] = float2int8(p0[6] * scale1); - pp[7] = float2int8(p0[7] * scale1); - pp[8] = float2int8(p0[8] * scale2); - pp[9] = float2int8(p0[9] * scale2); - pp[10] = float2int8(p0[10] * scale2); - pp[11] = float2int8(p0[11] * scale2); - pp[12] = float2int8(p0[12] * scale3); - pp[13] = float2int8(p0[13] * scale3); - pp[14] = float2int8(p0[14] * scale3); - pp[15] = float2int8(p0[15] * scale3); - pp[16] = float2int8(p0[16] * scale4); - pp[17] = float2int8(p0[17] * scale4); - pp[18] = float2int8(p0[18] * scale4); - pp[19] = float2int8(p0[19] * scale4); - pp[20] = float2int8(p0[20] * scale5); - pp[21] = float2int8(p0[21] * scale5); - pp[22] = float2int8(p0[22] * scale5); - pp[23] = float2int8(p0[23] * scale5); - pp[24] = float2int8(p0[24] * scale6); - pp[25] = float2int8(p0[25] * scale6); - pp[26] = float2int8(p0[26] * scale6); - pp[27] = float2int8(p0[27] * scale6); - pp[28] = float2int8(p0[28] * scale7); - pp[29] = float2int8(p0[29] * scale7); - pp[30] = float2int8(p0[30] * scale7); - pp[31] = float2int8(p0[31] * scale7); - w_shift0 += pp[0]; - w_shift0 += pp[1]; - w_shift0 += pp[2]; - w_shift0 += pp[3]; - w_shift1 += pp[4]; - w_shift1 += pp[5]; - w_shift1 += pp[6]; - w_shift1 += pp[7]; - w_shift2 += pp[8]; - w_shift2 += pp[9]; - w_shift2 += pp[10]; - w_shift2 += pp[11]; - w_shift3 += pp[12]; - w_shift3 += pp[13]; - w_shift3 += pp[14]; - w_shift3 += pp[15]; - w_shift4 += pp[16]; - w_shift4 += pp[17]; - w_shift4 += pp[18]; - w_shift4 += pp[19]; - w_shift5 += pp[20]; - w_shift5 += pp[21]; - w_shift5 += pp[22]; - w_shift5 += pp[23]; - w_shift6 += pp[24]; - w_shift6 += pp[25]; - w_shift6 += pp[26]; - w_shift6 += pp[27]; - w_shift7 += pp[28]; - w_shift7 += pp[29]; - w_shift7 += pp[30]; - w_shift7 += pp[31]; + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + 8); + __m256 _p2 = _mm256_loadu_ps(p0 + 16); + __m256 _p3 = _mm256_loadu_ps(p0 + 24); + + _p0 = _mm256_mul_ps(_p0, _scales0); + _p1 = _mm256_mul_ps(_p1, _scales1); + _p2 = _mm256_mul_ps(_p2, _scales2); + _p3 = _mm256_mul_ps(_p3, _scales3); + + __m128i _pp0 = float2int8_avx(_p0, _p1); + __m128i _pp1 = float2int8_avx(_p2, _p3); + + __m256i _pp = combine4x2_epi32(_pp0, _pp1); + + _w_shift = _mm256_comp_dpbusd_epi32(_w_shift, _v127, _pp); + + _mm256_storeu_si256((__m256i*)pp, _pp); + pp += 32; p0 += A_hstep * 4; } if (max_kk >= 4) { - ((int*)pp)[0] = w_shift0 * 127; - ((int*)pp)[1] = w_shift1 * 127; - ((int*)pp)[2] = w_shift2 * 127; - ((int*)pp)[3] = w_shift3 * 127; - ((int*)pp)[4] = w_shift4 * 127; - ((int*)pp)[5] = w_shift5 * 127; - ((int*)pp)[6] = w_shift6 * 127; - ((int*)pp)[7] = w_shift7 * 127; + _mm256_storeu_si256((__m256i*)pp, _w_shift); pp += 32; } #else // __AVX512VNNI__ || __AVXVNNI__ for (; kk + 3 < max_kk; kk += 4) { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[1] * scale0); - pp[2] = float2int8(p0[4] * scale1); - pp[3] = float2int8(p0[5] * scale1); - pp[4] = float2int8(p0[8] * scale2); - pp[5] = float2int8(p0[9] * scale2); - pp[6] = float2int8(p0[12] * scale3); - pp[7] = float2int8(p0[13] * scale3); + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + 8); + __m256 _p2 = _mm256_loadu_ps(p0 + 16); + __m256 _p3 = _mm256_loadu_ps(p0 + 24); + + _p0 = _mm256_mul_ps(_p0, _scales0); + _p1 = _mm256_mul_ps(_p1, _scales1); + _p2 = _mm256_mul_ps(_p2, _scales2); + _p3 = _mm256_mul_ps(_p3, _scales3); + + __m128i _pp0 = float2int8_avx(_p0, _p1); + __m128i _pp1 = float2int8_avx(_p2, _p3); + #if __AVX2__ - pp[8] = float2int8(p0[16] * scale4); - pp[9] = float2int8(p0[17] * scale4); - pp[10] = float2int8(p0[20] * scale5); - pp[11] = float2int8(p0[21] * scale5); - pp[12] = float2int8(p0[24] * scale6); - pp[13] = float2int8(p0[25] * scale6); - pp[14] = float2int8(p0[28] * scale7); - pp[15] = float2int8(p0[29] * scale7); - pp += 16; + __m128i _t0 = _mm_unpacklo_epi16(_pp0, _pp1); + __m128i _t1 = _mm_unpackhi_epi16(_pp0, _pp1); + __m128i _t2 = _mm_unpacklo_epi16(_t0, _t1); + __m128i _t3 = _mm_unpackhi_epi16(_t0, _t1); + _t0 = _mm_unpacklo_epi16(_t2, _t3); + _t1 = _mm_unpackhi_epi16(_t2, _t3); + + _mm_storeu_si128((__m128i*)pp, _t0); + _mm_storeu_si128((__m128i*)(pp + 16), _t1); + pp += 32; #else - pp1[0] = float2int8(p0[16] * scale4); - pp1[1] = float2int8(p0[17] * scale4); - pp1[2] = float2int8(p0[20] * scale5); - pp1[3] = float2int8(p0[21] * scale5); - pp1[4] = float2int8(p0[24] * scale6); - pp1[5] = float2int8(p0[25] * scale6); - pp1[6] = float2int8(p0[28] * scale7); - pp1[7] = float2int8(p0[29] * scale7); - pp += 8; - pp1 += 8; -#endif + __m128i _si = _mm_setr_epi8(0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15); + __m128i _t0 = _mm_shuffle_epi8(_pp0, _si); + __m128i _t1 = _mm_shuffle_epi8(_pp1, _si); - pp[0] = float2int8(p0[2] * scale0); - pp[1] = float2int8(p0[3] * scale0); - pp[2] = float2int8(p0[6] * scale1); - pp[3] = float2int8(p0[7] * scale1); - pp[4] = float2int8(p0[10] * scale2); - pp[5] = float2int8(p0[11] * scale2); - pp[6] = float2int8(p0[14] * scale3); - pp[7] = float2int8(p0[15] * scale3); -#if __AVX2__ - pp[8] = float2int8(p0[18] * scale4); - pp[9] = float2int8(p0[19] * scale4); - pp[10] = float2int8(p0[22] * scale5); - pp[11] = float2int8(p0[23] * scale5); - pp[12] = float2int8(p0[26] * scale6); - pp[13] = float2int8(p0[27] * scale6); - pp[14] = float2int8(p0[30] * scale7); - pp[15] = float2int8(p0[31] * scale7); + _mm_storeu_si128((__m128i*)pp, _t0); + _mm_storeu_si128((__m128i*)pp1, _t1); pp += 16; -#else - pp1[0] = float2int8(p0[18] * scale4); - pp1[1] = float2int8(p0[19] * scale4); - pp1[2] = float2int8(p0[22] * scale5); - pp1[3] = float2int8(p0[23] * scale5); - pp1[4] = float2int8(p0[26] * scale6); - pp1[5] = float2int8(p0[27] * scale6); - pp1[6] = float2int8(p0[30] * scale7); - pp1[7] = float2int8(p0[31] * scale7); - pp += 8; - pp1 += 8; + pp1 += 16; #endif p0 += A_hstep * 4; } @@ -5495,125 +4897,61 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int { int kk = 0; #if __AVX512VNNI__ || __AVXVNNI__ - int w_shift0 = 0; - int w_shift1 = 0; - int w_shift2 = 0; - int w_shift3 = 0; - int w_shift4 = 0; - int w_shift5 = 0; - int w_shift6 = 0; - int w_shift7 = 0; + __m256i _w_shift = _mm256_setzero_si256(); for (; kk + 3 < max_kk; kk += 4) { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[A_hstep] * scale0); - pp[2] = float2int8(p0[A_hstep * 2] * scale0); - pp[3] = float2int8(p0[A_hstep * 3] * scale0); - pp[4] = float2int8(p0[1] * scale1); - pp[5] = float2int8(p0[A_hstep + 1] * scale1); - pp[6] = float2int8(p0[A_hstep * 2 + 1] * scale1); - pp[7] = float2int8(p0[A_hstep * 3 + 1] * scale1); - pp[8] = float2int8(p0[2] * scale2); - pp[9] = float2int8(p0[A_hstep + 2] * scale2); - pp[10] = float2int8(p0[A_hstep * 2 + 2] * scale2); - pp[11] = float2int8(p0[A_hstep * 3 + 2] * scale2); - pp[12] = float2int8(p0[3] * scale3); - pp[13] = float2int8(p0[A_hstep + 3] * scale3); - pp[14] = float2int8(p0[A_hstep * 2 + 3] * scale3); - pp[15] = float2int8(p0[A_hstep * 3 + 3] * scale3); - pp[16] = float2int8(p0[4] * scale4); - pp[17] = float2int8(p0[A_hstep + 4] * scale4); - pp[18] = float2int8(p0[A_hstep * 2 + 4] * scale4); - pp[19] = float2int8(p0[A_hstep * 3 + 4] * scale4); - pp[20] = float2int8(p0[5] * scale5); - pp[21] = float2int8(p0[A_hstep + 5] * scale5); - pp[22] = float2int8(p0[A_hstep * 2 + 5] * scale5); - pp[23] = float2int8(p0[A_hstep * 3 + 5] * scale5); - pp[24] = float2int8(p0[6] * scale6); - pp[25] = float2int8(p0[A_hstep + 6] * scale6); - pp[26] = float2int8(p0[A_hstep * 2 + 6] * scale6); - pp[27] = float2int8(p0[A_hstep * 3 + 6] * scale6); - pp[28] = float2int8(p0[7] * scale7); - pp[29] = float2int8(p0[A_hstep + 7] * scale7); - pp[30] = float2int8(p0[A_hstep * 2 + 7] * scale7); - pp[31] = float2int8(p0[A_hstep * 3 + 7] * scale7); - w_shift0 += pp[0]; - w_shift0 += pp[1]; - w_shift0 += pp[2]; - w_shift0 += pp[3]; - w_shift1 += pp[4]; - w_shift1 += pp[5]; - w_shift1 += pp[6]; - w_shift1 += pp[7]; - w_shift2 += pp[8]; - w_shift2 += pp[9]; - w_shift2 += pp[10]; - w_shift2 += pp[11]; - w_shift3 += pp[12]; - w_shift3 += pp[13]; - w_shift3 += pp[14]; - w_shift3 += pp[15]; - w_shift4 += pp[16]; - w_shift4 += pp[17]; - w_shift4 += pp[18]; - w_shift4 += pp[19]; - w_shift5 += pp[20]; - w_shift5 += pp[21]; - w_shift5 += pp[22]; - w_shift5 += pp[23]; - w_shift6 += pp[24]; - w_shift6 += pp[25]; - w_shift6 += pp[26]; - w_shift6 += pp[27]; - w_shift7 += pp[28]; - w_shift7 += pp[29]; - w_shift7 += pp[30]; - w_shift7 += pp[31]; + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + A_hstep); + __m256 _p2 = _mm256_loadu_ps(p0 + A_hstep * 2); + __m256 _p3 = _mm256_loadu_ps(p0 + A_hstep * 3); + + _p0 = _mm256_mul_ps(_p0, _scales); + _p1 = _mm256_mul_ps(_p1, _scales); + _p2 = _mm256_mul_ps(_p2, _scales); + _p3 = _mm256_mul_ps(_p3, _scales); + + __m128i _pp0 = float2int8_avx(_p0, _p2); + __m128i _pp1 = float2int8_avx(_p1, _p3); + + __m128i _tt0 = _mm_unpacklo_epi8(_pp0, _pp1); + __m128i _tt1 = _mm_unpackhi_epi8(_pp0, _pp1); + _pp0 = _mm_unpacklo_epi16(_tt0, _tt1); + _pp1 = _mm_unpackhi_epi16(_tt0, _tt1); + + __m256i _pp = combine4x2_epi32(_pp0, _pp1); + + _w_shift = _mm256_comp_dpbusd_epi32(_w_shift, _v127, _pp); + + _mm256_storeu_si256((__m256i*)pp, _pp); + pp += 32; p0 += A_hstep * 4; } if (max_kk >= 4) { - ((int*)pp)[0] = w_shift0 * 127; - ((int*)pp)[1] = w_shift1 * 127; - ((int*)pp)[2] = w_shift2 * 127; - ((int*)pp)[3] = w_shift3 * 127; - ((int*)pp)[4] = w_shift4 * 127; - ((int*)pp)[5] = w_shift5 * 127; - ((int*)pp)[6] = w_shift6 * 127; - ((int*)pp)[7] = w_shift7 * 127; + _mm256_storeu_si256((__m256i*)pp, _w_shift); pp += 32; } #endif // __AVX512VNNI__ || __AVXVNNI__ for (; kk + 1 < max_kk; kk += 2) { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[A_hstep] * scale0); - pp[2] = float2int8(p0[1] * scale1); - pp[3] = float2int8(p0[A_hstep + 1] * scale1); - pp[4] = float2int8(p0[2] * scale2); - pp[5] = float2int8(p0[A_hstep + 2] * scale2); - pp[6] = float2int8(p0[3] * scale3); - pp[7] = float2int8(p0[A_hstep + 3] * scale3); + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + A_hstep); + + _p0 = _mm256_mul_ps(_p0, _scales); + _p1 = _mm256_mul_ps(_p1, _scales); + + __m128i _pp = float2int8_avx(_p0, _p1); + + __m128i _si = _mm_setr_epi8(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7 ,15); + _pp = _mm_shuffle_epi8(_pp, _si); + #if __AVX2__ - pp[8] = float2int8(p0[4] * scale4); - pp[9] = float2int8(p0[A_hstep + 4] * scale4); - pp[10] = float2int8(p0[5] * scale5); - pp[11] = float2int8(p0[A_hstep + 5] * scale5); - pp[12] = float2int8(p0[6] * scale6); - pp[13] = float2int8(p0[A_hstep + 6] * scale6); - pp[14] = float2int8(p0[7] * scale7); - pp[15] = float2int8(p0[A_hstep + 7] * scale7); + _mm_storeu_si128((__m128i*)pp, _pp); pp += 16; #else - pp1[0] = float2int8(p0[4] * scale4); - pp1[1] = float2int8(p0[A_hstep + 4] * scale4); - pp1[2] = float2int8(p0[5] * scale5); - pp1[3] = float2int8(p0[A_hstep + 5] * scale5); - pp1[4] = float2int8(p0[6] * scale6); - pp1[5] = float2int8(p0[A_hstep + 6] * scale6); - pp1[6] = float2int8(p0[7] * scale7); - pp1[7] = float2int8(p0[A_hstep + 7] * scale7); + _mm_storel_pd((double*)pp, _mm_castsi128_pd(_pp)); + _mm_storeh_pd((double*)pp1, _mm_castsi128_pd(_pp)); pp += 8; pp1 += 8; #endif @@ -5621,21 +4959,18 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int } for (; kk < max_kk; kk++) { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[1] * scale1); - pp[2] = float2int8(p0[2] * scale2); - pp[3] = float2int8(p0[3] * scale3); + __m256 _p = _mm256_loadu_ps(p0); + + _p = _mm256_mul_ps(_p, _scales); + + int64_t v = float2int8_avx(_p); + #if __AVX2__ - pp[4] = float2int8(p0[4] * scale4); - pp[5] = float2int8(p0[5] * scale5); - pp[6] = float2int8(p0[6] * scale6); - pp[7] = float2int8(p0[7] * scale7); + *(int64_t*)pp = v; pp += 8; #else - pp1[0] = float2int8(p0[4] * scale4); - pp1[1] = float2int8(p0[5] * scale5); - pp1[2] = float2int8(p0[6] * scale6); - pp1[3] = float2int8(p0[7] * scale7); + *(int32_t*)pp = (int32_t)v; + *(int32_t*)pp1 = (int32_t)(v >> 32); pp += 4; pp1 += 4; #endif @@ -8189,287 +7524,104 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int { const float* p0 = (const float*)B + k * B_hstep + (j + jj) * elempack; +#if __AVX__ + __m256 _scale = _mm256_set1_ps(scale); +#if __AVX512VNNI__ || __AVXVNNI__ + __m256i _v127 = _mm256_set1_epi8(127); +#endif // __AVX512VNNI__ || __AVXVNNI__ +#else + __m128 _scale = _mm_set1_ps(scale); +#endif + #if __AVX__ #if __AVX512F__ if (elempack == 16) { + __m512 _scale_avx512 = _mm512_set1_ps(scale); + int kk = 0; #if __AVX512VNNI__ + __m512i _v127_avx512 = _mm512_set1_epi8(127); for (; kk + 15 < max_kk; kk += 16) { - pp[0] = float2int8(p0[0] * scale) + 127; - pp[1] = float2int8(p0[1] * scale) + 127; - pp[2] = float2int8(p0[2 + 0] * scale) + 127; - pp[3] = float2int8(p0[2 + 1] * scale) + 127; - pp[4] = float2int8(p0[16] * scale) + 127; - pp[5] = float2int8(p0[17] * scale) + 127; - pp[6] = float2int8(p0[2 + 16] * scale) + 127; - pp[7] = float2int8(p0[2 + 17] * scale) + 127; - pp[8] = float2int8(p0[32] * scale) + 127; - pp[9] = float2int8(p0[33] * scale) + 127; - pp[10] = float2int8(p0[2 + 32] * scale) + 127; - pp[11] = float2int8(p0[2 + 33] * scale) + 127; - pp[12] = float2int8(p0[48] * scale) + 127; - pp[13] = float2int8(p0[49] * scale) + 127; - pp[14] = float2int8(p0[2 + 48] * scale) + 127; - pp[15] = float2int8(p0[2 + 49] * scale) + 127; - pp[16] = float2int8(p0[64] * scale) + 127; - pp[17] = float2int8(p0[65] * scale) + 127; - pp[18] = float2int8(p0[2 + 64] * scale) + 127; - pp[19] = float2int8(p0[2 + 65] * scale) + 127; - pp[20] = float2int8(p0[80] * scale) + 127; - pp[21] = float2int8(p0[81] * scale) + 127; - pp[22] = float2int8(p0[2 + 80] * scale) + 127; - pp[23] = float2int8(p0[2 + 81] * scale) + 127; - pp[24] = float2int8(p0[96] * scale) + 127; - pp[25] = float2int8(p0[97] * scale) + 127; - pp[26] = float2int8(p0[2 + 96] * scale) + 127; - pp[27] = float2int8(p0[2 + 97] * scale) + 127; - pp[28] = float2int8(p0[112] * scale) + 127; - pp[29] = float2int8(p0[113] * scale) + 127; - pp[30] = float2int8(p0[2 + 112] * scale) + 127; - pp[31] = float2int8(p0[2 + 113] * scale) + 127; - - pp[32 + 0] = float2int8(p0[4 + 0] * scale) + 127; - pp[32 + 1] = float2int8(p0[4 + 1] * scale) + 127; - pp[32 + 2] = float2int8(p0[6 + 0] * scale) + 127; - pp[32 + 3] = float2int8(p0[6 + 1] * scale) + 127; - pp[32 + 4] = float2int8(p0[4 + 16] * scale) + 127; - pp[32 + 5] = float2int8(p0[4 + 17] * scale) + 127; - pp[32 + 6] = float2int8(p0[6 + 16] * scale) + 127; - pp[32 + 7] = float2int8(p0[6 + 17] * scale) + 127; - pp[32 + 8] = float2int8(p0[4 + 32] * scale) + 127; - pp[32 + 9] = float2int8(p0[4 + 33] * scale) + 127; - pp[32 + 10] = float2int8(p0[6 + 32] * scale) + 127; - pp[32 + 11] = float2int8(p0[6 + 33] * scale) + 127; - pp[32 + 12] = float2int8(p0[4 + 48] * scale) + 127; - pp[32 + 13] = float2int8(p0[4 + 49] * scale) + 127; - pp[32 + 14] = float2int8(p0[6 + 48] * scale) + 127; - pp[32 + 15] = float2int8(p0[6 + 49] * scale) + 127; - pp[32 + 16] = float2int8(p0[4 + 64] * scale) + 127; - pp[32 + 17] = float2int8(p0[4 + 65] * scale) + 127; - pp[32 + 18] = float2int8(p0[6 + 64] * scale) + 127; - pp[32 + 19] = float2int8(p0[6 + 65] * scale) + 127; - pp[32 + 20] = float2int8(p0[4 + 80] * scale) + 127; - pp[32 + 21] = float2int8(p0[4 + 81] * scale) + 127; - pp[32 + 22] = float2int8(p0[6 + 80] * scale) + 127; - pp[32 + 23] = float2int8(p0[6 + 81] * scale) + 127; - pp[32 + 24] = float2int8(p0[4 + 96] * scale) + 127; - pp[32 + 25] = float2int8(p0[4 + 97] * scale) + 127; - pp[32 + 26] = float2int8(p0[6 + 96] * scale) + 127; - pp[32 + 27] = float2int8(p0[6 + 97] * scale) + 127; - pp[32 + 28] = float2int8(p0[4 + 112] * scale) + 127; - pp[32 + 29] = float2int8(p0[4 + 113] * scale) + 127; - pp[32 + 30] = float2int8(p0[6 + 112] * scale) + 127; - pp[32 + 31] = float2int8(p0[6 + 113] * scale) + 127; - - pp[64 + 0] = float2int8(p0[8 + 0] * scale) + 127; - pp[64 + 1] = float2int8(p0[8 + 1] * scale) + 127; - pp[64 + 2] = float2int8(p0[10 + 0] * scale) + 127; - pp[64 + 3] = float2int8(p0[10 + 1] * scale) + 127; - pp[64 + 4] = float2int8(p0[8 + 16] * scale) + 127; - pp[64 + 5] = float2int8(p0[8 + 17] * scale) + 127; - pp[64 + 6] = float2int8(p0[10 + 16] * scale) + 127; - pp[64 + 7] = float2int8(p0[10 + 17] * scale) + 127; - pp[64 + 8] = float2int8(p0[8 + 32] * scale) + 127; - pp[64 + 9] = float2int8(p0[8 + 33] * scale) + 127; - pp[64 + 10] = float2int8(p0[10 + 32] * scale) + 127; - pp[64 + 11] = float2int8(p0[10 + 33] * scale) + 127; - pp[64 + 12] = float2int8(p0[8 + 48] * scale) + 127; - pp[64 + 13] = float2int8(p0[8 + 49] * scale) + 127; - pp[64 + 14] = float2int8(p0[10 + 48] * scale) + 127; - pp[64 + 15] = float2int8(p0[10 + 49] * scale) + 127; - pp[64 + 16] = float2int8(p0[8 + 64] * scale) + 127; - pp[64 + 17] = float2int8(p0[8 + 65] * scale) + 127; - pp[64 + 18] = float2int8(p0[10 + 64] * scale) + 127; - pp[64 + 19] = float2int8(p0[10 + 65] * scale) + 127; - pp[64 + 20] = float2int8(p0[8 + 80] * scale) + 127; - pp[64 + 21] = float2int8(p0[8 + 81] * scale) + 127; - pp[64 + 22] = float2int8(p0[10 + 80] * scale) + 127; - pp[64 + 23] = float2int8(p0[10 + 81] * scale) + 127; - pp[64 + 24] = float2int8(p0[8 + 96] * scale) + 127; - pp[64 + 25] = float2int8(p0[8 + 97] * scale) + 127; - pp[64 + 26] = float2int8(p0[10 + 96] * scale) + 127; - pp[64 + 27] = float2int8(p0[10 + 97] * scale) + 127; - pp[64 + 28] = float2int8(p0[8 + 112] * scale) + 127; - pp[64 + 29] = float2int8(p0[8 + 113] * scale) + 127; - pp[64 + 30] = float2int8(p0[10 + 112] * scale) + 127; - pp[64 + 31] = float2int8(p0[10 + 113] * scale) + 127; - - pp[96 + 0] = float2int8(p0[12 + 0] * scale) + 127; - pp[96 + 1] = float2int8(p0[12 + 1] * scale) + 127; - pp[96 + 2] = float2int8(p0[14 + 0] * scale) + 127; - pp[96 + 3] = float2int8(p0[14 + 1] * scale) + 127; - pp[96 + 4] = float2int8(p0[12 + 16] * scale) + 127; - pp[96 + 5] = float2int8(p0[12 + 17] * scale) + 127; - pp[96 + 6] = float2int8(p0[14 + 16] * scale) + 127; - pp[96 + 7] = float2int8(p0[14 + 17] * scale) + 127; - pp[96 + 8] = float2int8(p0[12 + 32] * scale) + 127; - pp[96 + 9] = float2int8(p0[12 + 33] * scale) + 127; - pp[96 + 10] = float2int8(p0[14 + 32] * scale) + 127; - pp[96 + 11] = float2int8(p0[14 + 33] * scale) + 127; - pp[96 + 12] = float2int8(p0[12 + 48] * scale) + 127; - pp[96 + 13] = float2int8(p0[12 + 49] * scale) + 127; - pp[96 + 14] = float2int8(p0[14 + 48] * scale) + 127; - pp[96 + 15] = float2int8(p0[14 + 49] * scale) + 127; - pp[96 + 16] = float2int8(p0[12 + 64] * scale) + 127; - pp[96 + 17] = float2int8(p0[12 + 65] * scale) + 127; - pp[96 + 18] = float2int8(p0[14 + 64] * scale) + 127; - pp[96 + 19] = float2int8(p0[14 + 65] * scale) + 127; - pp[96 + 20] = float2int8(p0[12 + 80] * scale) + 127; - pp[96 + 21] = float2int8(p0[12 + 81] * scale) + 127; - pp[96 + 22] = float2int8(p0[14 + 80] * scale) + 127; - pp[96 + 23] = float2int8(p0[14 + 81] * scale) + 127; - pp[96 + 24] = float2int8(p0[12 + 96] * scale) + 127; - pp[96 + 25] = float2int8(p0[12 + 97] * scale) + 127; - pp[96 + 26] = float2int8(p0[14 + 96] * scale) + 127; - pp[96 + 27] = float2int8(p0[14 + 97] * scale) + 127; - pp[96 + 28] = float2int8(p0[12 + 112] * scale) + 127; - pp[96 + 29] = float2int8(p0[12 + 113] * scale) + 127; - pp[96 + 30] = float2int8(p0[14 + 112] * scale) + 127; - pp[96 + 31] = float2int8(p0[14 + 113] * scale) + 127; + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + 16); + __m512 _p2 = _mm512_loadu_ps(p0 + 32); + __m512 _p3 = _mm512_loadu_ps(p0 + 48); + __m512 _p4 = _mm512_loadu_ps(p0 + 64); + __m512 _p5 = _mm512_loadu_ps(p0 + 80); + __m512 _p6 = _mm512_loadu_ps(p0 + 96); + __m512 _p7 = _mm512_loadu_ps(p0 + 112); + + _p0 = _mm512_mul_ps(_p0, _scale_avx512); + _p1 = _mm512_mul_ps(_p1, _scale_avx512); + _p2 = _mm512_mul_ps(_p2, _scale_avx512); + _p3 = _mm512_mul_ps(_p3, _scale_avx512); + _p4 = _mm512_mul_ps(_p4, _scale_avx512); + _p5 = _mm512_mul_ps(_p5, _scale_avx512); + _p6 = _mm512_mul_ps(_p6, _scale_avx512); + _p7 = _mm512_mul_ps(_p7, _scale_avx512); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + __m128i _pp4 = float2int8_avx512(_p4); + __m128i _pp5 = float2int8_avx512(_p5); + __m128i _pp6 = float2int8_avx512(_p6); + __m128i _pp7 = float2int8_avx512(_p7); + + transpose4x8_epi32(_pp0, _pp1, _pp2, _pp3, _pp4, _pp5, _pp6, _pp7); + + __m512i _t0 = combine4x4_epi32(_pp0, _pp1, _pp2, _pp3); + __m512i _t1 = combine4x4_epi32(_pp4, _pp5, _pp6, _pp7); + + _t0 = _mm512_add_epi8(_t0, _v127_avx512); + _t1 = _mm512_add_epi8(_t1, _v127_avx512); + + _mm512_storeu_si512((__m512i*)pp, _t0); + _mm512_storeu_si512((__m512i*)(pp + 64), _t1); + + pp += 128; + p0 += B_hstep * 16; + } +#else // __AVX512VNNI__ + for (; kk + 15 < max_kk; kk += 16) + { + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + 16); + __m512 _p2 = _mm512_loadu_ps(p0 + 32); + __m512 _p3 = _mm512_loadu_ps(p0 + 48); + __m512 _p4 = _mm512_loadu_ps(p0 + 64); + __m512 _p5 = _mm512_loadu_ps(p0 + 80); + __m512 _p6 = _mm512_loadu_ps(p0 + 96); + __m512 _p7 = _mm512_loadu_ps(p0 + 112); + + _p0 = _mm512_mul_ps(_p0, _scale_avx512); + _p1 = _mm512_mul_ps(_p1, _scale_avx512); + _p2 = _mm512_mul_ps(_p2, _scale_avx512); + _p3 = _mm512_mul_ps(_p3, _scale_avx512); + _p4 = _mm512_mul_ps(_p4, _scale_avx512); + _p5 = _mm512_mul_ps(_p5, _scale_avx512); + _p6 = _mm512_mul_ps(_p6, _scale_avx512); + _p7 = _mm512_mul_ps(_p7, _scale_avx512); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + __m128i _pp4 = float2int8_avx512(_p4); + __m128i _pp5 = float2int8_avx512(_p5); + __m128i _pp6 = float2int8_avx512(_p6); + __m128i _pp7 = float2int8_avx512(_p7); - pp += 128; - p0 += B_hstep * 16; - } -#else // __AVX512VNNI__ - for (; kk + 15 < max_kk; kk += 16) - { - pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[1] * scale); - pp[2] = float2int8(p0[16] * scale); - pp[3] = float2int8(p0[17] * scale); - pp[4] = float2int8(p0[32] * scale); - pp[5] = float2int8(p0[33] * scale); - pp[6] = float2int8(p0[48] * scale); - pp[7] = float2int8(p0[49] * scale); - pp[8] = float2int8(p0[64] * scale); - pp[9] = float2int8(p0[65] * scale); - pp[10] = float2int8(p0[80] * scale); - pp[11] = float2int8(p0[81] * scale); - pp[12] = float2int8(p0[96] * scale); - pp[13] = float2int8(p0[97] * scale); - pp[14] = float2int8(p0[112] * scale); - pp[15] = float2int8(p0[113] * scale); - - pp[16 + 0] = float2int8(p0[2 + 0] * scale); - pp[16 + 1] = float2int8(p0[2 + 1] * scale); - pp[16 + 2] = float2int8(p0[2 + 16] * scale); - pp[16 + 3] = float2int8(p0[2 + 17] * scale); - pp[16 + 4] = float2int8(p0[2 + 32] * scale); - pp[16 + 5] = float2int8(p0[2 + 33] * scale); - pp[16 + 6] = float2int8(p0[2 + 48] * scale); - pp[16 + 7] = float2int8(p0[2 + 49] * scale); - pp[16 + 8] = float2int8(p0[2 + 64] * scale); - pp[16 + 9] = float2int8(p0[2 + 65] * scale); - pp[16 + 10] = float2int8(p0[2 + 80] * scale); - pp[16 + 11] = float2int8(p0[2 + 81] * scale); - pp[16 + 12] = float2int8(p0[2 + 96] * scale); - pp[16 + 13] = float2int8(p0[2 + 97] * scale); - pp[16 + 14] = float2int8(p0[2 + 112] * scale); - pp[16 + 15] = float2int8(p0[2 + 113] * scale); - - pp[32 + 0] = float2int8(p0[4 + 0] * scale); - pp[32 + 1] = float2int8(p0[4 + 1] * scale); - pp[32 + 2] = float2int8(p0[4 + 16] * scale); - pp[32 + 3] = float2int8(p0[4 + 17] * scale); - pp[32 + 4] = float2int8(p0[4 + 32] * scale); - pp[32 + 5] = float2int8(p0[4 + 33] * scale); - pp[32 + 6] = float2int8(p0[4 + 48] * scale); - pp[32 + 7] = float2int8(p0[4 + 49] * scale); - pp[32 + 8] = float2int8(p0[4 + 64] * scale); - pp[32 + 9] = float2int8(p0[4 + 65] * scale); - pp[32 + 10] = float2int8(p0[4 + 80] * scale); - pp[32 + 11] = float2int8(p0[4 + 81] * scale); - pp[32 + 12] = float2int8(p0[4 + 96] * scale); - pp[32 + 13] = float2int8(p0[4 + 97] * scale); - pp[32 + 14] = float2int8(p0[4 + 112] * scale); - pp[32 + 15] = float2int8(p0[4 + 113] * scale); - - pp[48 + 0] = float2int8(p0[6 + 0] * scale); - pp[48 + 1] = float2int8(p0[6 + 1] * scale); - pp[48 + 2] = float2int8(p0[6 + 16] * scale); - pp[48 + 3] = float2int8(p0[6 + 17] * scale); - pp[48 + 4] = float2int8(p0[6 + 32] * scale); - pp[48 + 5] = float2int8(p0[6 + 33] * scale); - pp[48 + 6] = float2int8(p0[6 + 48] * scale); - pp[48 + 7] = float2int8(p0[6 + 49] * scale); - pp[48 + 8] = float2int8(p0[6 + 64] * scale); - pp[48 + 9] = float2int8(p0[6 + 65] * scale); - pp[48 + 10] = float2int8(p0[6 + 80] * scale); - pp[48 + 11] = float2int8(p0[6 + 81] * scale); - pp[48 + 12] = float2int8(p0[6 + 96] * scale); - pp[48 + 13] = float2int8(p0[6 + 97] * scale); - pp[48 + 14] = float2int8(p0[6 + 112] * scale); - pp[48 + 15] = float2int8(p0[6 + 113] * scale); - - pp[64 + 0] = float2int8(p0[8 + 0] * scale); - pp[64 + 1] = float2int8(p0[8 + 1] * scale); - pp[64 + 2] = float2int8(p0[8 + 16] * scale); - pp[64 + 3] = float2int8(p0[8 + 17] * scale); - pp[64 + 4] = float2int8(p0[8 + 32] * scale); - pp[64 + 5] = float2int8(p0[8 + 33] * scale); - pp[64 + 6] = float2int8(p0[8 + 48] * scale); - pp[64 + 7] = float2int8(p0[8 + 49] * scale); - pp[64 + 8] = float2int8(p0[8 + 64] * scale); - pp[64 + 9] = float2int8(p0[8 + 65] * scale); - pp[64 + 10] = float2int8(p0[8 + 80] * scale); - pp[64 + 11] = float2int8(p0[8 + 81] * scale); - pp[64 + 12] = float2int8(p0[8 + 96] * scale); - pp[64 + 13] = float2int8(p0[8 + 97] * scale); - pp[64 + 14] = float2int8(p0[8 + 112] * scale); - pp[64 + 15] = float2int8(p0[8 + 113] * scale); - - pp[80 + 0] = float2int8(p0[10 + 0] * scale); - pp[80 + 1] = float2int8(p0[10 + 1] * scale); - pp[80 + 2] = float2int8(p0[10 + 16] * scale); - pp[80 + 3] = float2int8(p0[10 + 17] * scale); - pp[80 + 4] = float2int8(p0[10 + 32] * scale); - pp[80 + 5] = float2int8(p0[10 + 33] * scale); - pp[80 + 6] = float2int8(p0[10 + 48] * scale); - pp[80 + 7] = float2int8(p0[10 + 49] * scale); - pp[80 + 8] = float2int8(p0[10 + 64] * scale); - pp[80 + 9] = float2int8(p0[10 + 65] * scale); - pp[80 + 10] = float2int8(p0[10 + 80] * scale); - pp[80 + 11] = float2int8(p0[10 + 81] * scale); - pp[80 + 12] = float2int8(p0[10 + 96] * scale); - pp[80 + 13] = float2int8(p0[10 + 97] * scale); - pp[80 + 14] = float2int8(p0[10 + 112] * scale); - pp[80 + 15] = float2int8(p0[10 + 113] * scale); - - pp[96 + 0] = float2int8(p0[12 + 0] * scale); - pp[96 + 1] = float2int8(p0[12 + 1] * scale); - pp[96 + 2] = float2int8(p0[12 + 16] * scale); - pp[96 + 3] = float2int8(p0[12 + 17] * scale); - pp[96 + 4] = float2int8(p0[12 + 32] * scale); - pp[96 + 5] = float2int8(p0[12 + 33] * scale); - pp[96 + 6] = float2int8(p0[12 + 48] * scale); - pp[96 + 7] = float2int8(p0[12 + 49] * scale); - pp[96 + 8] = float2int8(p0[12 + 64] * scale); - pp[96 + 9] = float2int8(p0[12 + 65] * scale); - pp[96 + 10] = float2int8(p0[12 + 80] * scale); - pp[96 + 11] = float2int8(p0[12 + 81] * scale); - pp[96 + 12] = float2int8(p0[12 + 96] * scale); - pp[96 + 13] = float2int8(p0[12 + 97] * scale); - pp[96 + 14] = float2int8(p0[12 + 112] * scale); - pp[96 + 15] = float2int8(p0[12 + 113] * scale); - - pp[112 + 0] = float2int8(p0[14 + 0] * scale); - pp[112 + 1] = float2int8(p0[14 + 1] * scale); - pp[112 + 2] = float2int8(p0[14 + 16] * scale); - pp[112 + 3] = float2int8(p0[14 + 17] * scale); - pp[112 + 4] = float2int8(p0[14 + 32] * scale); - pp[112 + 5] = float2int8(p0[14 + 33] * scale); - pp[112 + 6] = float2int8(p0[14 + 48] * scale); - pp[112 + 7] = float2int8(p0[14 + 49] * scale); - pp[112 + 8] = float2int8(p0[14 + 64] * scale); - pp[112 + 9] = float2int8(p0[14 + 65] * scale); - pp[112 + 10] = float2int8(p0[14 + 80] * scale); - pp[112 + 11] = float2int8(p0[14 + 81] * scale); - pp[112 + 12] = float2int8(p0[14 + 96] * scale); - pp[112 + 13] = float2int8(p0[14 + 97] * scale); - pp[112 + 14] = float2int8(p0[14 + 112] * scale); - pp[112 + 15] = float2int8(p0[14 + 113] * scale); + transpose8x8_epi16(_pp0, _pp1, _pp2, _pp3, _pp4, _pp5, _pp6, _pp7); + + __m512i _t0 = combine4x4_epi32(_pp0, _pp1, _pp2, _pp3); + __m512i _t1 = combine4x4_epi32(_pp4, _pp5, _pp6, _pp7); + + _mm512_storeu_si512((__m512i*)pp, _t0); + _mm512_storeu_si512((__m512i*)(pp + 64), _t1); pp += 128; p0 += B_hstep * 16; @@ -8483,71 +7635,42 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int #if __AVX512VNNI__ || __AVXVNNI__ for (; kk + 7 < max_kk; kk += 8) { - pp[0] = float2int8(p0[0] * scale) + 127; - pp[1] = float2int8(p0[1] * scale) + 127; - pp[2] = float2int8(p0[2] * scale) + 127; - pp[3] = float2int8(p0[3] * scale) + 127; - pp[4] = float2int8(p0[8] * scale) + 127; - pp[5] = float2int8(p0[9] * scale) + 127; - pp[6] = float2int8(p0[10] * scale) + 127; - pp[7] = float2int8(p0[11] * scale) + 127; - pp[8] = float2int8(p0[16] * scale) + 127; - pp[9] = float2int8(p0[17] * scale) + 127; - pp[10] = float2int8(p0[18] * scale) + 127; - pp[11] = float2int8(p0[19] * scale) + 127; - pp[12] = float2int8(p0[24] * scale) + 127; - pp[13] = float2int8(p0[25] * scale) + 127; - pp[14] = float2int8(p0[26] * scale) + 127; - pp[15] = float2int8(p0[27] * scale) + 127; - pp[16] = float2int8(p0[32] * scale) + 127; - pp[17] = float2int8(p0[33] * scale) + 127; - pp[18] = float2int8(p0[34] * scale) + 127; - pp[19] = float2int8(p0[35] * scale) + 127; - pp[20] = float2int8(p0[40] * scale) + 127; - pp[21] = float2int8(p0[41] * scale) + 127; - pp[22] = float2int8(p0[42] * scale) + 127; - pp[23] = float2int8(p0[43] * scale) + 127; - pp[24] = float2int8(p0[48] * scale) + 127; - pp[25] = float2int8(p0[49] * scale) + 127; - pp[26] = float2int8(p0[50] * scale) + 127; - pp[27] = float2int8(p0[51] * scale) + 127; - pp[28] = float2int8(p0[56] * scale) + 127; - pp[29] = float2int8(p0[57] * scale) + 127; - pp[30] = float2int8(p0[58] * scale) + 127; - pp[31] = float2int8(p0[59] * scale) + 127; - - pp[32 + 0] = float2int8(p0[4] * scale) + 127; - pp[32 + 1] = float2int8(p0[5] * scale) + 127; - pp[32 + 2] = float2int8(p0[6] * scale) + 127; - pp[32 + 3] = float2int8(p0[7] * scale) + 127; - pp[32 + 4] = float2int8(p0[12] * scale) + 127; - pp[32 + 5] = float2int8(p0[13] * scale) + 127; - pp[32 + 6] = float2int8(p0[14] * scale) + 127; - pp[32 + 7] = float2int8(p0[15] * scale) + 127; - pp[32 + 8] = float2int8(p0[20] * scale) + 127; - pp[32 + 9] = float2int8(p0[21] * scale) + 127; - pp[32 + 10] = float2int8(p0[22] * scale) + 127; - pp[32 + 11] = float2int8(p0[23] * scale) + 127; - pp[32 + 12] = float2int8(p0[28] * scale) + 127; - pp[32 + 13] = float2int8(p0[29] * scale) + 127; - pp[32 + 14] = float2int8(p0[30] * scale) + 127; - pp[32 + 15] = float2int8(p0[31] * scale) + 127; - pp[32 + 16] = float2int8(p0[36] * scale) + 127; - pp[32 + 17] = float2int8(p0[37] * scale) + 127; - pp[32 + 18] = float2int8(p0[38] * scale) + 127; - pp[32 + 19] = float2int8(p0[39] * scale) + 127; - pp[32 + 20] = float2int8(p0[44] * scale) + 127; - pp[32 + 21] = float2int8(p0[45] * scale) + 127; - pp[32 + 22] = float2int8(p0[46] * scale) + 127; - pp[32 + 23] = float2int8(p0[47] * scale) + 127; - pp[32 + 24] = float2int8(p0[52] * scale) + 127; - pp[32 + 25] = float2int8(p0[53] * scale) + 127; - pp[32 + 26] = float2int8(p0[54] * scale) + 127; - pp[32 + 27] = float2int8(p0[55] * scale) + 127; - pp[32 + 28] = float2int8(p0[60] * scale) + 127; - pp[32 + 29] = float2int8(p0[61] * scale) + 127; - pp[32 + 30] = float2int8(p0[62] * scale) + 127; - pp[32 + 31] = float2int8(p0[63] * scale) + 127; + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + 8); + __m256 _p2 = _mm256_loadu_ps(p0 + 16); + __m256 _p3 = _mm256_loadu_ps(p0 + 24); + __m256 _p4 = _mm256_loadu_ps(p0 + 32); + __m256 _p5 = _mm256_loadu_ps(p0 + 40); + __m256 _p6 = _mm256_loadu_ps(p0 + 48); + __m256 _p7 = _mm256_loadu_ps(p0 + 56); + + _p0 = _mm256_mul_ps(_p0, _scale); + _p1 = _mm256_mul_ps(_p1, _scale); + _p2 = _mm256_mul_ps(_p2, _scale); + _p3 = _mm256_mul_ps(_p3, _scale); + _p4 = _mm256_mul_ps(_p4, _scale); + _p5 = _mm256_mul_ps(_p5, _scale); + _p6 = _mm256_mul_ps(_p6, _scale); + _p7 = _mm256_mul_ps(_p7, _scale); + + __m128i _pp0 = float2int8_avx(_p0, _p2); + __m128i _pp1 = float2int8_avx(_p1, _p3); + __m128i _pp2 = float2int8_avx(_p4, _p6); + __m128i _pp3 = float2int8_avx(_p5, _p7); + + __m256i _t0 = combine4x2_epi32(_pp0, _pp2); + __m256i _t1 = combine4x2_epi32(_pp1, _pp3); + + __m256i _t2 = _mm256_unpacklo_epi32(_t0, _t1); + __m256i _t3 = _mm256_unpackhi_epi32(_t0, _t1); + _t0 = _mm256_unpacklo_epi64(_t2, _t3); + _t1 = _mm256_unpackhi_epi64(_t2, _t3); + + _t0 = _mm256_add_epi8(_t0, _v127); + _t1 = _mm256_add_epi8(_t1, _v127); + + _mm256_storeu_si256((__m256i*)pp, _t0); + _mm256_storeu_si256((__m256i*)(pp + 32), _t1); pp += 64; p0 += B_hstep * 8; @@ -8555,78 +7678,54 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int #else // __AVX512VNNI__ || __AVXVNNI__ for (; kk + 7 < max_kk; kk += 8) { - pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[1] * scale); - pp[2] = float2int8(p0[8] * scale); - pp[3] = float2int8(p0[9] * scale); - pp[4] = float2int8(p0[16] * scale); - pp[5] = float2int8(p0[17] * scale); - pp[6] = float2int8(p0[24] * scale); - pp[7] = float2int8(p0[25] * scale); - pp[8] = float2int8(p0[32] * scale); - pp[9] = float2int8(p0[33] * scale); - pp[10] = float2int8(p0[40] * scale); - pp[11] = float2int8(p0[41] * scale); - pp[12] = float2int8(p0[48] * scale); - pp[13] = float2int8(p0[49] * scale); - pp[14] = float2int8(p0[56] * scale); - pp[15] = float2int8(p0[57] * scale); - pp += 16; + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + 8); + __m256 _p2 = _mm256_loadu_ps(p0 + 16); + __m256 _p3 = _mm256_loadu_ps(p0 + 24); + __m256 _p4 = _mm256_loadu_ps(p0 + 32); + __m256 _p5 = _mm256_loadu_ps(p0 + 40); + __m256 _p6 = _mm256_loadu_ps(p0 + 48); + __m256 _p7 = _mm256_loadu_ps(p0 + 56); - pp[0] = float2int8(p0[2] * scale); - pp[1] = float2int8(p0[3] * scale); - pp[2] = float2int8(p0[10] * scale); - pp[3] = float2int8(p0[11] * scale); - pp[4] = float2int8(p0[18] * scale); - pp[5] = float2int8(p0[19] * scale); - pp[6] = float2int8(p0[26] * scale); - pp[7] = float2int8(p0[27] * scale); - pp[8] = float2int8(p0[34] * scale); - pp[9] = float2int8(p0[35] * scale); - pp[10] = float2int8(p0[42] * scale); - pp[11] = float2int8(p0[43] * scale); - pp[12] = float2int8(p0[50] * scale); - pp[13] = float2int8(p0[51] * scale); - pp[14] = float2int8(p0[58] * scale); - pp[15] = float2int8(p0[59] * scale); - pp += 16; + _p0 = _mm256_mul_ps(_p0, _scale); + _p1 = _mm256_mul_ps(_p1, _scale); + _p2 = _mm256_mul_ps(_p2, _scale); + _p3 = _mm256_mul_ps(_p3, _scale); + _p4 = _mm256_mul_ps(_p4, _scale); + _p5 = _mm256_mul_ps(_p5, _scale); + _p6 = _mm256_mul_ps(_p6, _scale); + _p7 = _mm256_mul_ps(_p7, _scale); - pp[0] = float2int8(p0[4] * scale); - pp[1] = float2int8(p0[5] * scale); - pp[2] = float2int8(p0[12] * scale); - pp[3] = float2int8(p0[13] * scale); - pp[4] = float2int8(p0[20] * scale); - pp[5] = float2int8(p0[21] * scale); - pp[6] = float2int8(p0[28] * scale); - pp[7] = float2int8(p0[29] * scale); - pp[8] = float2int8(p0[36] * scale); - pp[9] = float2int8(p0[37] * scale); - pp[10] = float2int8(p0[44] * scale); - pp[11] = float2int8(p0[45] * scale); - pp[12] = float2int8(p0[52] * scale); - pp[13] = float2int8(p0[53] * scale); - pp[14] = float2int8(p0[60] * scale); - pp[15] = float2int8(p0[61] * scale); - pp += 16; + __m128i _pp0 = float2int8_avx(_p0, _p2); + __m128i _pp1 = float2int8_avx(_p1, _p3); + __m128i _pp2 = float2int8_avx(_p4, _p6); + __m128i _pp3 = float2int8_avx(_p5, _p7); - pp[0] = float2int8(p0[6] * scale); - pp[1] = float2int8(p0[7] * scale); - pp[2] = float2int8(p0[14] * scale); - pp[3] = float2int8(p0[15] * scale); - pp[4] = float2int8(p0[22] * scale); - pp[5] = float2int8(p0[23] * scale); - pp[6] = float2int8(p0[30] * scale); - pp[7] = float2int8(p0[31] * scale); - pp[8] = float2int8(p0[38] * scale); - pp[9] = float2int8(p0[39] * scale); - pp[10] = float2int8(p0[46] * scale); - pp[11] = float2int8(p0[47] * scale); - pp[12] = float2int8(p0[54] * scale); - pp[13] = float2int8(p0[55] * scale); - pp[14] = float2int8(p0[62] * scale); - pp[15] = float2int8(p0[63] * scale); - pp += 16; +#if __AVX2__ + __m256i _t0 = combine4x2_epi32(_pp0, _pp2); + __m256i _t1 = combine4x2_epi32(_pp1, _pp3); + __m256i _t2 = _mm256_unpacklo_epi16(_t0, _t1); + __m256i _t3 = _mm256_unpackhi_epi16(_t0, _t1); + _t0 = _mm256_unpacklo_epi32(_t2, _t3); + _t1 = _mm256_unpackhi_epi32(_t2, _t3); + _t0 = _mm256_permute4x64_epi64(_t0, _MM_SHUFFLE(3, 1, 2, 0)); + _t1 = _mm256_permute4x64_epi64(_t1, _MM_SHUFFLE(3, 1, 2, 0)); +#else + __m128i _tt0 = _mm_unpacklo_epi16(_pp0, _pp1); + __m128i _tt1 = _mm_unpackhi_epi16(_pp0, _pp1); + __m128i _tt2 = _mm_unpacklo_epi16(_pp2, _pp3); + __m128i _tt3 = _mm_unpackhi_epi16(_pp2, _pp3); + _pp0 = _mm_unpacklo_epi16(_tt0, _tt1); + _pp1 = _mm_unpackhi_epi16(_tt0, _tt1); + _pp2 = _mm_unpacklo_epi16(_tt2, _tt3); + _pp3 = _mm_unpackhi_epi16(_tt2, _tt3); + __m256i _t0 = combine4x2_epi32(_pp0, _pp1); + __m256i _t1 = combine4x2_epi32(_pp2, _pp3); +#endif + _mm256_storeu_si256((__m256i*)pp, _t0); + _mm256_storeu_si256((__m256i*)(pp + 32), _t1); + pp += 64; p0 += B_hstep * 8; } #endif // __AVX512VNNI__ || __AVXVNNI__ @@ -8638,77 +7737,75 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int #if __AVX512VNNI__ || __AVXVNNI__ for (; kk + 3 < max_kk; kk += 4) { - pp[0] = float2int8(p0[0] * scale) + 127; - pp[1] = float2int8(p0[1] * scale) + 127; - pp[2] = float2int8(p0[2] * scale) + 127; - pp[3] = float2int8(p0[3] * scale) + 127; - pp[4] = float2int8(p0[4] * scale) + 127; - pp[5] = float2int8(p0[5] * scale) + 127; - pp[6] = float2int8(p0[6] * scale) + 127; - pp[7] = float2int8(p0[7] * scale) + 127; - pp[8] = float2int8(p0[8] * scale) + 127; - pp[9] = float2int8(p0[9] * scale) + 127; - pp[10] = float2int8(p0[10] * scale) + 127; - pp[11] = float2int8(p0[11] * scale) + 127; - pp[12] = float2int8(p0[12] * scale) + 127; - pp[13] = float2int8(p0[13] * scale) + 127; - pp[14] = float2int8(p0[14] * scale) + 127; - pp[15] = float2int8(p0[15] * scale) + 127; - pp[16] = float2int8(p0[16] * scale) + 127; - pp[17] = float2int8(p0[17] * scale) + 127; - pp[18] = float2int8(p0[18] * scale) + 127; - pp[19] = float2int8(p0[19] * scale) + 127; - pp[20] = float2int8(p0[20] * scale) + 127; - pp[21] = float2int8(p0[21] * scale) + 127; - pp[22] = float2int8(p0[22] * scale) + 127; - pp[23] = float2int8(p0[23] * scale) + 127; - pp[24] = float2int8(p0[24] * scale) + 127; - pp[25] = float2int8(p0[25] * scale) + 127; - pp[26] = float2int8(p0[26] * scale) + 127; - pp[27] = float2int8(p0[27] * scale) + 127; - pp[28] = float2int8(p0[28] * scale) + 127; - pp[29] = float2int8(p0[29] * scale) + 127; - pp[30] = float2int8(p0[30] * scale) + 127; - pp[31] = float2int8(p0[31] * scale) + 127; + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + 8); + __m256 _p2 = _mm256_loadu_ps(p0 + 16); + __m256 _p3 = _mm256_loadu_ps(p0 + 24); + + _p0 = _mm256_mul_ps(_p0, _scale); + _p1 = _mm256_mul_ps(_p1, _scale); + _p2 = _mm256_mul_ps(_p2, _scale); + _p3 = _mm256_mul_ps(_p3, _scale); + + __m128i _pp0 = float2int8_avx(_p0, _p1); + __m128i _pp1 = float2int8_avx(_p2, _p3); + + __m256i _pp = combine4x2_epi32(_pp0, _pp1); + + _pp = _mm256_add_epi8(_pp, _v127); + + _mm256_storeu_si256((__m256i*)pp, _pp); + pp += 32; p0 += B_hstep * 4; } #else // __AVX512VNNI__ || __AVXVNNI__ for (; kk + 3 < max_kk; kk += 4) { - pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[1] * scale); - pp[2] = float2int8(p0[4] * scale); - pp[3] = float2int8(p0[5] * scale); - pp[4] = float2int8(p0[8] * scale); - pp[5] = float2int8(p0[9] * scale); - pp[6] = float2int8(p0[12] * scale); - pp[7] = float2int8(p0[13] * scale); - pp[8] = float2int8(p0[16] * scale); - pp[9] = float2int8(p0[17] * scale); - pp[10] = float2int8(p0[20] * scale); - pp[11] = float2int8(p0[21] * scale); - pp[12] = float2int8(p0[24] * scale); - pp[13] = float2int8(p0[25] * scale); - pp[14] = float2int8(p0[28] * scale); - pp[15] = float2int8(p0[29] * scale); - - pp[16 + 0] = float2int8(p0[2] * scale); - pp[16 + 1] = float2int8(p0[3] * scale); - pp[16 + 2] = float2int8(p0[6] * scale); - pp[16 + 3] = float2int8(p0[7] * scale); - pp[16 + 4] = float2int8(p0[10] * scale); - pp[16 + 5] = float2int8(p0[11] * scale); - pp[16 + 6] = float2int8(p0[14] * scale); - pp[16 + 7] = float2int8(p0[15] * scale); - pp[16 + 8] = float2int8(p0[18] * scale); - pp[16 + 9] = float2int8(p0[19] * scale); - pp[16 + 10] = float2int8(p0[22] * scale); - pp[16 + 11] = float2int8(p0[23] * scale); - pp[16 + 12] = float2int8(p0[26] * scale); - pp[16 + 13] = float2int8(p0[27] * scale); - pp[16 + 14] = float2int8(p0[30] * scale); - pp[16 + 15] = float2int8(p0[31] * scale); +#if __AVX__ + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + 8); + __m256 _p2 = _mm256_loadu_ps(p0 + 16); + __m256 _p3 = _mm256_loadu_ps(p0 + 24); + + _p0 = _mm256_mul_ps(_p0, _scale); + _p1 = _mm256_mul_ps(_p1, _scale); + _p2 = _mm256_mul_ps(_p2, _scale); + _p3 = _mm256_mul_ps(_p3, _scale); + + __m128i _pp0 = float2int8_avx(_p0, _p1); + __m128i _pp1 = float2int8_avx(_p2, _p3); +#else + __m128 _p0 = _mm_loadu_ps(p0); + __m128 _p1 = _mm_loadu_ps(p0 + 4); + __m128 _p2 = _mm_loadu_ps(p0 + 8); + __m128 _p3 = _mm_loadu_ps(p0 + 12); + __m128 _p4 = _mm_loadu_ps(p0 + 16); + __m128 _p5 = _mm_loadu_ps(p0 + 20); + __m128 _p6 = _mm_loadu_ps(p0 + 24); + __m128 _p7 = _mm_loadu_ps(p0 + 28); + + _p0 = _mm_mul_ps(_p0, _scale); + _p1 = _mm_mul_ps(_p1, _scale); + _p2 = _mm_mul_ps(_p2, _scale); + _p3 = _mm_mul_ps(_p3, _scale); + _p4 = _mm_mul_ps(_p4, _scale); + _p5 = _mm_mul_ps(_p5, _scale); + _p6 = _mm_mul_ps(_p6, _scale); + _p7 = _mm_mul_ps(_p7, _scale); + + __m128i _pp0 = float2int8_sse(_p0, _p1, _p2, _p3); + __m128i _pp1 = float2int8_sse(_p4, _p5, _p6, _p7); +#endif + __m128i _t0 = _mm_unpacklo_epi16(_pp0, _pp1); + __m128i _t1 = _mm_unpackhi_epi16(_pp0, _pp1); + __m128i _t2 = _mm_unpacklo_epi16(_t0, _t1); + __m128i _t3 = _mm_unpackhi_epi16(_t0, _t1); + _t0 = _mm_unpacklo_epi16(_t2, _t3); + _t1 = _mm_unpackhi_epi16(_t2, _t3); + + _mm_storeu_si128((__m128i*)pp, _t0); + _mm_storeu_si128((__m128i*)(pp + 16), _t1); pp += 32; p0 += B_hstep * 4; @@ -8721,74 +7818,90 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int #if __AVX512VNNI__ || __AVXVNNI__ for (; kk + 3 < max_kk; kk += 4) { - pp[0] = float2int8(p0[0] * scale) + 127; - pp[1] = float2int8(p0[B_hstep] * scale) + 127; - pp[2] = float2int8(p0[B_hstep * 2] * scale) + 127; - pp[3] = float2int8(p0[B_hstep * 3] * scale) + 127; - pp[4] = float2int8(p0[1] * scale) + 127; - pp[5] = float2int8(p0[B_hstep + 1] * scale) + 127; - pp[6] = float2int8(p0[B_hstep * 2 + 1] * scale) + 127; - pp[7] = float2int8(p0[B_hstep * 3 + 1] * scale) + 127; - pp[8] = float2int8(p0[2] * scale) + 127; - pp[9] = float2int8(p0[B_hstep + 2] * scale) + 127; - pp[10] = float2int8(p0[B_hstep * 2 + 2] * scale) + 127; - pp[11] = float2int8(p0[B_hstep * 3 + 2] * scale) + 127; - pp[12] = float2int8(p0[3] * scale) + 127; - pp[13] = float2int8(p0[B_hstep + 3] * scale) + 127; - pp[14] = float2int8(p0[B_hstep * 2 + 3] * scale) + 127; - pp[15] = float2int8(p0[B_hstep * 3 + 3] * scale) + 127; - pp[16] = float2int8(p0[4] * scale) + 127; - pp[17] = float2int8(p0[B_hstep + 4] * scale) + 127; - pp[18] = float2int8(p0[B_hstep * 2 + 4] * scale) + 127; - pp[19] = float2int8(p0[B_hstep * 3 + 4] * scale) + 127; - pp[20] = float2int8(p0[5] * scale) + 127; - pp[21] = float2int8(p0[B_hstep + 5] * scale) + 127; - pp[22] = float2int8(p0[B_hstep * 2 + 5] * scale) + 127; - pp[23] = float2int8(p0[B_hstep * 3 + 5] * scale) + 127; - pp[24] = float2int8(p0[6] * scale) + 127; - pp[25] = float2int8(p0[B_hstep + 6] * scale) + 127; - pp[26] = float2int8(p0[B_hstep * 2 + 6] * scale) + 127; - pp[27] = float2int8(p0[B_hstep * 3 + 6] * scale) + 127; - pp[28] = float2int8(p0[7] * scale) + 127; - pp[29] = float2int8(p0[B_hstep + 7] * scale) + 127; - pp[30] = float2int8(p0[B_hstep * 2 + 7] * scale) + 127; - pp[31] = float2int8(p0[B_hstep * 3 + 7] * scale) + 127; + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + B_hstep); + __m256 _p2 = _mm256_loadu_ps(p0 + B_hstep * 2); + __m256 _p3 = _mm256_loadu_ps(p0 + B_hstep * 3); + + _p0 = _mm256_mul_ps(_p0, _scale); + _p1 = _mm256_mul_ps(_p1, _scale); + _p2 = _mm256_mul_ps(_p2, _scale); + _p3 = _mm256_mul_ps(_p3, _scale); + + __m128i _pp0 = float2int8_avx(_p0, _p2); + __m128i _pp1 = float2int8_avx(_p1, _p3); + + __m128i _tt0 = _mm_unpacklo_epi8(_pp0, _pp1); + __m128i _tt1 = _mm_unpackhi_epi8(_pp0, _pp1); + _pp0 = _mm_unpacklo_epi16(_tt0, _tt1); + _pp1 = _mm_unpackhi_epi16(_tt0, _tt1); + + __m256i _pp = combine4x2_epi32(_pp0, _pp1); + + _pp = _mm256_add_epi8(_pp, _v127); + + _mm256_storeu_si256((__m256i*)pp, _pp); + pp += 32; p0 += B_hstep * 4; } #endif // __AVX512VNNI__ || __AVXVNNI__ for (; kk + 1 < max_kk; kk += 2) { - pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[B_hstep] * scale); - pp[2] = float2int8(p0[1] * scale); - pp[3] = float2int8(p0[B_hstep + 1] * scale); - pp[4] = float2int8(p0[2] * scale); - pp[5] = float2int8(p0[B_hstep + 2] * scale); - pp[6] = float2int8(p0[3] * scale); - pp[7] = float2int8(p0[B_hstep + 3] * scale); - pp[8] = float2int8(p0[4] * scale); - pp[9] = float2int8(p0[B_hstep + 4] * scale); - pp[10] = float2int8(p0[5] * scale); - pp[11] = float2int8(p0[B_hstep + 5] * scale); - pp[12] = float2int8(p0[6] * scale); - pp[13] = float2int8(p0[B_hstep + 6] * scale); - pp[14] = float2int8(p0[7] * scale); - pp[15] = float2int8(p0[B_hstep + 7] * scale); +#if __AVX__ + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + B_hstep); + + _p0 = _mm256_mul_ps(_p0, _scale); + _p1 = _mm256_mul_ps(_p1, _scale); + + __m128i _pp = float2int8_avx(_p0, _p1); + + __m128i _si = _mm_setr_epi8(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7 ,15); + _pp = _mm_shuffle_epi8(_pp, _si); +#else + __m128 _p0 = _mm_loadu_ps(p0); + __m128 _p1 = _mm_loadu_ps(p0 + 4); + __m128 _p2 = _mm_loadu_ps(p0 + B_hstep); + __m128 _p3 = _mm_loadu_ps(p0 + B_hstep + 4); + + __m128 _t0 = _mm_unpacklo_ps(_p0, _p2); + __m128 _t1 = _mm_unpackhi_ps(_p0, _p2); + __m128 _t2 = _mm_unpacklo_ps(_p1, _p3); + __m128 _t3 = _mm_unpackhi_ps(_p1, _p3); + + _t0 = _mm_mul_ps(_t0, _scale); + _t1 = _mm_mul_ps(_t1, _scale); + _t2 = _mm_mul_ps(_t2, _scale); + _t3 = _mm_mul_ps(_t3, _scale); + + __m128i _pp = float2int8_sse(_t0, _t1, _t2, _t3); +#endif + + _mm_storeu_si128((__m128i*)pp, _pp); pp += 16; p0 += B_hstep * 2; } for (; kk < max_kk; kk++) { - pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[1] * scale); - pp[2] = float2int8(p0[2] * scale); - pp[3] = float2int8(p0[3] * scale); - pp[4] = float2int8(p0[4] * scale); - pp[5] = float2int8(p0[5] * scale); - pp[6] = float2int8(p0[6] * scale); - pp[7] = float2int8(p0[7] * scale); +#if __AVX__ + __m256 _p = _mm256_loadu_ps(p0); + + _p = _mm256_mul_ps(_p, _scale); + + int64_t v = float2int8_avx(_p); +#else + __m128 _p0 = _mm_loadu_ps(p0); + __m128 _p1 = _mm_loadu_ps(p0 + 4); + + _p0 = _mm_mul_ps(_p0, _scale); + _p1 = _mm_mul_ps(_p1, _scale); + + int64_t v = float2int8_sse(_p0, _p1); +#endif + *(int64_t*)pp = v; + pp += 8; p0 += B_hstep; }