Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Oct 8, 2023
1 parent 4198cb1 commit 4cc7b2b
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions src/layer/x86/convolution_3x3_winograd_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ static void transpose_pack_B_tile_int8(const Mat& B, Mat& BT, int batch, int max

int jj = 0;
#if __SSE2__
#if defined(__x86_64__) || defined(_M_X64)
#if __AVX512F__
for (; jj + 15 < max_jj; jj += 16)
{
Expand Down Expand Up @@ -429,6 +430,7 @@ static void transpose_pack_B_tile_int8(const Mat& B, Mat& BT, int batch, int max
pp += 8;
}
}
#endif // defined(__x86_64__) || defined(_M_X64)
for (; jj + 3 < max_jj; jj += 4)
{
const short* p0 = B;
Expand Down Expand Up @@ -611,6 +613,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
const short* pB = BT_tile.row<const short>(b);

int jj = 0;
#if defined(__x86_64__) || defined(_M_X64)
for (; jj + 15 < max_jj; jj += 16)
{
const short* pA = pAT;
Expand Down Expand Up @@ -1094,6 +1097,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
_mm512_store_si512((__m512i*)(outptr + 16 * 7), _sum7);
outptr += 16 * 8;
}
#endif // defined(__x86_64__) || defined(_M_X64)
for (; jj + 3 < max_jj; jj += 4)
{
const short* pA = pAT;
Expand Down Expand Up @@ -1323,6 +1327,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
const short* pB = BT_tile.row<const short>(b);

int jj = 0;
#if defined(__x86_64__) || defined(_M_X64)
#if __AVX512F__
for (; jj + 15 < max_jj; jj += 16)
{
Expand Down Expand Up @@ -1406,9 +1411,9 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
__m512i _pA00 = _mm512_inserti32x8(_mm512_castsi256_si512(_pA0), _pA0, 1);
__m512i _pA11 = _mm512_permutex_epi64(_pA00, _MM_SHUFFLE(1, 0, 3, 2));

__m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1));
__m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB);
__m512i _pB2 = _mm512_permutex_epi64(_pB0, _MM_SHUFFLE(2, 3, 0, 1));
__m512i _pB3 = _mm512_shuffle_epi32(_pB0, _MM_SHUFFLE(2, 1, 0, 3));
__m512i _pB3 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CBAD);

__m512i _s0 = _mm512_mullo_epi32(_pA00, _pB0);
__m512i _s1 = _mm512_mullo_epi32(_pA00, _pB1);
Expand Down Expand Up @@ -1796,6 +1801,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
outptr += 8 * 8;
#endif // __AVX512F__
}
#endif // defined(__x86_64__) || defined(_M_X64)
for (; jj + 3 < max_jj; jj += 4)
{
const short* pA = pAT;
Expand Down Expand Up @@ -2020,6 +2026,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
const short* pB = BT_tile.row<const short>(b);

int jj = 0;
#if defined(__x86_64__) || defined(_M_X64)
#if __AVX512F__
for (; jj + 15 < max_jj; jj += 16)
{
Expand Down Expand Up @@ -2414,6 +2421,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
outptr += 32;
#endif // __AVX2__
}
#endif // defined(__x86_64__) || defined(_M_X64)
for (; jj + 3 < max_jj; jj += 4)
{
const short* pA = pAT;
Expand Down Expand Up @@ -2676,6 +2684,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,

int jj = 0;
#if __SSE2__
#if defined(__x86_64__) || defined(_M_X64)
#if __AVX512F__
for (; jj + 15 < max_jj; jj += 16)
{
Expand Down Expand Up @@ -2875,6 +2884,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
outptr += 16;
#endif // __AVX2__
}
#endif // defined(__x86_64__) || defined(_M_X64)
for (; jj + 3 < max_jj; jj += 4)
{
const short* pA = pAT;
Expand Down Expand Up @@ -3041,6 +3051,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,

int jj = 0;
#if __SSE2__
#if defined(__x86_64__) || defined(_M_X64)
#if __AVX512F__
for (; jj + 15 < max_jj; jj += 16)
{
Expand Down Expand Up @@ -3136,6 +3147,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
_mm_storeu_si128((__m128i*)(outptr + 4), _sum1);
outptr += 8;
}
#endif // defined(__x86_64__) || defined(_M_X64)
for (; jj + 3 < max_jj; jj += 4)
{
const short* pA = pAT;
Expand Down

0 comments on commit 4cc7b2b

Please sign in to comment.