diff --git a/src/layer/x86/lstm_int8.h b/src/layer/x86/lstm_int8.h index 67a9b89570e..c5e8b06f259 100644 --- a/src/layer/x86/lstm_int8.h +++ b/src/layer/x86/lstm_int8.h @@ -152,6 +152,7 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x __m512i _w0_shift = _mm512_setzero_si512(); __m512i _w1_shift = _mm512_setzero_si512(); +#if defined(__x86_64__) || defined(_M_X64) __m512i _w2_shift = _mm512_setzero_si512(); __m512i _w3_shift = _mm512_setzero_si512(); for (; i + 15 < size; i += 16) @@ -202,6 +203,7 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x _w0_shift = _mm512_setzero_si512(); _w1_shift = _mm512_setzero_si512(); +#endif // defined(__x86_64__) || defined(_M_X64) for (; i + 7 < size; i += 8) { _mm_storel_epi64((__m128i*)kptr, _mm_loadl_epi64((const __m128i*)(weight_xc_I_0 + i))); @@ -312,7 +314,7 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x _mm512_storeu_si512((__m512i*)kptr, _w_shift); kptr += 64; #else - +#if defined(__x86_64__) || defined(_M_X64) for (; i + 7 < size; i += 8) { _mm_storel_epi64((__m128i*)kptr, _mm_loadl_epi64((const __m128i*)(weight_xc_I_0 + i))); @@ -333,7 +335,7 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x _mm_storel_epi64((__m128i*)(kptr + 8 * 15), _mm_loadl_epi64((const __m128i*)(weight_xc_G_3 + i))); kptr += 128; } - +#endif // defined(__x86_64__) || defined(_M_X64) for (; i + 3 < size; i += 4) { kptr[0] = weight_xc_I_0[i]; @@ -465,6 +467,7 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x _w_shift = _mm512_setzero_si512(); _w0_shift = _mm512_setzero_si512(); _w1_shift = _mm512_setzero_si512(); +#if defined(__x86_64__) || defined(_M_X64) _w2_shift = _mm512_setzero_si512(); _w3_shift = _mm512_setzero_si512(); for (; i + 15 < num_output; i += 16) @@ -515,6 +518,7 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x _w0_shift = _mm512_setzero_si512(); _w1_shift = _mm512_setzero_si512(); +#endif // defined(__x86_64__) || defined(_M_X64) for (; i + 7 < num_output; i += 8) { _mm_storel_epi64((__m128i*)kptr, _mm_loadl_epi64((const __m128i*)(weight_hc_I_0 + i))); @@ -625,6 +629,7 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x _mm512_storeu_si512((__m512i*)kptr, _w_shift); kptr += 64; #else +#if defined(__x86_64__) || defined(_M_X64) for (; i + 7 < num_output; i += 8) { _mm_storel_epi64((__m128i*)kptr, _mm_loadl_epi64((const __m128i*)(weight_hc_I_0 + i))); @@ -645,7 +650,7 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x _mm_storel_epi64((__m128i*)(kptr + 8 * 15), _mm_loadl_epi64((const __m128i*)(weight_hc_G_3 + i))); kptr += 128; } - +#endif // defined(__x86_64__) || defined(_M_X64) for (; i + 3 < num_output; i += 4) { kptr[0] = weight_hc_I_0[i]; @@ -845,6 +850,7 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x __m256i _w0_shift = _mm256_setzero_si256(); __m256i _w1_shift = _mm256_setzero_si256(); +#if defined(__x86_64__) || defined(_M_X64) __m256i _w2_shift = _mm256_setzero_si256(); __m256i _w3_shift = _mm256_setzero_si256(); for (; i + 15 < size; i += 16) @@ -878,6 +884,7 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x _w0_shift = _mm256_setzero_si256(); _w1_shift = _mm256_setzero_si256(); +#endif // defined(__x86_64__) || defined(_M_X64) for (; i + 7 < size; i += 8) { _mm_storel_epi64((__m128i*)kptr, _mm_loadl_epi64((const __m128i*)(weight_xc_I_0 + i))); @@ -945,6 +952,7 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x _mm256_storeu_si256((__m256i*)kptr, _w_shift); kptr += 32; #else +#if defined(__x86_64__) || defined(_M_X64) for (; i + 7 < size; i += 8) { _mm_storel_epi64((__m128i*)kptr, _mm_loadl_epi64((const __m128i*)(weight_xc_I_0 + i))); @@ -957,6 +965,7 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x _mm_storel_epi64((__m128i*)(kptr + 56), _mm_loadl_epi64((const __m128i*)(weight_xc_G_1 + i))); kptr += 64; } +#endif // defined(__x86_64__) || defined(_M_X64) for (; i + 3 < size; i += 4) { kptr[0] = weight_xc_I_0[i]; @@ -1033,6 +1042,7 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x _v127 = _mm256_set1_epi8(127); _w0_shift = _mm256_setzero_si256(); _w1_shift = _mm256_setzero_si256(); +#if defined(__x86_64__) || defined(_M_X64) _w2_shift = _mm256_setzero_si256(); _w3_shift = _mm256_setzero_si256(); for (; i + 15 < num_output; i += 16) @@ -1066,6 +1076,7 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x _w0_shift = _mm256_setzero_si256(); _w1_shift = _mm256_setzero_si256(); +#endif // defined(__x86_64__) || defined(_M_X64) for (; i + 7 < num_output; i += 8) { _mm_storel_epi64((__m128i*)kptr, _mm_loadl_epi64((const __m128i*)(weight_hc_I_0 + i))); @@ -1133,6 +1144,7 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x _mm256_storeu_si256((__m256i*)kptr, _w_shift); kptr += 32; #else +#if defined(__x86_64__) || defined(_M_X64) for (; i + 7 < num_output; i += 8) { _mm_storel_epi64((__m128i*)kptr, _mm_loadl_epi64((const __m128i*)(weight_hc_I_0 + i))); @@ -1145,6 +1157,7 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x _mm_storel_epi64((__m128i*)(kptr + 56), _mm_loadl_epi64((const __m128i*)(weight_hc_G_1 + i))); kptr += 64; } +#endif // defined(__x86_64__) || defined(_M_X64) for (; i + 3 < num_output; i += 4) { kptr[0] = weight_hc_I_0[i]; @@ -1270,6 +1283,7 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x __m128i _v127 = _mm_set1_epi8(127); __m128i _w0_shift = _mm_setzero_si128(); __m128i _w1_shift = _mm_setzero_si128(); +#if defined(__x86_64__) || defined(_M_X64) __m128i _w2_shift = _mm_setzero_si128(); __m128i _w3_shift = _mm_setzero_si128(); for (; i + 15 < size; i += 16) @@ -1300,6 +1314,7 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x _w0_shift = _mm_setzero_si128(); _w1_shift = _mm_setzero_si128(); +#endif // defined(__x86_64__) || defined(_M_X64) for (; i + 7 < size; i += 8) { _mm_storel_epi64((__m128i*)kptr, _mm_loadl_epi64((const __m128i*)(weight_xc_I + i))); @@ -1347,6 +1362,7 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x _mm_storeu_si128((__m128i*)kptr, _w_shift); kptr += 16; #else +#if defined(__x86_64__) || defined(_M_X64) for (; i + 7 < size; i += 8) { _mm_storel_epi64((__m128i*)kptr, _mm_loadl_epi64((const __m128i*)(weight_xc_I + i))); @@ -1355,6 +1371,7 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x _mm_storel_epi64((__m128i*)(kptr + 24), _mm_loadl_epi64((const __m128i*)(weight_xc_G + i))); kptr += 32; } +#endif // defined(__x86_64__) || defined(_M_X64) for (; i + 3 < size; i += 4) { kptr[0] = weight_xc_I[i]; @@ -1404,6 +1421,7 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x _w_shift = _mm_setzero_si128(); _w0_shift = _mm_setzero_si128(); _w1_shift = _mm_setzero_si128(); +#if defined(__x86_64__) || defined(_M_X64) _w2_shift = _mm_setzero_si128(); _w3_shift = _mm_setzero_si128(); for (; i + 15 < num_output; i += 16) @@ -1434,6 +1452,7 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x _w0_shift = _mm_setzero_si128(); _w1_shift = _mm_setzero_si128(); +#endif // defined(__x86_64__) || defined(_M_X64) for (; i + 7 < num_output; i += 8) { _mm_storel_epi64((__m128i*)kptr, _mm_loadl_epi64((const __m128i*)(weight_hc_I + i))); @@ -1481,6 +1500,7 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x _mm_storeu_si128((__m128i*)kptr, _w_shift); kptr += 16; #else +#if defined(__x86_64__) || defined(_M_X64) for (; i + 7 < num_output; i += 8) { _mm_storel_epi64((__m128i*)kptr, _mm_loadl_epi64((const __m128i*)(weight_hc_I + i))); @@ -1489,6 +1509,7 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x _mm_storel_epi64((__m128i*)(kptr + 24), _mm_loadl_epi64((const __m128i*)(weight_hc_G + i))); kptr += 32; } +#endif // defined(__x86_64__) || defined(_M_X64) for (; i + 3 < num_output; i += 4) { kptr[0] = weight_hc_I[i]; @@ -1652,13 +1673,14 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m512i _lstm_IFOGx0 = _mm512_setzero_si512(); __m512i _sum0 = _mm512_setzero_si512(); __m512i _sum1 = _mm512_setzero_si512(); - __m512i _sum2 = _mm512_setzero_si512(); - __m512i _sum3 = _mm512_setzero_si512(); int i = 0; #if __AVX512VNNI__ __m128i _v127q = _mm_set1_epi8(127); __m512i _v127 = _mm512_set1_epi8(127); +#if defined(__x86_64__) || defined(_M_X64) + __m512i _sum2 = _mm512_setzero_si512(); + __m512i _sum3 = _mm512_setzero_si512(); for (; i + 15 < size; i += 16) { __m128i _xi = _mm_loadu_si128((const __m128i*)(x + i)); @@ -1695,6 +1717,7 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d _sum0 = _mm512_setzero_si512(); _sum1 = _mm512_setzero_si512(); +#endif // defined(__x86_64__) || defined(_M_X64) for (; i + 7 < size; i += 8) { __m128i _xi = _mm_loadl_epi64((const __m128i*)(x + i)); @@ -1733,6 +1756,9 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d kptr += 64; } #else +#if defined(__x86_64__) || defined(_M_X64) + __m512i _sum2 = _mm512_setzero_si512(); + __m512i _sum3 = _mm512_setzero_si512(); for (; i + 7 < size; i += 8) { __m256i _xi = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)(x + i))); @@ -1776,6 +1802,7 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d _sum0 = _mm512_setzero_si512(); _sum1 = _mm512_setzero_si512(); +#endif // defined(__x86_64__) || defined(_M_X64) for (; i + 3 < size; i += 4) { __m256i _xi = _mm256_castps_si256(_mm256_broadcast_ss((const float*)(x + i))); @@ -1836,10 +1863,11 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m512i _lstm_IFOGh0 = _mm512_setzero_si512(); _sum0 = _mm512_setzero_si512(); _sum1 = _mm512_setzero_si512(); - _sum2 = _mm512_setzero_si512(); - _sum3 = _mm512_setzero_si512(); i = 0; #if __AVX512VNNI__ +#if defined(__x86_64__) || defined(_M_X64) + _sum2 = _mm512_setzero_si512(); + _sum3 = _mm512_setzero_si512(); for (; i + 15 < num_output; i += 16) { __m128i _h_cont = _mm_loadu_si128((const __m128i*)(hs + i)); @@ -1876,6 +1904,7 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d _sum0 = _mm512_setzero_si512(); _sum1 = _mm512_setzero_si512(); +#endif // defined(__x86_64__) || defined(_M_X64) for (; i + 7 < num_output; i += 8) { __m128i _h_cont = _mm_loadl_epi64((const __m128i*)(hs + i)); @@ -1914,6 +1943,9 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d kptr += 64; } #else +#if defined(__x86_64__) || defined(_M_X64) + _sum2 = _mm512_setzero_si512(); + _sum3 = _mm512_setzero_si512(); for (; i + 7 < num_output; i += 8) { __m256i _h_cont = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)(hs + i))); @@ -1957,6 +1989,7 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d _sum0 = _mm512_setzero_si512(); _sum1 = _mm512_setzero_si512(); +#endif // defined(__x86_64__) || defined(_M_X64) for (; i + 3 < num_output; i += 4) { __m256i _h_cont = _mm256_castps_si256(_mm256_broadcast_ss((const float*)(hs + i))); @@ -2059,12 +2092,13 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m256i _lstm_IFOGx0 = _mm256_setzero_si256(); __m256i _sum0 = _mm256_setzero_si256(); __m256i _sum1 = _mm256_setzero_si256(); - __m256i _sum2 = _mm256_setzero_si256(); - __m256i _sum3 = _mm256_setzero_si256(); int i = 0; #if __AVXVNNI__ || __AVX512VNNI__ __m128i _v127q = _mm_set1_epi8(127); __m256i _v127 = _mm256_set1_epi8(127); +#if defined(__x86_64__) || defined(_M_X64) + __m256i _sum2 = _mm256_setzero_si256(); + __m256i _sum3 = _mm256_setzero_si256(); for (; i + 15 < size; i += 16) { __m128i _xi = _mm_loadu_si128((const __m128i*)(x + i)); @@ -2092,6 +2126,7 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d _sum0 = _mm256_setzero_si256(); _sum1 = _mm256_setzero_si256(); +#endif // defined(__x86_64__) || defined(_M_X64) for (; i + 7 < size; i += 8) { __m256i _xi = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)(x + i))); @@ -2125,6 +2160,9 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d kptr += 32; } #else +#if defined(__x86_64__) || defined(_M_X64) + __m256i _sum2 = _mm256_setzero_si256(); + __m256i _sum3 = _mm256_setzero_si256(); for (; i + 7 < size; i += 8) { __m128i _xi = _mm_castpd_si128(_mm_load1_pd((const double*)(x + i))); @@ -2159,6 +2197,7 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d _sum0 = _mm256_setzero_si256(); _sum1 = _mm256_setzero_si256(); +#endif // defined(__x86_64__) || defined(_M_X64) for (; i + 3 < size; i += 4) { __m128i _xi = _mm_castps_si128(_mm_load1_ps((const float*)(x + i))); @@ -2216,10 +2255,11 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m256i _lstm_IFOGh0 = _mm256_setzero_si256(); _sum0 = _mm256_setzero_si256(); _sum1 = _mm256_setzero_si256(); - _sum2 = _mm256_setzero_si256(); - _sum3 = _mm256_setzero_si256(); i = 0; #if __AVXVNNI__ || __AVX512VNNI__ +#if defined(__x86_64__) || defined(_M_X64) + _sum2 = _mm256_setzero_si256(); + _sum3 = _mm256_setzero_si256(); for (; i + 15 < num_output; i += 16) { __m128i _h_cont = _mm_loadu_si128((const __m128i*)(hs + i)); @@ -2247,6 +2287,7 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d _sum0 = _mm256_setzero_si256(); _sum1 = _mm256_setzero_si256(); +#endif // defined(__x86_64__) || defined(_M_X64) for (; i + 7 < num_output; i += 8) { __m256i _h_cont = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)(hs + i))); @@ -2280,6 +2321,9 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d kptr += 32; } #else +#if defined(__x86_64__) || defined(_M_X64) + _sum2 = _mm256_setzero_si256(); + _sum3 = _mm256_setzero_si256(); for (; i + 7 < num_output; i += 8) { __m128i _h_cont = _mm_castpd_si128(_mm_load1_pd((const double*)(hs + i))); @@ -2314,6 +2358,7 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d _sum0 = _mm256_setzero_si256(); _sum1 = _mm256_setzero_si256(); +#endif // defined(__x86_64__) || defined(_M_X64) for (; i + 3 < num_output; i += 4) { __m128i _h_cont = _mm_castps_si128(_mm_load1_ps((const float*)(hs + i))); @@ -2413,11 +2458,12 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m128i _lstm_IFOGx0 = _mm_setzero_si128(); __m128i _sum0 = _mm_setzero_si128(); __m128i _sum1 = _mm_setzero_si128(); - __m128i _sum2 = _mm_setzero_si128(); - __m128i _sum3 = _mm_setzero_si128(); int i = 0; #if __AVXVNNI__ || __AVX512VNNI__ __m128i _v127 = _mm_set1_epi8(127); +#if defined(__x86_64__) || defined(_M_X64) + __m128i _sum2 = _mm_setzero_si128(); + __m128i _sum3 = _mm_setzero_si128(); for (; i + 15 < size; i += 16) { __m128i _xi = _mm_loadu_si128((const __m128i*)(x + i)); @@ -2444,6 +2490,7 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d _sum0 = _mm_setzero_si128(); _sum1 = _mm_setzero_si128(); +#endif // defined(__x86_64__) || defined(_M_X64) for (; i + 7 < size; i += 8) { __m128i _xi = _mm_castpd_si128(_mm_load1_pd((const double*)(x + i))); @@ -2477,6 +2524,9 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d kptr += 16; } #else +#if defined(__x86_64__) || defined(_M_X64) + __m128i _sum2 = _mm_setzero_si128(); + __m128i _sum3 = _mm_setzero_si128(); for (; i + 7 < size; i += 8) { __m128i _xi = _mm_castpd_si128(_mm_load1_pd((const double*)(x + i))); @@ -2527,6 +2577,7 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d _sum0 = _mm_setzero_si128(); _sum1 = _mm_setzero_si128(); +#endif // defined(__x86_64__) || defined(_M_X64) for (; i + 3 < size; i += 4) { __m128i _xi = _mm_castps_si128(_mm_load1_ps((const float*)(x + i))); @@ -2617,10 +2668,11 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m128i _lstm_IFOGh0 = _mm_setzero_si128(); _sum0 = _mm_setzero_si128(); _sum1 = _mm_setzero_si128(); - _sum2 = _mm_setzero_si128(); - _sum3 = _mm_setzero_si128(); i = 0; #if __AVXVNNI__ || __AVX512VNNI__ +#if defined(__x86_64__) || defined(_M_X64) + _sum2 = _mm_setzero_si128(); + _sum3 = _mm_setzero_si128(); for (; i + 15 < num_output; i += 16) { __m128i _h_cont = _mm_loadu_si128((const __m128i*)(hs + i)); @@ -2647,6 +2699,7 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d _sum0 = _mm_setzero_si128(); _sum1 = _mm_setzero_si128(); +#endif // defined(__x86_64__) || defined(_M_X64) for (; i + 7 < num_output; i += 8) { __m128i _h_cont = _mm_castpd_si128(_mm_load1_pd((const double*)(hs + i))); @@ -2680,6 +2733,9 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d kptr += 16; } #else +#if defined(__x86_64__) || defined(_M_X64) + _sum2 = _mm_setzero_si128(); + _sum3 = _mm_setzero_si128(); for (; i + 7 < num_output; i += 8) { __m128i _h_cont = _mm_castpd_si128(_mm_load1_pd((const double*)(hs + i))); @@ -2730,6 +2786,7 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d _sum0 = _mm_setzero_si128(); _sum1 = _mm_setzero_si128(); +#endif // defined(__x86_64__) || defined(_M_X64) for (; i + 3 < num_output; i += 4) { __m128i _h_cont = _mm_castps_si128(_mm_load1_ps((const float*)(hs + i)));