From 9ce793041300da16655a7b89627b67353803ee32 Mon Sep 17 00:00:00 2001 From: nihui Date: Sat, 13 Apr 2024 15:47:03 +0800 Subject: [PATCH] x86 optimization for convolution tiled gemm (#5426) --- src/layer/x86/convolution_im2col_gemm.h | 4681 ++++++++++++++++++ src/layer/x86/convolution_im2col_gemm_int8.h | 1253 +---- src/layer/x86/convolution_x86.cpp | 514 +- src/layer/x86/convolution_x86.h | 2 - tests/test_convolution_2.cpp | 3 +- 5 files changed, 4786 insertions(+), 1667 deletions(-) create mode 100644 src/layer/x86/convolution_im2col_gemm.h diff --git a/src/layer/x86/convolution_im2col_gemm.h b/src/layer/x86/convolution_im2col_gemm.h new file mode 100644 index 00000000000..e683a2bb951 --- /dev/null +++ b/src/layer/x86/convolution_im2col_gemm.h @@ -0,0 +1,4681 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +static void convolution_im2col_pack_A_tile(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) +{ + // A = (pa, maxk, inch/pa), outch + const int A_hstep = A.w; + + float* pp = AT; + + int ii = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; ii + 15 < max_ii; ii += 16) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep + k; + const float* p1 = (const float*)A + (i + ii + 1) * A_hstep + k; + const float* p2 = (const float*)A + (i + ii + 2) * A_hstep + k; + const float* p3 = (const float*)A + (i + ii + 3) * A_hstep + k; + const float* p4 = (const float*)A + (i + ii + 4) * A_hstep + k; + const float* p5 = (const float*)A + (i + ii + 5) * A_hstep + k; + const float* p6 = (const float*)A + (i + ii + 6) * A_hstep + k; + const float* p7 = (const float*)A + (i + ii + 7) * A_hstep + k; + const float* p8 = (const float*)A + (i + ii + 8) * A_hstep + k; + const float* p9 = (const float*)A + (i + ii + 9) * A_hstep + k; + const float* pa = (const float*)A + (i + ii + 10) * A_hstep + k; + const float* pb = (const float*)A + (i + ii + 11) * A_hstep + k; + const float* pc = (const float*)A + (i + ii + 12) * A_hstep + k; + const float* pd = (const float*)A + (i + ii + 13) * A_hstep + k; + const float* pe = (const float*)A + (i + ii + 14) * A_hstep + k; + const float* pf = (const float*)A + (i + ii + 15) * A_hstep + k; + + int kk = 0; + for (; kk + 15 < max_kk; kk += 16) + { + __m512 _r0 = _mm512_loadu_ps(p0); + __m512 _r1 = _mm512_loadu_ps(p1); + __m512 _r2 = _mm512_loadu_ps(p2); + __m512 _r3 = _mm512_loadu_ps(p3); + __m512 _r4 = _mm512_loadu_ps(p4); + __m512 _r5 = _mm512_loadu_ps(p5); + __m512 _r6 = _mm512_loadu_ps(p6); + __m512 _r7 = _mm512_loadu_ps(p7); + __m512 _r8 = _mm512_loadu_ps(p8); + __m512 _r9 = _mm512_loadu_ps(p9); + __m512 _ra = _mm512_loadu_ps(pa); + __m512 _rb = _mm512_loadu_ps(pb); + __m512 _rc = _mm512_loadu_ps(pc); + __m512 _rd = _mm512_loadu_ps(pd); + __m512 _re = _mm512_loadu_ps(pe); + __m512 _rf = _mm512_loadu_ps(pf); + transpose16x16_ps(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7, _r8, _r9, _ra, _rb, _rc, _rd, _re, _rf); + _mm512_store_ps(pp, _r0); + _mm512_store_ps(pp + 16, _r1); + _mm512_store_ps(pp + 16 * 2, _r2); + _mm512_store_ps(pp + 16 * 3, _r3); + _mm512_store_ps(pp + 16 * 4, _r4); + _mm512_store_ps(pp + 16 * 5, _r5); + _mm512_store_ps(pp + 16 * 6, _r6); + _mm512_store_ps(pp + 16 * 7, _r7); + _mm512_store_ps(pp + 16 * 8, _r8); + _mm512_store_ps(pp + 16 * 9, _r9); + _mm512_store_ps(pp + 16 * 10, _ra); + _mm512_store_ps(pp + 16 * 11, _rb); + _mm512_store_ps(pp + 16 * 12, _rc); + _mm512_store_ps(pp + 16 * 13, _rd); + _mm512_store_ps(pp + 16 * 14, _re); + _mm512_store_ps(pp + 16 * 15, _rf); + pp += 256; + p0 += 16; + p1 += 16; + p2 += 16; + p3 += 16; + p4 += 16; + p5 += 16; + p6 += 16; + p7 += 16; + p8 += 16; + p9 += 16; + pa += 16; + pb += 16; + pc += 16; + pd += 16; + pe += 16; + pf += 16; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp[2] = p2[0]; + pp[3] = p3[0]; + pp[4] = p4[0]; + pp[5] = p5[0]; + pp[6] = p6[0]; + pp[7] = p7[0]; + pp[8] = p8[0]; + pp[9] = p9[0]; + pp[10] = pa[0]; + pp[11] = pb[0]; + pp[12] = pc[0]; + pp[13] = pd[0]; + pp[14] = pe[0]; + pp[15] = pf[0]; + pp += 16; + p0++; + p1++; + p2++; + p3++; + p4++; + p5++; + p6++; + p7++; + p8++; + p9++; + pa++; + pb++; + pc++; + pd++; + pe++; + pf++; + } + } +#endif // __AVX512F__ + for (; ii + 7 < max_ii; ii += 8) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep + k; + const float* p1 = (const float*)A + (i + ii + 1) * A_hstep + k; + const float* p2 = (const float*)A + (i + ii + 2) * A_hstep + k; + const float* p3 = (const float*)A + (i + ii + 3) * A_hstep + k; + const float* p4 = (const float*)A + (i + ii + 4) * A_hstep + k; + const float* p5 = (const float*)A + (i + ii + 5) * A_hstep + k; + const float* p6 = (const float*)A + (i + ii + 6) * A_hstep + k; + const float* p7 = (const float*)A + (i + ii + 7) * A_hstep + k; + + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + __m256 _r0 = _mm256_loadu_ps(p0); + __m256 _r1 = _mm256_loadu_ps(p1); + __m256 _r2 = _mm256_loadu_ps(p2); + __m256 _r3 = _mm256_loadu_ps(p3); + __m256 _r4 = _mm256_loadu_ps(p4); + __m256 _r5 = _mm256_loadu_ps(p5); + __m256 _r6 = _mm256_loadu_ps(p6); + __m256 _r7 = _mm256_loadu_ps(p7); + transpose8x8_ps(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7); + _mm256_store_ps(pp, _r0); + _mm256_store_ps(pp + 8, _r1); + _mm256_store_ps(pp + 8 * 2, _r2); + _mm256_store_ps(pp + 8 * 3, _r3); + _mm256_store_ps(pp + 8 * 4, _r4); + _mm256_store_ps(pp + 8 * 5, _r5); + _mm256_store_ps(pp + 8 * 6, _r6); + _mm256_store_ps(pp + 8 * 7, _r7); + pp += 64; + p0 += 8; + p1 += 8; + p2 += 8; + p3 += 8; + p4 += 8; + p5 += 8; + p6 += 8; + p7 += 8; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp[2] = p2[0]; + pp[3] = p3[0]; + pp[4] = p4[0]; + pp[5] = p5[0]; + pp[6] = p6[0]; + pp[7] = p7[0]; + pp += 8; + p0++; + p1++; + p2++; + p3++; + p4++; + p5++; + p6++; + p7++; + } + } +#endif // __AVX__ + for (; ii + 3 < max_ii; ii += 4) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep + k; + const float* p1 = (const float*)A + (i + ii + 1) * A_hstep + k; + const float* p2 = (const float*)A + (i + ii + 2) * A_hstep + k; + const float* p3 = (const float*)A + (i + ii + 3) * A_hstep + k; + + int kk = 0; +#if __AVX__ + for (; kk + 7 < max_kk; kk += 8) + { + __m256 _r0 = _mm256_loadu_ps(p0); + __m256 _r1 = _mm256_loadu_ps(p1); + __m256 _r2 = _mm256_loadu_ps(p2); + __m256 _r3 = _mm256_loadu_ps(p3); + transpose8x4_ps(_r0, _r1, _r2, _r3); + _mm256_store_ps(pp, _r0); + _mm256_store_ps(pp + 8, _r1); + _mm256_store_ps(pp + 16, _r2); + _mm256_store_ps(pp + 24, _r3); + pp += 32; + p0 += 8; + p1 += 8; + p2 += 8; + p3 += 8; + } +#endif // __AVX__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128 _r0 = _mm_loadu_ps(p0); + __m128 _r1 = _mm_loadu_ps(p1); + __m128 _r2 = _mm_loadu_ps(p2); + __m128 _r3 = _mm_loadu_ps(p3); + _MM_TRANSPOSE4_PS(_r0, _r1, _r2, _r3); + _mm_store_ps(pp, _r0); + _mm_store_ps(pp + 4, _r1); + _mm_store_ps(pp + 8, _r2); + _mm_store_ps(pp + 12, _r3); + pp += 16; + p0 += 4; + p1 += 4; + p2 += 4; + p3 += 4; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp[2] = p2[0]; + pp[3] = p3[0]; + pp += 4; + p0++; + p1++; + p2++; + p3++; + } + } +#endif // __SSE2__ + for (; ii + 1 < max_ii; ii += 2) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep + k; + const float* p1 = (const float*)A + (i + ii + 1) * A_hstep + k; + + int kk = 0; +#if __SSE2__ +#if __AVX__ + for (; kk + 7 < max_kk; kk += 8) + { + __m256 _r0 = _mm256_loadu_ps(p0); + __m256 _r1 = _mm256_loadu_ps(p1); + transpose8x2_ps(_r0, _r1); + _mm256_storeu_ps(pp, _r0); + _mm256_storeu_ps(pp + 8, _r1); + pp += 16; + p0 += 8; + p1 += 8; + } +#endif // __AVX__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128 _r0 = _mm_loadu_ps(p0); + __m128 _r1 = _mm_loadu_ps(p1); + __m128 _tmp0 = _mm_unpacklo_ps(_r0, _r1); + __m128 _tmp1 = _mm_unpackhi_ps(_r0, _r1); + _mm_store_ps(pp, _tmp0); + _mm_store_ps(pp + 4, _tmp1); + pp += 8; + p0 += 4; + p1 += 4; + } +#endif // __SSE2__ + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp += 2; + p0++; + p1++; + } + } + for (; ii < max_ii; ii += 1) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep + k; + + int kk = 0; +#if __SSE2__ +#if __AVX__ + for (; kk + 7 < max_kk; kk += 8) + { + _mm256_storeu_ps(pp, _mm256_loadu_ps(p0)); + pp += 8; + p0 += 8; + } +#endif // __AVX__ + for (; kk + 3 < max_kk; kk += 4) + { + _mm_storeu_ps(pp, _mm_loadu_ps(p0)); + pp += 4; + p0 += 4; + } +#endif // __SSE2__ + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp += 1; + p0++; + } + } +} + +static void convolution_gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, const Mat& CT_tile, Mat& topT_tile, Mat& top_blob, int i, int max_ii, int j, int max_jj, int k, int max_kk, bool k_end) +{ + // NCNN_LOGE("convolution_gemm_transB_packed_tile %d %d %d %d %d %d", i, max_ii, j, max_jj, k, max_kk); + + const int out_elempack = top_blob.elempack; + const int out_hstep = (int)top_blob.cstep; + + const float* pAT = AT_tile; + const float* pBT = BT_tile; + const float* pC = CT_tile; + + float* outptr = topT_tile; + + int ii = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; ii + 15 < max_ii; ii += 16) + { + float* outptr0 = (float*)top_blob + (i + ii) * out_hstep + j * out_elempack; + + const float* pB = pBT; + + if (pC) + { + pC = (const float*)CT_tile + i + ii; + } + + int jj = 0; +#if defined(__x86_64__) || defined(_M_X64) + for (; jj + 11 < max_jj; jj += 12) + { + const float* pA = pAT; + + __m512 _sum0; + __m512 _sum1; + __m512 _sum2; + __m512 _sum3; + __m512 _sum4; + __m512 _sum5; + __m512 _sum6; + __m512 _sum7; + __m512 _sum8; + __m512 _sum9; + __m512 _suma; + __m512 _sumb; + + if (k == 0) + { + if (pC) + { + _sum0 = _mm512_loadu_ps(pC); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + _sum8 = _sum0; + _sum9 = _sum0; + _suma = _sum0; + _sumb = _sum0; + } + else + { + _sum0 = _mm512_setzero_ps(); + _sum1 = _mm512_setzero_ps(); + _sum2 = _mm512_setzero_ps(); + _sum3 = _mm512_setzero_ps(); + _sum4 = _mm512_setzero_ps(); + _sum5 = _mm512_setzero_ps(); + _sum6 = _mm512_setzero_ps(); + _sum7 = _mm512_setzero_ps(); + _sum8 = _mm512_setzero_ps(); + _sum9 = _mm512_setzero_ps(); + _suma = _mm512_setzero_ps(); + _sumb = _mm512_setzero_ps(); + } + } + else + { + _sum0 = _mm512_load_ps(outptr); + _sum1 = _mm512_load_ps(outptr + 16 * 1); + _sum2 = _mm512_load_ps(outptr + 16 * 2); + _sum3 = _mm512_load_ps(outptr + 16 * 3); + _sum4 = _mm512_load_ps(outptr + 16 * 4); + _sum5 = _mm512_load_ps(outptr + 16 * 5); + _sum6 = _mm512_load_ps(outptr + 16 * 6); + _sum7 = _mm512_load_ps(outptr + 16 * 7); + _sum8 = _mm512_load_ps(outptr + 16 * 8); + _sum9 = _mm512_load_ps(outptr + 16 * 9); + _suma = _mm512_load_ps(outptr + 16 * 10); + _sumb = _mm512_load_ps(outptr + 16 * 11); + } + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + __m512 _pA = _mm512_load_ps(pA); + + _sum0 = _mm512_fmadd_ps(_pA, _mm512_set1_ps(pB[0]), _sum0); + _sum1 = _mm512_fmadd_ps(_pA, _mm512_set1_ps(pB[1]), _sum1); + _sum2 = _mm512_fmadd_ps(_pA, _mm512_set1_ps(pB[2]), _sum2); + _sum3 = _mm512_fmadd_ps(_pA, _mm512_set1_ps(pB[3]), _sum3); + _sum4 = _mm512_fmadd_ps(_pA, _mm512_set1_ps(pB[4]), _sum4); + _sum5 = _mm512_fmadd_ps(_pA, _mm512_set1_ps(pB[5]), _sum5); + _sum6 = _mm512_fmadd_ps(_pA, _mm512_set1_ps(pB[6]), _sum6); + _sum7 = _mm512_fmadd_ps(_pA, _mm512_set1_ps(pB[7]), _sum7); + _sum8 = _mm512_fmadd_ps(_pA, _mm512_set1_ps(pB[8]), _sum8); + _sum9 = _mm512_fmadd_ps(_pA, _mm512_set1_ps(pB[9]), _sum9); + _suma = _mm512_fmadd_ps(_pA, _mm512_set1_ps(pB[10]), _suma); + _sumb = _mm512_fmadd_ps(_pA, _mm512_set1_ps(pB[11]), _sumb); + + pA += 16; + pB += 12; + } + + if (k_end) + { + if (out_elempack == 16) + { + _mm512_store_ps(outptr0, _sum0); + _mm512_store_ps(outptr0 + 16 * 1, _sum1); + _mm512_store_ps(outptr0 + 16 * 2, _sum2); + _mm512_store_ps(outptr0 + 16 * 3, _sum3); + _mm512_store_ps(outptr0 + 16 * 4, _sum4); + _mm512_store_ps(outptr0 + 16 * 5, _sum5); + _mm512_store_ps(outptr0 + 16 * 6, _sum6); + _mm512_store_ps(outptr0 + 16 * 7, _sum7); + _mm512_store_ps(outptr0 + 16 * 8, _sum8); + _mm512_store_ps(outptr0 + 16 * 9, _sum9); + _mm512_store_ps(outptr0 + 16 * 10, _suma); + _mm512_store_ps(outptr0 + 16 * 11, _sumb); + outptr0 += 192; + } + if (out_elempack == 8) + { + __m512 _tmp0 = _mm512_shuffle_f32x4(_sum0, _sum1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_sum2, _sum3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_sum4, _sum5, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_sum6, _sum7, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp4 = _mm512_shuffle_f32x4(_sum8, _sum9, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp5 = _mm512_shuffle_f32x4(_suma, _sumb, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp6 = _mm512_shuffle_f32x4(_sum0, _sum1, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp7 = _mm512_shuffle_f32x4(_sum2, _sum3, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp8 = _mm512_shuffle_f32x4(_sum4, _sum5, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp9 = _mm512_shuffle_f32x4(_sum6, _sum7, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmpa = _mm512_shuffle_f32x4(_sum8, _sum9, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmpb = _mm512_shuffle_f32x4(_suma, _sumb, _MM_SHUFFLE(3, 2, 3, 2)); + + _mm512_storeu_ps(outptr0, _tmp0); + _mm512_storeu_ps(outptr0 + 16, _tmp1); + _mm512_storeu_ps(outptr0 + 16 * 2, _tmp2); + _mm512_storeu_ps(outptr0 + 16 * 3, _tmp3); + _mm512_storeu_ps(outptr0 + 16 * 4, _tmp4); + _mm512_storeu_ps(outptr0 + 16 * 5, _tmp5); + + _mm512_storeu_ps(outptr0 + out_hstep * 8, _tmp6); + _mm512_storeu_ps(outptr0 + out_hstep * 8 + 16, _tmp7); + _mm512_storeu_ps(outptr0 + out_hstep * 8 + 16 * 2, _tmp8); + _mm512_storeu_ps(outptr0 + out_hstep * 8 + 16 * 3, _tmp9); + _mm512_storeu_ps(outptr0 + out_hstep * 8 + 16 * 4, _tmpa); + _mm512_storeu_ps(outptr0 + out_hstep * 8 + 16 * 5, _tmpb); + + outptr0 += 96; + } + if (out_elempack == 4) + { + __m512 _tmp0 = _mm512_shuffle_f32x4(_sum0, _sum1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_sum2, _sum3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_sum0, _sum1, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_sum2, _sum3, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp4 = _mm512_shuffle_f32x4(_sum4, _sum5, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp5 = _mm512_shuffle_f32x4(_sum6, _sum7, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp6 = _mm512_shuffle_f32x4(_sum4, _sum5, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp7 = _mm512_shuffle_f32x4(_sum6, _sum7, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp8 = _mm512_shuffle_f32x4(_sum8, _sum9, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp9 = _mm512_shuffle_f32x4(_suma, _sumb, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmpa = _mm512_shuffle_f32x4(_sum8, _sum9, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmpb = _mm512_shuffle_f32x4(_suma, _sumb, _MM_SHUFFLE(3, 2, 3, 2)); + + _sum0 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _sum1 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _sum2 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _sum3 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + _sum4 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(2, 0, 2, 0)); + _sum5 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(3, 1, 3, 1)); + _sum6 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); + _sum7 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(3, 1, 3, 1)); + _sum8 = _mm512_shuffle_f32x4(_tmp8, _tmp9, _MM_SHUFFLE(2, 0, 2, 0)); + _sum9 = _mm512_shuffle_f32x4(_tmp8, _tmp9, _MM_SHUFFLE(3, 1, 3, 1)); + _suma = _mm512_shuffle_f32x4(_tmpa, _tmpb, _MM_SHUFFLE(2, 0, 2, 0)); + _sumb = _mm512_shuffle_f32x4(_tmpa, _tmpb, _MM_SHUFFLE(3, 1, 3, 1)); + + _mm512_storeu_ps(outptr0, _sum0); + _mm512_storeu_ps(outptr0 + 16, _sum4); + _mm512_storeu_ps(outptr0 + 32, _sum8); + _mm512_storeu_ps(outptr0 + out_hstep * 4, _sum1); + _mm512_storeu_ps(outptr0 + out_hstep * 4 + 16, _sum5); + _mm512_storeu_ps(outptr0 + out_hstep * 4 + 32, _sum9); + _mm512_storeu_ps(outptr0 + out_hstep * 8, _sum2); + _mm512_storeu_ps(outptr0 + out_hstep * 8 + 16, _sum6); + _mm512_storeu_ps(outptr0 + out_hstep * 8 + 32, _suma); + _mm512_storeu_ps(outptr0 + out_hstep * 12, _sum3); + _mm512_storeu_ps(outptr0 + out_hstep * 12 + 16, _sum7); + _mm512_storeu_ps(outptr0 + out_hstep * 12 + 32, _sumb); + + outptr0 += 48; + } + if (out_elempack == 1) + { + transpose16x12_ps(_sum0, _sum1, _sum2, _sum3, _sum4, _sum5, _sum6, _sum7, _sum8, _sum9, _suma, _sumb); + + _mm256_storeu_ps(outptr0, _mm512_extractf32x8_ps(_sum0, 0)); + _mm_storeu_ps(outptr0 + 8, _mm512_extractf32x4_ps(_sum0, 2)); + _mm_storeu_ps(outptr0 + out_hstep * 1, _mm512_extractf32x4_ps(_sum0, 3)); + _mm256_storeu_ps(outptr0 + out_hstep * 1 + 4, _mm512_extractf32x8_ps(_sum1, 0)); + _mm256_storeu_ps(outptr0 + out_hstep * 2, _mm512_extractf32x8_ps(_sum1, 1)); + _mm_storeu_ps(outptr0 + out_hstep * 2 + 8, _mm512_extractf32x4_ps(_sum2, 0)); + _mm_storeu_ps(outptr0 + out_hstep * 3, _mm512_extractf32x4_ps(_sum2, 1)); + _mm256_storeu_ps(outptr0 + out_hstep * 3 + 4, _mm512_extractf32x8_ps(_sum2, 1)); + + _mm256_storeu_ps(outptr0 + out_hstep * 4, _mm512_extractf32x8_ps(_sum3, 0)); + _mm_storeu_ps(outptr0 + out_hstep * 4 + 8, _mm512_extractf32x4_ps(_sum3, 2)); + _mm_storeu_ps(outptr0 + out_hstep * 5, _mm512_extractf32x4_ps(_sum3, 3)); + _mm256_storeu_ps(outptr0 + out_hstep * 5 + 4, _mm512_extractf32x8_ps(_sum4, 0)); + _mm256_storeu_ps(outptr0 + out_hstep * 6, _mm512_extractf32x8_ps(_sum4, 1)); + _mm_storeu_ps(outptr0 + out_hstep * 6 + 8, _mm512_extractf32x4_ps(_sum5, 0)); + _mm_storeu_ps(outptr0 + out_hstep * 7, _mm512_extractf32x4_ps(_sum5, 1)); + _mm256_storeu_ps(outptr0 + out_hstep * 7 + 4, _mm512_extractf32x8_ps(_sum5, 1)); + + _mm256_storeu_ps(outptr0 + out_hstep * 8, _mm512_extractf32x8_ps(_sum6, 0)); + _mm_storeu_ps(outptr0 + out_hstep * 8 + 8, _mm512_extractf32x4_ps(_sum6, 2)); + _mm_storeu_ps(outptr0 + out_hstep * 9, _mm512_extractf32x4_ps(_sum6, 3)); + _mm256_storeu_ps(outptr0 + out_hstep * 9 + 4, _mm512_extractf32x8_ps(_sum7, 0)); + _mm256_storeu_ps(outptr0 + out_hstep * 10, _mm512_extractf32x8_ps(_sum7, 1)); + _mm_storeu_ps(outptr0 + out_hstep * 10 + 8, _mm512_extractf32x4_ps(_sum8, 0)); + _mm_storeu_ps(outptr0 + out_hstep * 11, _mm512_extractf32x4_ps(_sum8, 1)); + _mm256_storeu_ps(outptr0 + out_hstep * 11 + 4, _mm512_extractf32x8_ps(_sum8, 1)); + + _mm256_storeu_ps(outptr0 + out_hstep * 12, _mm512_extractf32x8_ps(_sum9, 0)); + _mm_storeu_ps(outptr0 + out_hstep * 12 + 8, _mm512_extractf32x4_ps(_sum9, 2)); + _mm_storeu_ps(outptr0 + out_hstep * 13, _mm512_extractf32x4_ps(_sum9, 3)); + _mm256_storeu_ps(outptr0 + out_hstep * 13 + 4, _mm512_extractf32x8_ps(_suma, 0)); + _mm256_storeu_ps(outptr0 + out_hstep * 14, _mm512_extractf32x8_ps(_suma, 1)); + _mm_storeu_ps(outptr0 + out_hstep * 14 + 8, _mm512_extractf32x4_ps(_sumb, 0)); + _mm_storeu_ps(outptr0 + out_hstep * 15, _mm512_extractf32x4_ps(_sumb, 1)); + _mm256_storeu_ps(outptr0 + out_hstep * 15 + 4, _mm512_extractf32x8_ps(_sumb, 1)); + + outptr0 += 12; + } + } + else + { + _mm512_store_ps(outptr, _sum0); + _mm512_store_ps(outptr + 16 * 1, _sum1); + _mm512_store_ps(outptr + 16 * 2, _sum2); + _mm512_store_ps(outptr + 16 * 3, _sum3); + _mm512_store_ps(outptr + 16 * 4, _sum4); + _mm512_store_ps(outptr + 16 * 5, _sum5); + _mm512_store_ps(outptr + 16 * 6, _sum6); + _mm512_store_ps(outptr + 16 * 7, _sum7); + _mm512_store_ps(outptr + 16 * 8, _sum8); + _mm512_store_ps(outptr + 16 * 9, _sum9); + _mm512_store_ps(outptr + 16 * 10, _suma); + _mm512_store_ps(outptr + 16 * 11, _sumb); + } + + outptr += 192; + } + for (; jj + 7 < max_jj; jj += 8) + { + const float* pA = pAT; + + __m512 _sum0; + __m512 _sum1; + __m512 _sum2; + __m512 _sum3; + __m512 _sum4; + __m512 _sum5; + __m512 _sum6; + __m512 _sum7; + + if (k == 0) + { + if (pC) + { + _sum0 = _mm512_loadu_ps(pC); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + } + else + { + _sum0 = _mm512_setzero_ps(); + _sum1 = _mm512_setzero_ps(); + _sum2 = _mm512_setzero_ps(); + _sum3 = _mm512_setzero_ps(); + _sum4 = _mm512_setzero_ps(); + _sum5 = _mm512_setzero_ps(); + _sum6 = _mm512_setzero_ps(); + _sum7 = _mm512_setzero_ps(); + } + } + else + { + _sum0 = _mm512_load_ps(outptr); + _sum1 = _mm512_load_ps(outptr + 16 * 1); + _sum2 = _mm512_load_ps(outptr + 16 * 2); + _sum3 = _mm512_load_ps(outptr + 16 * 3); + _sum4 = _mm512_load_ps(outptr + 16 * 4); + _sum5 = _mm512_load_ps(outptr + 16 * 5); + _sum6 = _mm512_load_ps(outptr + 16 * 6); + _sum7 = _mm512_load_ps(outptr + 16 * 7); + } + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + __m512 _pA = _mm512_load_ps(pA); + + _sum0 = _mm512_fmadd_ps(_pA, _mm512_set1_ps(pB[0]), _sum0); + _sum1 = _mm512_fmadd_ps(_pA, _mm512_set1_ps(pB[1]), _sum1); + _sum2 = _mm512_fmadd_ps(_pA, _mm512_set1_ps(pB[2]), _sum2); + _sum3 = _mm512_fmadd_ps(_pA, _mm512_set1_ps(pB[3]), _sum3); + _sum4 = _mm512_fmadd_ps(_pA, _mm512_set1_ps(pB[4]), _sum4); + _sum5 = _mm512_fmadd_ps(_pA, _mm512_set1_ps(pB[5]), _sum5); + _sum6 = _mm512_fmadd_ps(_pA, _mm512_set1_ps(pB[6]), _sum6); + _sum7 = _mm512_fmadd_ps(_pA, _mm512_set1_ps(pB[7]), _sum7); + + pA += 16; + pB += 8; + } + + if (k_end) + { + if (out_elempack == 16) + { + _mm512_store_ps(outptr0, _sum0); + _mm512_store_ps(outptr0 + 16 * 1, _sum1); + _mm512_store_ps(outptr0 + 16 * 2, _sum2); + _mm512_store_ps(outptr0 + 16 * 3, _sum3); + _mm512_store_ps(outptr0 + 16 * 4, _sum4); + _mm512_store_ps(outptr0 + 16 * 5, _sum5); + _mm512_store_ps(outptr0 + 16 * 6, _sum6); + _mm512_store_ps(outptr0 + 16 * 7, _sum7); + outptr0 += 128; + } + if (out_elempack == 8) + { + __m512 _tmp0 = _mm512_shuffle_f32x4(_sum0, _sum1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_sum2, _sum3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_sum4, _sum5, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_sum6, _sum7, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp4 = _mm512_shuffle_f32x4(_sum0, _sum1, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp5 = _mm512_shuffle_f32x4(_sum2, _sum3, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp6 = _mm512_shuffle_f32x4(_sum4, _sum5, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp7 = _mm512_shuffle_f32x4(_sum6, _sum7, _MM_SHUFFLE(3, 2, 3, 2)); + + _mm512_storeu_ps(outptr0, _tmp0); + _mm512_storeu_ps(outptr0 + 16, _tmp1); + _mm512_storeu_ps(outptr0 + 16 * 2, _tmp2); + _mm512_storeu_ps(outptr0 + 16 * 3, _tmp3); + + _mm512_storeu_ps(outptr0 + out_hstep * 8, _tmp4); + _mm512_storeu_ps(outptr0 + out_hstep * 8 + 16, _tmp5); + _mm512_storeu_ps(outptr0 + out_hstep * 8 + 16 * 2, _tmp6); + _mm512_storeu_ps(outptr0 + out_hstep * 8 + 16 * 3, _tmp7); + + outptr0 += 64; + } + if (out_elempack == 4) + { + __m512 _tmp0 = _mm512_shuffle_f32x4(_sum0, _sum1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_sum2, _sum3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_sum0, _sum1, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_sum2, _sum3, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp4 = _mm512_shuffle_f32x4(_sum4, _sum5, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp5 = _mm512_shuffle_f32x4(_sum6, _sum7, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp6 = _mm512_shuffle_f32x4(_sum4, _sum5, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp7 = _mm512_shuffle_f32x4(_sum6, _sum7, _MM_SHUFFLE(3, 2, 3, 2)); + + _sum0 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _sum1 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _sum2 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _sum3 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + _sum4 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(2, 0, 2, 0)); + _sum5 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(3, 1, 3, 1)); + _sum6 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); + _sum7 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(3, 1, 3, 1)); + + _mm512_storeu_ps(outptr0, _sum0); + _mm512_storeu_ps(outptr0 + 16, _sum4); + _mm512_storeu_ps(outptr0 + out_hstep * 4, _sum1); + _mm512_storeu_ps(outptr0 + out_hstep * 4 + 16, _sum5); + _mm512_storeu_ps(outptr0 + out_hstep * 8, _sum2); + _mm512_storeu_ps(outptr0 + out_hstep * 8 + 16, _sum6); + _mm512_storeu_ps(outptr0 + out_hstep * 12, _sum3); + _mm512_storeu_ps(outptr0 + out_hstep * 12 + 16, _sum7); + + outptr0 += 32; + } + if (out_elempack == 1) + { + transpose16x8_ps(_sum0, _sum1, _sum2, _sum3, _sum4, _sum5, _sum6, _sum7); + + _mm256_storeu_ps(outptr0, _mm512_extractf32x8_ps(_sum0, 0)); + _mm256_storeu_ps(outptr0 + out_hstep * 1, _mm512_extractf32x8_ps(_sum0, 1)); + _mm256_storeu_ps(outptr0 + out_hstep * 2, _mm512_extractf32x8_ps(_sum1, 0)); + _mm256_storeu_ps(outptr0 + out_hstep * 3, _mm512_extractf32x8_ps(_sum1, 1)); + _mm256_storeu_ps(outptr0 + out_hstep * 4, _mm512_extractf32x8_ps(_sum2, 0)); + _mm256_storeu_ps(outptr0 + out_hstep * 5, _mm512_extractf32x8_ps(_sum2, 1)); + _mm256_storeu_ps(outptr0 + out_hstep * 6, _mm512_extractf32x8_ps(_sum3, 0)); + _mm256_storeu_ps(outptr0 + out_hstep * 7, _mm512_extractf32x8_ps(_sum3, 1)); + _mm256_storeu_ps(outptr0 + out_hstep * 8, _mm512_extractf32x8_ps(_sum4, 0)); + _mm256_storeu_ps(outptr0 + out_hstep * 9, _mm512_extractf32x8_ps(_sum4, 1)); + _mm256_storeu_ps(outptr0 + out_hstep * 10, _mm512_extractf32x8_ps(_sum5, 0)); + _mm256_storeu_ps(outptr0 + out_hstep * 11, _mm512_extractf32x8_ps(_sum5, 1)); + _mm256_storeu_ps(outptr0 + out_hstep * 12, _mm512_extractf32x8_ps(_sum6, 0)); + _mm256_storeu_ps(outptr0 + out_hstep * 13, _mm512_extractf32x8_ps(_sum6, 1)); + _mm256_storeu_ps(outptr0 + out_hstep * 14, _mm512_extractf32x8_ps(_sum7, 0)); + _mm256_storeu_ps(outptr0 + out_hstep * 15, _mm512_extractf32x8_ps(_sum7, 1)); + + outptr0 += 8; + } + } + else + { + _mm512_store_ps(outptr, _sum0); + _mm512_store_ps(outptr + 16 * 1, _sum1); + _mm512_store_ps(outptr + 16 * 2, _sum2); + _mm512_store_ps(outptr + 16 * 3, _sum3); + _mm512_store_ps(outptr + 16 * 4, _sum4); + _mm512_store_ps(outptr + 16 * 5, _sum5); + _mm512_store_ps(outptr + 16 * 6, _sum6); + _mm512_store_ps(outptr + 16 * 7, _sum7); + } + + outptr += 128; + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + const float* pA = pAT; + + __m512 _sum0; + __m512 _sum1; + __m512 _sum2; + __m512 _sum3; + + if (k == 0) + { + if (pC) + { + _sum0 = _mm512_loadu_ps(pC); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + } + else + { + _sum0 = _mm512_setzero_ps(); + _sum1 = _mm512_setzero_ps(); + _sum2 = _mm512_setzero_ps(); + _sum3 = _mm512_setzero_ps(); + } + } + else + { + _sum0 = _mm512_load_ps(outptr); + _sum1 = _mm512_load_ps(outptr + 16 * 1); + _sum2 = _mm512_load_ps(outptr + 16 * 2); + _sum3 = _mm512_load_ps(outptr + 16 * 3); + } + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + __m512 _pA = _mm512_load_ps(pA); + + _sum0 = _mm512_fmadd_ps(_pA, _mm512_set1_ps(pB[0]), _sum0); + _sum1 = _mm512_fmadd_ps(_pA, _mm512_set1_ps(pB[1]), _sum1); + _sum2 = _mm512_fmadd_ps(_pA, _mm512_set1_ps(pB[2]), _sum2); + _sum3 = _mm512_fmadd_ps(_pA, _mm512_set1_ps(pB[3]), _sum3); + + pA += 16; + pB += 4; + } + + if (k_end) + { + if (out_elempack == 16) + { + _mm512_store_ps(outptr0, _sum0); + _mm512_store_ps(outptr0 + 16 * 1, _sum1); + _mm512_store_ps(outptr0 + 16 * 2, _sum2); + _mm512_store_ps(outptr0 + 16 * 3, _sum3); + outptr0 += 64; + } + if (out_elempack == 8) + { + __m512 _tmp0 = _mm512_shuffle_f32x4(_sum0, _sum1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_sum2, _sum3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_sum0, _sum1, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_sum2, _sum3, _MM_SHUFFLE(3, 2, 3, 2)); + + _mm512_storeu_ps(outptr0, _tmp0); + _mm512_storeu_ps(outptr0 + 16, _tmp1); + + _mm512_storeu_ps(outptr0 + out_hstep * 8, _tmp2); + _mm512_storeu_ps(outptr0 + out_hstep * 8 + 16, _tmp3); + + outptr0 += 32; + } + if (out_elempack == 4) + { + __m512 _tmp0 = _mm512_shuffle_f32x4(_sum0, _sum1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_sum2, _sum3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_sum0, _sum1, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_sum2, _sum3, _MM_SHUFFLE(3, 2, 3, 2)); + + _sum0 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _sum1 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _sum2 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _sum3 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + + _mm512_storeu_ps(outptr0, _sum0); + _mm512_storeu_ps(outptr0 + out_hstep * 4, _sum1); + _mm512_storeu_ps(outptr0 + out_hstep * 8, _sum2); + _mm512_storeu_ps(outptr0 + out_hstep * 12, _sum3); + + outptr0 += 16; + } + if (out_elempack == 1) + { + __m128 _sum0_0 = _mm512_extractf32x4_ps(_sum0, 0); + __m128 _sum1_0 = _mm512_extractf32x4_ps(_sum1, 0); + __m128 _sum2_0 = _mm512_extractf32x4_ps(_sum2, 0); + __m128 _sum3_0 = _mm512_extractf32x4_ps(_sum3, 0); + __m128 _sum0_1 = _mm512_extractf32x4_ps(_sum0, 1); + __m128 _sum1_1 = _mm512_extractf32x4_ps(_sum1, 1); + __m128 _sum2_1 = _mm512_extractf32x4_ps(_sum2, 1); + __m128 _sum3_1 = _mm512_extractf32x4_ps(_sum3, 1); + __m128 _sum0_2 = _mm512_extractf32x4_ps(_sum0, 2); + __m128 _sum1_2 = _mm512_extractf32x4_ps(_sum1, 2); + __m128 _sum2_2 = _mm512_extractf32x4_ps(_sum2, 2); + __m128 _sum3_2 = _mm512_extractf32x4_ps(_sum3, 2); + __m128 _sum0_3 = _mm512_extractf32x4_ps(_sum0, 3); + __m128 _sum1_3 = _mm512_extractf32x4_ps(_sum1, 3); + __m128 _sum2_3 = _mm512_extractf32x4_ps(_sum2, 3); + __m128 _sum3_3 = _mm512_extractf32x4_ps(_sum3, 3); + + _MM_TRANSPOSE4_PS(_sum0_0, _sum1_0, _sum2_0, _sum3_0); + _MM_TRANSPOSE4_PS(_sum0_1, _sum1_1, _sum2_1, _sum3_1); + _MM_TRANSPOSE4_PS(_sum0_2, _sum1_2, _sum2_2, _sum3_2); + _MM_TRANSPOSE4_PS(_sum0_3, _sum1_3, _sum2_3, _sum3_3); + + _mm_storeu_ps(outptr0, _sum0_0); + _mm_storeu_ps(outptr0 + out_hstep * 1, _sum1_0); + _mm_storeu_ps(outptr0 + out_hstep * 2, _sum2_0); + _mm_storeu_ps(outptr0 + out_hstep * 3, _sum3_0); + _mm_storeu_ps(outptr0 + out_hstep * 4, _sum0_1); + _mm_storeu_ps(outptr0 + out_hstep * 5, _sum1_1); + _mm_storeu_ps(outptr0 + out_hstep * 6, _sum2_1); + _mm_storeu_ps(outptr0 + out_hstep * 7, _sum3_1); + _mm_storeu_ps(outptr0 + out_hstep * 8, _sum0_2); + _mm_storeu_ps(outptr0 + out_hstep * 9, _sum1_2); + _mm_storeu_ps(outptr0 + out_hstep * 10, _sum2_2); + _mm_storeu_ps(outptr0 + out_hstep * 11, _sum3_2); + _mm_storeu_ps(outptr0 + out_hstep * 12, _sum0_3); + _mm_storeu_ps(outptr0 + out_hstep * 13, _sum1_3); + _mm_storeu_ps(outptr0 + out_hstep * 14, _sum2_3); + _mm_storeu_ps(outptr0 + out_hstep * 15, _sum3_3); + + outptr0 += 4; + } + } + else + { + _mm512_store_ps(outptr, _sum0); + _mm512_store_ps(outptr + 16 * 1, _sum1); + _mm512_store_ps(outptr + 16 * 2, _sum2); + _mm512_store_ps(outptr + 16 * 3, _sum3); + } + + outptr += 64; + } + for (; jj + 1 < max_jj; jj += 2) + { + const float* pA = pAT; + + __m512 _sum0; + __m512 _sum1; + + if (k == 0) + { + if (pC) + { + _sum0 = _mm512_loadu_ps(pC); + _sum1 = _sum0; + } + else + { + _sum0 = _mm512_setzero_ps(); + _sum1 = _mm512_setzero_ps(); + } + } + else + { + _sum0 = _mm512_load_ps(outptr); + _sum1 = _mm512_load_ps(outptr + 16); + } + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + __m512 _pA = _mm512_load_ps(pA); + + _sum0 = _mm512_fmadd_ps(_pA, _mm512_set1_ps(pB[0]), _sum0); + _sum1 = _mm512_fmadd_ps(_pA, _mm512_set1_ps(pB[1]), _sum1); + + pA += 16; + pB += 2; + } + + if (k_end) + { + if (out_elempack == 16) + { + _mm512_store_ps(outptr0, _sum0); + _mm512_store_ps(outptr0 + 16, _sum1); + outptr0 += 32; + } + if (out_elempack == 8) + { + __m512 _tmp0 = _mm512_shuffle_f32x4(_sum0, _sum1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_sum0, _sum1, _MM_SHUFFLE(3, 2, 3, 2)); + + _mm512_storeu_ps(outptr0, _tmp0); + _mm512_storeu_ps(outptr0 + out_hstep * 8, _tmp1); + + outptr0 += 16; + } + if (out_elempack == 4) + { + _mm_store_ps(outptr0, _mm512_extractf32x4_ps(_sum0, 0)); + _mm_store_ps(outptr0 + 4, _mm512_extractf32x4_ps(_sum1, 0)); + + _mm_store_ps(outptr0 + out_hstep * 4, _mm512_extractf32x4_ps(_sum0, 1)); + _mm_store_ps(outptr0 + out_hstep * 4 + 4, _mm512_extractf32x4_ps(_sum1, 1)); + + _mm_store_ps(outptr0 + out_hstep * 8, _mm512_extractf32x4_ps(_sum0, 2)); + _mm_store_ps(outptr0 + out_hstep * 8 + 4, _mm512_extractf32x4_ps(_sum1, 2)); + + _mm_store_ps(outptr0 + out_hstep * 12, _mm512_extractf32x4_ps(_sum0, 3)); + _mm_store_ps(outptr0 + out_hstep * 12 + 4, _mm512_extractf32x4_ps(_sum1, 3)); + outptr0 += 8; + } + if (out_elempack == 1) + { + float sum0[16]; + float sum1[16]; + _mm512_storeu_ps(sum0, _sum0); + _mm512_storeu_ps(sum1, _sum1); + + outptr0[0] = sum0[0]; + outptr0[out_hstep] = sum0[1]; + outptr0[out_hstep * 2] = sum0[2]; + outptr0[out_hstep * 3] = sum0[3]; + outptr0[out_hstep * 4] = sum0[4]; + outptr0[out_hstep * 5] = sum0[5]; + outptr0[out_hstep * 6] = sum0[6]; + outptr0[out_hstep * 7] = sum0[7]; + outptr0[out_hstep * 8] = sum0[8]; + outptr0[out_hstep * 9] = sum0[9]; + outptr0[out_hstep * 10] = sum0[10]; + outptr0[out_hstep * 11] = sum0[11]; + outptr0[out_hstep * 12] = sum0[12]; + outptr0[out_hstep * 13] = sum0[13]; + outptr0[out_hstep * 14] = sum0[14]; + outptr0[out_hstep * 15] = sum0[15]; + + outptr0[1] = sum1[0]; + outptr0[out_hstep + 1] = sum1[1]; + outptr0[out_hstep * 2 + 1] = sum1[2]; + outptr0[out_hstep * 3 + 1] = sum1[3]; + outptr0[out_hstep * 4 + 1] = sum1[4]; + outptr0[out_hstep * 5 + 1] = sum1[5]; + outptr0[out_hstep * 6 + 1] = sum1[6]; + outptr0[out_hstep * 7 + 1] = sum1[7]; + outptr0[out_hstep * 8 + 1] = sum1[8]; + outptr0[out_hstep * 9 + 1] = sum1[9]; + outptr0[out_hstep * 10 + 1] = sum1[10]; + outptr0[out_hstep * 11 + 1] = sum1[11]; + outptr0[out_hstep * 12 + 1] = sum1[12]; + outptr0[out_hstep * 13 + 1] = sum1[13]; + outptr0[out_hstep * 14 + 1] = sum1[14]; + outptr0[out_hstep * 15 + 1] = sum1[15]; + outptr0 += 2; + } + } + else + { + _mm512_store_ps(outptr, _sum0); + _mm512_store_ps(outptr + 16, _sum1); + } + + outptr += 32; + } + for (; jj < max_jj; jj += 1) + { + const float* pA = pAT; + + __m512 _sum0; + + if (k == 0) + { + if (pC) + { + _sum0 = _mm512_loadu_ps(pC); + } + else + { + _sum0 = _mm512_setzero_ps(); + } + } + else + { + _sum0 = _mm512_load_ps(outptr); + } + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + __m512 _pA = _mm512_load_ps(pA); + + _sum0 = _mm512_fmadd_ps(_pA, _mm512_set1_ps(pB[0]), _sum0); + + pA += 16; + pB += 1; + } + + if (k_end) + { + if (out_elempack == 16) + { + _mm512_store_ps(outptr0, _sum0); + outptr0 += 16; + } + if (out_elempack == 8) + { + _mm256_store_ps(outptr0, _mm512_extractf32x8_ps(_sum0, 0)); + _mm256_store_ps(outptr0 + out_hstep * 8, _mm512_extractf32x8_ps(_sum0, 1)); + outptr0 += 8; + } + if (out_elempack == 4) + { + _mm_store_ps(outptr0, _mm512_extractf32x4_ps(_sum0, 0)); + _mm_store_ps(outptr0 + out_hstep * 4, _mm512_extractf32x4_ps(_sum0, 1)); + _mm_store_ps(outptr0 + out_hstep * 8, _mm512_extractf32x4_ps(_sum0, 2)); + _mm_store_ps(outptr0 + out_hstep * 12, _mm512_extractf32x4_ps(_sum0, 3)); + outptr0 += 4; + } + if (out_elempack == 1) + { + float sum0[16]; + _mm512_storeu_ps(sum0, _sum0); + + outptr0[0] = sum0[0]; + outptr0[out_hstep * 1] = sum0[1]; + outptr0[out_hstep * 2] = sum0[2]; + outptr0[out_hstep * 3] = sum0[3]; + outptr0[out_hstep * 4] = sum0[4]; + outptr0[out_hstep * 5] = sum0[5]; + outptr0[out_hstep * 6] = sum0[6]; + outptr0[out_hstep * 7] = sum0[7]; + outptr0[out_hstep * 8] = sum0[8]; + outptr0[out_hstep * 9] = sum0[9]; + outptr0[out_hstep * 10] = sum0[10]; + outptr0[out_hstep * 11] = sum0[11]; + outptr0[out_hstep * 12] = sum0[12]; + outptr0[out_hstep * 13] = sum0[13]; + outptr0[out_hstep * 14] = sum0[14]; + outptr0[out_hstep * 15] = sum0[15]; + outptr0++; + } + } + else + { + _mm512_store_ps(outptr, _sum0); + } + + outptr += 16; + } + + pAT += max_kk * 16; + } +#endif // __AVX512F__ + for (; ii + 7 < max_ii; ii += 8) + { + float* outptr0 = (float*)top_blob + (i + ii) * out_hstep + j * out_elempack; + + const float* pB = pBT; + + if (pC) + { + pC = (const float*)CT_tile + i + ii; + } + + int jj = 0; +#if defined(__x86_64__) || defined(_M_X64) + for (; jj + 11 < max_jj; jj += 12) + { + const float* pA = pAT; + + __m256 _sum0; + __m256 _sum1; + __m256 _sum2; + __m256 _sum3; + __m256 _sum4; + __m256 _sum5; + __m256 _sum6; + __m256 _sum7; + __m256 _sum8; + __m256 _sum9; + __m256 _suma; + __m256 _sumb; + + if (k == 0) + { + if (pC) + { + _sum0 = _mm256_loadu_ps(pC); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + _sum8 = _sum0; + _sum9 = _sum0; + _suma = _sum0; + _sumb = _sum0; + } + else + { + _sum0 = _mm256_setzero_ps(); + _sum1 = _mm256_setzero_ps(); + _sum2 = _mm256_setzero_ps(); + _sum3 = _mm256_setzero_ps(); + _sum4 = _mm256_setzero_ps(); + _sum5 = _mm256_setzero_ps(); + _sum6 = _mm256_setzero_ps(); + _sum7 = _mm256_setzero_ps(); + _sum8 = _mm256_setzero_ps(); + _sum9 = _mm256_setzero_ps(); + _suma = _mm256_setzero_ps(); + _sumb = _mm256_setzero_ps(); + } + } + else + { + _sum0 = _mm256_load_ps(outptr); + _sum1 = _mm256_load_ps(outptr + 8 * 1); + _sum2 = _mm256_load_ps(outptr + 8 * 2); + _sum3 = _mm256_load_ps(outptr + 8 * 3); + _sum4 = _mm256_load_ps(outptr + 8 * 4); + _sum5 = _mm256_load_ps(outptr + 8 * 5); + _sum6 = _mm256_load_ps(outptr + 8 * 6); + _sum7 = _mm256_load_ps(outptr + 8 * 7); + _sum8 = _mm256_load_ps(outptr + 8 * 8); + _sum9 = _mm256_load_ps(outptr + 8 * 9); + _suma = _mm256_load_ps(outptr + 8 * 10); + _sumb = _mm256_load_ps(outptr + 8 * 11); + } + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + __m256 _pA = _mm256_load_ps(pA); + + _sum0 = _mm256_comp_fmadd_ps(_pA, _mm256_set1_ps(pB[0]), _sum0); + _sum1 = _mm256_comp_fmadd_ps(_pA, _mm256_set1_ps(pB[1]), _sum1); + _sum2 = _mm256_comp_fmadd_ps(_pA, _mm256_set1_ps(pB[2]), _sum2); + _sum3 = _mm256_comp_fmadd_ps(_pA, _mm256_set1_ps(pB[3]), _sum3); + _sum4 = _mm256_comp_fmadd_ps(_pA, _mm256_set1_ps(pB[4]), _sum4); + _sum5 = _mm256_comp_fmadd_ps(_pA, _mm256_set1_ps(pB[5]), _sum5); + _sum6 = _mm256_comp_fmadd_ps(_pA, _mm256_set1_ps(pB[6]), _sum6); + _sum7 = _mm256_comp_fmadd_ps(_pA, _mm256_set1_ps(pB[7]), _sum7); + _sum8 = _mm256_comp_fmadd_ps(_pA, _mm256_set1_ps(pB[8]), _sum8); + _sum9 = _mm256_comp_fmadd_ps(_pA, _mm256_set1_ps(pB[9]), _sum9); + _suma = _mm256_comp_fmadd_ps(_pA, _mm256_set1_ps(pB[10]), _suma); + _sumb = _mm256_comp_fmadd_ps(_pA, _mm256_set1_ps(pB[11]), _sumb); + + pA += 8; + pB += 12; + } + + if (k_end) + { + if (out_elempack == 8) + { + _mm256_store_ps(outptr0, _sum0); + _mm256_store_ps(outptr0 + 8 * 1, _sum1); + _mm256_store_ps(outptr0 + 8 * 2, _sum2); + _mm256_store_ps(outptr0 + 8 * 3, _sum3); + _mm256_store_ps(outptr0 + 8 * 4, _sum4); + _mm256_store_ps(outptr0 + 8 * 5, _sum5); + _mm256_store_ps(outptr0 + 8 * 6, _sum6); + _mm256_store_ps(outptr0 + 8 * 7, _sum7); + _mm256_store_ps(outptr0 + 8 * 8, _sum8); + _mm256_store_ps(outptr0 + 8 * 9, _sum9); + _mm256_store_ps(outptr0 + 8 * 10, _suma); + _mm256_store_ps(outptr0 + 8 * 11, _sumb); + outptr0 += 96; + } + if (out_elempack == 4) + { + __m256 _tmp0 = _mm256_permute2f128_ps(_sum0, _sum1, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp1 = _mm256_permute2f128_ps(_sum2, _sum3, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp2 = _mm256_permute2f128_ps(_sum4, _sum5, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp3 = _mm256_permute2f128_ps(_sum6, _sum7, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp4 = _mm256_permute2f128_ps(_sum8, _sum9, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp5 = _mm256_permute2f128_ps(_suma, _sumb, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp6 = _mm256_permute2f128_ps(_sum0, _sum1, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _tmp7 = _mm256_permute2f128_ps(_sum2, _sum3, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _tmp8 = _mm256_permute2f128_ps(_sum4, _sum5, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _tmp9 = _mm256_permute2f128_ps(_sum6, _sum7, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _tmpa = _mm256_permute2f128_ps(_sum8, _sum9, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _tmpb = _mm256_permute2f128_ps(_suma, _sumb, _MM_SHUFFLE(0, 3, 0, 1)); + + _mm256_storeu_ps(outptr0, _tmp0); + _mm256_storeu_ps(outptr0 + 8, _tmp1); + _mm256_storeu_ps(outptr0 + 8 * 2, _tmp2); + _mm256_storeu_ps(outptr0 + 8 * 3, _tmp3); + _mm256_storeu_ps(outptr0 + 8 * 4, _tmp4); + _mm256_storeu_ps(outptr0 + 8 * 5, _tmp5); + + _mm256_storeu_ps(outptr0 + out_hstep * 4, _tmp6); + _mm256_storeu_ps(outptr0 + out_hstep * 4 + 8, _tmp7); + _mm256_storeu_ps(outptr0 + out_hstep * 4 + 8 * 2, _tmp8); + _mm256_storeu_ps(outptr0 + out_hstep * 4 + 8 * 3, _tmp9); + _mm256_storeu_ps(outptr0 + out_hstep * 4 + 8 * 4, _tmpa); + _mm256_storeu_ps(outptr0 + out_hstep * 4 + 8 * 5, _tmpb); + + outptr0 += 48; + } + if (out_elempack == 1) + { + transpose8x8_ps(_sum0, _sum1, _sum2, _sum3, _sum4, _sum5, _sum6, _sum7); + + _mm256_storeu_ps(outptr0, _sum0); + _mm256_storeu_ps(outptr0 + out_hstep * 1, _sum1); + _mm256_storeu_ps(outptr0 + out_hstep * 2, _sum2); + _mm256_storeu_ps(outptr0 + out_hstep * 3, _sum3); + _mm256_storeu_ps(outptr0 + out_hstep * 4, _sum4); + _mm256_storeu_ps(outptr0 + out_hstep * 5, _sum5); + _mm256_storeu_ps(outptr0 + out_hstep * 6, _sum6); + _mm256_storeu_ps(outptr0 + out_hstep * 7, _sum7); + + __m128 _sum8_0 = _mm256_extractf128_ps(_sum8, 0); + __m128 _sum9_0 = _mm256_extractf128_ps(_sum9, 0); + __m128 _suma_0 = _mm256_extractf128_ps(_suma, 0); + __m128 _sumb_0 = _mm256_extractf128_ps(_sumb, 0); + __m128 _sum8_1 = _mm256_extractf128_ps(_sum8, 1); + __m128 _sum9_1 = _mm256_extractf128_ps(_sum9, 1); + __m128 _suma_1 = _mm256_extractf128_ps(_suma, 1); + __m128 _sumb_1 = _mm256_extractf128_ps(_sumb, 1); + + _MM_TRANSPOSE4_PS(_sum8_0, _sum9_0, _suma_0, _sumb_0); + _MM_TRANSPOSE4_PS(_sum8_1, _sum9_1, _suma_1, _sumb_1); + + _mm_storeu_ps(outptr0 + 8, _sum8_0); + _mm_storeu_ps(outptr0 + out_hstep * 1 + 8, _sum9_0); + _mm_storeu_ps(outptr0 + out_hstep * 2 + 8, _suma_0); + _mm_storeu_ps(outptr0 + out_hstep * 3 + 8, _sumb_0); + _mm_storeu_ps(outptr0 + out_hstep * 4 + 8, _sum8_1); + _mm_storeu_ps(outptr0 + out_hstep * 5 + 8, _sum9_1); + _mm_storeu_ps(outptr0 + out_hstep * 6 + 8, _suma_1); + _mm_storeu_ps(outptr0 + out_hstep * 7 + 8, _sumb_1); + + outptr0 += 12; + } + } + else + { + _mm256_store_ps(outptr, _sum0); + _mm256_store_ps(outptr + 8 * 1, _sum1); + _mm256_store_ps(outptr + 8 * 2, _sum2); + _mm256_store_ps(outptr + 8 * 3, _sum3); + _mm256_store_ps(outptr + 8 * 4, _sum4); + _mm256_store_ps(outptr + 8 * 5, _sum5); + _mm256_store_ps(outptr + 8 * 6, _sum6); + _mm256_store_ps(outptr + 8 * 7, _sum7); + _mm256_store_ps(outptr + 8 * 8, _sum8); + _mm256_store_ps(outptr + 8 * 9, _sum9); + _mm256_store_ps(outptr + 8 * 10, _suma); + _mm256_store_ps(outptr + 8 * 11, _sumb); + } + + outptr += 96; + } + for (; jj + 7 < max_jj; jj += 8) + { + const float* pA = pAT; + + __m256 _sum0; + __m256 _sum1; + __m256 _sum2; + __m256 _sum3; + __m256 _sum4; + __m256 _sum5; + __m256 _sum6; + __m256 _sum7; + + if (k == 0) + { + if (pC) + { + _sum0 = _mm256_loadu_ps(pC); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + } + else + { + _sum0 = _mm256_setzero_ps(); + _sum1 = _mm256_setzero_ps(); + _sum2 = _mm256_setzero_ps(); + _sum3 = _mm256_setzero_ps(); + _sum4 = _mm256_setzero_ps(); + _sum5 = _mm256_setzero_ps(); + _sum6 = _mm256_setzero_ps(); + _sum7 = _mm256_setzero_ps(); + } + } + else + { + _sum0 = _mm256_load_ps(outptr); + _sum1 = _mm256_load_ps(outptr + 8 * 1); + _sum2 = _mm256_load_ps(outptr + 8 * 2); + _sum3 = _mm256_load_ps(outptr + 8 * 3); + _sum4 = _mm256_load_ps(outptr + 8 * 4); + _sum5 = _mm256_load_ps(outptr + 8 * 5); + _sum6 = _mm256_load_ps(outptr + 8 * 6); + _sum7 = _mm256_load_ps(outptr + 8 * 7); + } + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + __m256 _pA = _mm256_load_ps(pA); + + _sum0 = _mm256_comp_fmadd_ps(_pA, _mm256_set1_ps(pB[0]), _sum0); + _sum1 = _mm256_comp_fmadd_ps(_pA, _mm256_set1_ps(pB[1]), _sum1); + _sum2 = _mm256_comp_fmadd_ps(_pA, _mm256_set1_ps(pB[2]), _sum2); + _sum3 = _mm256_comp_fmadd_ps(_pA, _mm256_set1_ps(pB[3]), _sum3); + _sum4 = _mm256_comp_fmadd_ps(_pA, _mm256_set1_ps(pB[4]), _sum4); + _sum5 = _mm256_comp_fmadd_ps(_pA, _mm256_set1_ps(pB[5]), _sum5); + _sum6 = _mm256_comp_fmadd_ps(_pA, _mm256_set1_ps(pB[6]), _sum6); + _sum7 = _mm256_comp_fmadd_ps(_pA, _mm256_set1_ps(pB[7]), _sum7); + + pA += 8; + pB += 8; + } + + if (k_end) + { + if (out_elempack == 8) + { + _mm256_store_ps(outptr0, _sum0); + _mm256_store_ps(outptr0 + 8 * 1, _sum1); + _mm256_store_ps(outptr0 + 8 * 2, _sum2); + _mm256_store_ps(outptr0 + 8 * 3, _sum3); + _mm256_store_ps(outptr0 + 8 * 4, _sum4); + _mm256_store_ps(outptr0 + 8 * 5, _sum5); + _mm256_store_ps(outptr0 + 8 * 6, _sum6); + _mm256_store_ps(outptr0 + 8 * 7, _sum7); + outptr0 += 64; + } + if (out_elempack == 4) + { + __m256 _tmp0 = _mm256_permute2f128_ps(_sum0, _sum1, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp1 = _mm256_permute2f128_ps(_sum2, _sum3, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp2 = _mm256_permute2f128_ps(_sum4, _sum5, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp3 = _mm256_permute2f128_ps(_sum6, _sum7, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp4 = _mm256_permute2f128_ps(_sum0, _sum1, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _tmp5 = _mm256_permute2f128_ps(_sum2, _sum3, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _tmp6 = _mm256_permute2f128_ps(_sum4, _sum5, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _tmp7 = _mm256_permute2f128_ps(_sum6, _sum7, _MM_SHUFFLE(0, 3, 0, 1)); + + _mm256_storeu_ps(outptr0, _tmp0); + _mm256_storeu_ps(outptr0 + 8, _tmp1); + _mm256_storeu_ps(outptr0 + 8 * 2, _tmp2); + _mm256_storeu_ps(outptr0 + 8 * 3, _tmp3); + + _mm256_storeu_ps(outptr0 + out_hstep * 4, _tmp4); + _mm256_storeu_ps(outptr0 + out_hstep * 4 + 8, _tmp5); + _mm256_storeu_ps(outptr0 + out_hstep * 4 + 8 * 2, _tmp6); + _mm256_storeu_ps(outptr0 + out_hstep * 4 + 8 * 3, _tmp7); + + outptr0 += 32; + } + if (out_elempack == 1) + { + transpose8x8_ps(_sum0, _sum1, _sum2, _sum3, _sum4, _sum5, _sum6, _sum7); + + _mm256_storeu_ps(outptr0, _sum0); + _mm256_storeu_ps(outptr0 + out_hstep * 1, _sum1); + _mm256_storeu_ps(outptr0 + out_hstep * 2, _sum2); + _mm256_storeu_ps(outptr0 + out_hstep * 3, _sum3); + _mm256_storeu_ps(outptr0 + out_hstep * 4, _sum4); + _mm256_storeu_ps(outptr0 + out_hstep * 5, _sum5); + _mm256_storeu_ps(outptr0 + out_hstep * 6, _sum6); + _mm256_storeu_ps(outptr0 + out_hstep * 7, _sum7); + + outptr0 += 8; + } + } + else + { + _mm256_store_ps(outptr, _sum0); + _mm256_store_ps(outptr + 8 * 1, _sum1); + _mm256_store_ps(outptr + 8 * 2, _sum2); + _mm256_store_ps(outptr + 8 * 3, _sum3); + _mm256_store_ps(outptr + 8 * 4, _sum4); + _mm256_store_ps(outptr + 8 * 5, _sum5); + _mm256_store_ps(outptr + 8 * 6, _sum6); + _mm256_store_ps(outptr + 8 * 7, _sum7); + } + + outptr += 64; + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + const float* pA = pAT; + + __m256 _sum0; + __m256 _sum1; + __m256 _sum2; + __m256 _sum3; + + if (k == 0) + { + if (pC) + { + _sum0 = _mm256_loadu_ps(pC); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + } + else + { + _sum0 = _mm256_setzero_ps(); + _sum1 = _mm256_setzero_ps(); + _sum2 = _mm256_setzero_ps(); + _sum3 = _mm256_setzero_ps(); + } + } + else + { + _sum0 = _mm256_load_ps(outptr); + _sum1 = _mm256_load_ps(outptr + 8 * 1); + _sum2 = _mm256_load_ps(outptr + 8 * 2); + _sum3 = _mm256_load_ps(outptr + 8 * 3); + } + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + __m256 _pA = _mm256_load_ps(pA); + + _sum0 = _mm256_comp_fmadd_ps(_pA, _mm256_set1_ps(pB[0]), _sum0); + _sum1 = _mm256_comp_fmadd_ps(_pA, _mm256_set1_ps(pB[1]), _sum1); + _sum2 = _mm256_comp_fmadd_ps(_pA, _mm256_set1_ps(pB[2]), _sum2); + _sum3 = _mm256_comp_fmadd_ps(_pA, _mm256_set1_ps(pB[3]), _sum3); + + pA += 8; + pB += 4; + } + + if (k_end) + { + if (out_elempack == 8) + { + _mm256_store_ps(outptr0, _sum0); + _mm256_store_ps(outptr0 + 8 * 1, _sum1); + _mm256_store_ps(outptr0 + 8 * 2, _sum2); + _mm256_store_ps(outptr0 + 8 * 3, _sum3); + outptr0 += 32; + } + if (out_elempack == 4) + { + __m256 _tmp0 = _mm256_permute2f128_ps(_sum0, _sum1, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp1 = _mm256_permute2f128_ps(_sum2, _sum3, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp2 = _mm256_permute2f128_ps(_sum0, _sum1, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _tmp3 = _mm256_permute2f128_ps(_sum2, _sum3, _MM_SHUFFLE(0, 3, 0, 1)); + + _mm256_storeu_ps(outptr0, _tmp0); + _mm256_storeu_ps(outptr0 + 8, _tmp1); + + _mm256_storeu_ps(outptr0 + out_hstep * 4, _tmp2); + _mm256_storeu_ps(outptr0 + out_hstep * 4 + 8, _tmp3); + + outptr0 += 16; + } + if (out_elempack == 1) + { + __m128 _sum0_0 = _mm256_extractf128_ps(_sum0, 0); + __m128 _sum1_0 = _mm256_extractf128_ps(_sum1, 0); + __m128 _sum2_0 = _mm256_extractf128_ps(_sum2, 0); + __m128 _sum3_0 = _mm256_extractf128_ps(_sum3, 0); + __m128 _sum0_1 = _mm256_extractf128_ps(_sum0, 1); + __m128 _sum1_1 = _mm256_extractf128_ps(_sum1, 1); + __m128 _sum2_1 = _mm256_extractf128_ps(_sum2, 1); + __m128 _sum3_1 = _mm256_extractf128_ps(_sum3, 1); + + _MM_TRANSPOSE4_PS(_sum0_0, _sum1_0, _sum2_0, _sum3_0); + _MM_TRANSPOSE4_PS(_sum0_1, _sum1_1, _sum2_1, _sum3_1); + + _mm_storeu_ps(outptr0, _sum0_0); + _mm_storeu_ps(outptr0 + out_hstep * 1, _sum1_0); + _mm_storeu_ps(outptr0 + out_hstep * 2, _sum2_0); + _mm_storeu_ps(outptr0 + out_hstep * 3, _sum3_0); + _mm_storeu_ps(outptr0 + out_hstep * 4, _sum0_1); + _mm_storeu_ps(outptr0 + out_hstep * 5, _sum1_1); + _mm_storeu_ps(outptr0 + out_hstep * 6, _sum2_1); + _mm_storeu_ps(outptr0 + out_hstep * 7, _sum3_1); + + outptr0 += 4; + } + } + else + { + _mm256_store_ps(outptr, _sum0); + _mm256_store_ps(outptr + 8 * 1, _sum1); + _mm256_store_ps(outptr + 8 * 2, _sum2); + _mm256_store_ps(outptr + 8 * 3, _sum3); + } + + outptr += 32; + } + for (; jj + 1 < max_jj; jj += 2) + { + const float* pA = pAT; + + __m256 _sum0; + __m256 _sum1; + + if (k == 0) + { + if (pC) + { + _sum0 = _mm256_loadu_ps(pC); + _sum1 = _sum0; + } + else + { + _sum0 = _mm256_setzero_ps(); + _sum1 = _mm256_setzero_ps(); + } + } + else + { + _sum0 = _mm256_load_ps(outptr); + _sum1 = _mm256_load_ps(outptr + 8); + } + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + __m256 _pA = _mm256_load_ps(pA); + + _sum0 = _mm256_comp_fmadd_ps(_pA, _mm256_set1_ps(pB[0]), _sum0); + _sum1 = _mm256_comp_fmadd_ps(_pA, _mm256_set1_ps(pB[1]), _sum1); + + pA += 8; + pB += 2; + } + + if (k_end) + { + if (out_elempack == 8) + { + _mm256_store_ps(outptr0, _sum0); + _mm256_store_ps(outptr0 + 8, _sum1); + outptr0 += 16; + } + if (out_elempack == 4) + { + __m256 _tmp0 = _mm256_permute2f128_ps(_sum0, _sum1, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp1 = _mm256_permute2f128_ps(_sum0, _sum1, _MM_SHUFFLE(0, 3, 0, 1)); + + _mm256_storeu_ps(outptr0, _tmp0); + _mm256_storeu_ps(outptr0 + out_hstep * 4, _tmp1); + outptr0 += 8; + } + if (out_elempack == 1) + { + float sum0[8]; + float sum1[8]; + _mm256_storeu_ps(sum0, _sum0); + _mm256_storeu_ps(sum1, _sum1); + + outptr0[0] = sum0[0]; + outptr0[out_hstep] = sum0[1]; + outptr0[out_hstep * 2] = sum0[2]; + outptr0[out_hstep * 3] = sum0[3]; + outptr0[out_hstep * 4] = sum0[4]; + outptr0[out_hstep * 5] = sum0[5]; + outptr0[out_hstep * 6] = sum0[6]; + outptr0[out_hstep * 7] = sum0[7]; + + outptr0[1] = sum1[0]; + outptr0[out_hstep + 1] = sum1[1]; + outptr0[out_hstep * 2 + 1] = sum1[2]; + outptr0[out_hstep * 3 + 1] = sum1[3]; + outptr0[out_hstep * 4 + 1] = sum1[4]; + outptr0[out_hstep * 5 + 1] = sum1[5]; + outptr0[out_hstep * 6 + 1] = sum1[6]; + outptr0[out_hstep * 7 + 1] = sum1[7]; + outptr0 += 2; + } + } + else + { + _mm256_store_ps(outptr, _sum0); + _mm256_store_ps(outptr + 8, _sum1); + } + + outptr += 16; + } + for (; jj < max_jj; jj += 1) + { + const float* pA = pAT; + + __m256 _sum0; + + if (k == 0) + { + if (pC) + { + _sum0 = _mm256_loadu_ps(pC); + } + else + { + _sum0 = _mm256_setzero_ps(); + } + } + else + { + _sum0 = _mm256_load_ps(outptr); + } + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + __m256 _pA = _mm256_load_ps(pA); + + _sum0 = _mm256_comp_fmadd_ps(_pA, _mm256_set1_ps(pB[0]), _sum0); + + pA += 8; + pB += 1; + } + + if (k_end) + { + if (out_elempack == 8) + { + _mm256_store_ps(outptr0, _sum0); + outptr0 += 8; + } + if (out_elempack == 4) + { + _mm_store_ps(outptr0, _mm256_extractf128_ps(_sum0, 0)); + _mm_store_ps(outptr0 + out_hstep * 4, _mm256_extractf128_ps(_sum0, 1)); + outptr0 += 4; + } + if (out_elempack == 1) + { + float sum0[8]; + _mm256_storeu_ps(sum0, _sum0); + + outptr0[0] = sum0[0]; + outptr0[out_hstep * 1] = sum0[1]; + outptr0[out_hstep * 2] = sum0[2]; + outptr0[out_hstep * 3] = sum0[3]; + outptr0[out_hstep * 4] = sum0[4]; + outptr0[out_hstep * 5] = sum0[5]; + outptr0[out_hstep * 6] = sum0[6]; + outptr0[out_hstep * 7] = sum0[7]; + outptr0++; + } + } + else + { + _mm256_store_ps(outptr, _sum0); + } + + outptr += 8; + } + + pAT += max_kk * 8; + } +#endif // __AVX__ + for (; ii + 3 < max_ii; ii += 4) + { + float* outptr0 = (float*)top_blob + (i + ii) * out_hstep + j * out_elempack; + + const float* pB = pBT; + + if (pC) + { + pC = (const float*)CT_tile + i + ii; + } + + int jj = 0; +#if defined(__x86_64__) || defined(_M_X64) + for (; jj + 11 < max_jj; jj += 12) + { + const float* pA = pAT; + + __m128 _sum0; + __m128 _sum1; + __m128 _sum2; + __m128 _sum3; + __m128 _sum4; + __m128 _sum5; + __m128 _sum6; + __m128 _sum7; + __m128 _sum8; + __m128 _sum9; + __m128 _suma; + __m128 _sumb; + + if (k == 0) + { + if (pC) + { + _sum0 = _mm_loadu_ps(pC); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + _sum8 = _sum0; + _sum9 = _sum0; + _suma = _sum0; + _sumb = _sum0; + } + else + { + _sum0 = _mm_setzero_ps(); + _sum1 = _mm_setzero_ps(); + _sum2 = _mm_setzero_ps(); + _sum3 = _mm_setzero_ps(); + _sum4 = _mm_setzero_ps(); + _sum5 = _mm_setzero_ps(); + _sum6 = _mm_setzero_ps(); + _sum7 = _mm_setzero_ps(); + _sum8 = _mm_setzero_ps(); + _sum9 = _mm_setzero_ps(); + _suma = _mm_setzero_ps(); + _sumb = _mm_setzero_ps(); + } + } + else + { + _sum0 = _mm_load_ps(outptr); + _sum1 = _mm_load_ps(outptr + 4 * 1); + _sum2 = _mm_load_ps(outptr + 4 * 2); + _sum3 = _mm_load_ps(outptr + 4 * 3); + _sum4 = _mm_load_ps(outptr + 4 * 4); + _sum5 = _mm_load_ps(outptr + 4 * 5); + _sum6 = _mm_load_ps(outptr + 4 * 6); + _sum7 = _mm_load_ps(outptr + 4 * 7); + _sum8 = _mm_load_ps(outptr + 4 * 8); + _sum9 = _mm_load_ps(outptr + 4 * 9); + _suma = _mm_load_ps(outptr + 4 * 10); + _sumb = _mm_load_ps(outptr + 4 * 11); + } + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + __m128 _pA = _mm_loadu_ps(pA); + + _sum0 = _mm_comp_fmadd_ps(_pA, _mm_set1_ps(pB[0]), _sum0); + _sum1 = _mm_comp_fmadd_ps(_pA, _mm_set1_ps(pB[1]), _sum1); + _sum2 = _mm_comp_fmadd_ps(_pA, _mm_set1_ps(pB[2]), _sum2); + _sum3 = _mm_comp_fmadd_ps(_pA, _mm_set1_ps(pB[3]), _sum3); + _sum4 = _mm_comp_fmadd_ps(_pA, _mm_set1_ps(pB[4]), _sum4); + _sum5 = _mm_comp_fmadd_ps(_pA, _mm_set1_ps(pB[5]), _sum5); + _sum6 = _mm_comp_fmadd_ps(_pA, _mm_set1_ps(pB[6]), _sum6); + _sum7 = _mm_comp_fmadd_ps(_pA, _mm_set1_ps(pB[7]), _sum7); + _sum8 = _mm_comp_fmadd_ps(_pA, _mm_set1_ps(pB[8]), _sum8); + _sum9 = _mm_comp_fmadd_ps(_pA, _mm_set1_ps(pB[9]), _sum9); + _suma = _mm_comp_fmadd_ps(_pA, _mm_set1_ps(pB[10]), _suma); + _sumb = _mm_comp_fmadd_ps(_pA, _mm_set1_ps(pB[11]), _sumb); + + pA += 4; + pB += 12; + } + + if (k_end) + { + if (out_elempack == 4) + { + _mm_storeu_ps(outptr0, _sum0); + _mm_storeu_ps(outptr0 + 4, _sum1); + _mm_storeu_ps(outptr0 + 4 * 2, _sum2); + _mm_storeu_ps(outptr0 + 4 * 3, _sum3); + _mm_storeu_ps(outptr0 + 4 * 4, _sum4); + _mm_storeu_ps(outptr0 + 4 * 5, _sum5); + _mm_storeu_ps(outptr0 + 4 * 6, _sum6); + _mm_storeu_ps(outptr0 + 4 * 7, _sum7); + _mm_storeu_ps(outptr0 + 4 * 8, _sum8); + _mm_storeu_ps(outptr0 + 4 * 9, _sum9); + _mm_storeu_ps(outptr0 + 4 * 10, _suma); + _mm_storeu_ps(outptr0 + 4 * 11, _sumb); + outptr0 += 48; + } + if (out_elempack == 1) + { + _MM_TRANSPOSE4_PS(_sum0, _sum1, _sum2, _sum3); + _MM_TRANSPOSE4_PS(_sum4, _sum5, _sum6, _sum7); + _MM_TRANSPOSE4_PS(_sum8, _sum9, _suma, _sumb); + + _mm_storeu_ps(outptr0, _sum0); + _mm_storeu_ps(outptr0 + out_hstep * 1, _sum1); + _mm_storeu_ps(outptr0 + out_hstep * 2, _sum2); + _mm_storeu_ps(outptr0 + out_hstep * 3, _sum3); + _mm_storeu_ps(outptr0 + 4, _sum4); + _mm_storeu_ps(outptr0 + out_hstep * 1 + 4, _sum5); + _mm_storeu_ps(outptr0 + out_hstep * 2 + 4, _sum6); + _mm_storeu_ps(outptr0 + out_hstep * 3 + 4, _sum7); + _mm_storeu_ps(outptr0 + 8, _sum8); + _mm_storeu_ps(outptr0 + out_hstep * 1 + 8, _sum9); + _mm_storeu_ps(outptr0 + out_hstep * 2 + 8, _suma); + _mm_storeu_ps(outptr0 + out_hstep * 3 + 8, _sumb); + outptr0 += 12; + } + } + else + { + _mm_store_ps(outptr, _sum0); + _mm_store_ps(outptr + 4, _sum1); + _mm_store_ps(outptr + 4 * 2, _sum2); + _mm_store_ps(outptr + 4 * 3, _sum3); + _mm_store_ps(outptr + 4 * 4, _sum4); + _mm_store_ps(outptr + 4 * 5, _sum5); + _mm_store_ps(outptr + 4 * 6, _sum6); + _mm_store_ps(outptr + 4 * 7, _sum7); + _mm_store_ps(outptr + 4 * 8, _sum8); + _mm_store_ps(outptr + 4 * 9, _sum9); + _mm_store_ps(outptr + 4 * 10, _suma); + _mm_store_ps(outptr + 4 * 11, _sumb); + } + + outptr += 48; + } + for (; jj + 7 < max_jj; jj += 8) + { + const float* pA = pAT; + + __m128 _sum0; + __m128 _sum1; + __m128 _sum2; + __m128 _sum3; + __m128 _sum4; + __m128 _sum5; + __m128 _sum6; + __m128 _sum7; + + if (k == 0) + { + if (pC) + { + _sum0 = _mm_loadu_ps(pC); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + } + else + { + _sum0 = _mm_setzero_ps(); + _sum1 = _mm_setzero_ps(); + _sum2 = _mm_setzero_ps(); + _sum3 = _mm_setzero_ps(); + _sum4 = _mm_setzero_ps(); + _sum5 = _mm_setzero_ps(); + _sum6 = _mm_setzero_ps(); + _sum7 = _mm_setzero_ps(); + } + } + else + { + _sum0 = _mm_load_ps(outptr); + _sum1 = _mm_load_ps(outptr + 4 * 1); + _sum2 = _mm_load_ps(outptr + 4 * 2); + _sum3 = _mm_load_ps(outptr + 4 * 3); + _sum4 = _mm_load_ps(outptr + 4 * 4); + _sum5 = _mm_load_ps(outptr + 4 * 5); + _sum6 = _mm_load_ps(outptr + 4 * 6); + _sum7 = _mm_load_ps(outptr + 4 * 7); + } + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + __m128 _pA = _mm_loadu_ps(pA); + + _sum0 = _mm_comp_fmadd_ps(_pA, _mm_set1_ps(pB[0]), _sum0); + _sum1 = _mm_comp_fmadd_ps(_pA, _mm_set1_ps(pB[1]), _sum1); + _sum2 = _mm_comp_fmadd_ps(_pA, _mm_set1_ps(pB[2]), _sum2); + _sum3 = _mm_comp_fmadd_ps(_pA, _mm_set1_ps(pB[3]), _sum3); + _sum4 = _mm_comp_fmadd_ps(_pA, _mm_set1_ps(pB[4]), _sum4); + _sum5 = _mm_comp_fmadd_ps(_pA, _mm_set1_ps(pB[5]), _sum5); + _sum6 = _mm_comp_fmadd_ps(_pA, _mm_set1_ps(pB[6]), _sum6); + _sum7 = _mm_comp_fmadd_ps(_pA, _mm_set1_ps(pB[7]), _sum7); + + pA += 4; + pB += 8; + } + + if (k_end) + { + if (out_elempack == 4) + { + _mm_storeu_ps(outptr0, _sum0); + _mm_storeu_ps(outptr0 + 4, _sum1); + _mm_storeu_ps(outptr0 + 4 * 2, _sum2); + _mm_storeu_ps(outptr0 + 4 * 3, _sum3); + _mm_storeu_ps(outptr0 + 4 * 4, _sum4); + _mm_storeu_ps(outptr0 + 4 * 5, _sum5); + _mm_storeu_ps(outptr0 + 4 * 6, _sum6); + _mm_storeu_ps(outptr0 + 4 * 7, _sum7); + outptr0 += 32; + } + if (out_elempack == 1) + { + _MM_TRANSPOSE4_PS(_sum0, _sum1, _sum2, _sum3); + _MM_TRANSPOSE4_PS(_sum4, _sum5, _sum6, _sum7); + + _mm_storeu_ps(outptr0, _sum0); + _mm_storeu_ps(outptr0 + out_hstep * 1, _sum1); + _mm_storeu_ps(outptr0 + out_hstep * 2, _sum2); + _mm_storeu_ps(outptr0 + out_hstep * 3, _sum3); + _mm_storeu_ps(outptr0 + 4, _sum4); + _mm_storeu_ps(outptr0 + out_hstep * 1 + 4, _sum5); + _mm_storeu_ps(outptr0 + out_hstep * 2 + 4, _sum6); + _mm_storeu_ps(outptr0 + out_hstep * 3 + 4, _sum7); + outptr0 += 8; + } + } + else + { + _mm_store_ps(outptr, _sum0); + _mm_store_ps(outptr + 4, _sum1); + _mm_store_ps(outptr + 4 * 2, _sum2); + _mm_store_ps(outptr + 4 * 3, _sum3); + _mm_store_ps(outptr + 4 * 4, _sum4); + _mm_store_ps(outptr + 4 * 5, _sum5); + _mm_store_ps(outptr + 4 * 6, _sum6); + _mm_store_ps(outptr + 4 * 7, _sum7); + } + + outptr += 32; + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + const float* pA = pAT; + + __m128 _sum0; + __m128 _sum1; + __m128 _sum2; + __m128 _sum3; + + if (k == 0) + { + if (pC) + { + _sum0 = _mm_loadu_ps(pC); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + } + else + { + _sum0 = _mm_setzero_ps(); + _sum1 = _mm_setzero_ps(); + _sum2 = _mm_setzero_ps(); + _sum3 = _mm_setzero_ps(); + } + } + else + { + _sum0 = _mm_load_ps(outptr); + _sum1 = _mm_load_ps(outptr + 4 * 1); + _sum2 = _mm_load_ps(outptr + 4 * 2); + _sum3 = _mm_load_ps(outptr + 4 * 3); + } + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + __m128 _pA = _mm_loadu_ps(pA); + + _sum0 = _mm_comp_fmadd_ps(_pA, _mm_set1_ps(pB[0]), _sum0); + _sum1 = _mm_comp_fmadd_ps(_pA, _mm_set1_ps(pB[1]), _sum1); + _sum2 = _mm_comp_fmadd_ps(_pA, _mm_set1_ps(pB[2]), _sum2); + _sum3 = _mm_comp_fmadd_ps(_pA, _mm_set1_ps(pB[3]), _sum3); + + pA += 4; + pB += 4; + } + + if (k_end) + { + if (out_elempack == 4) + { + _mm_storeu_ps(outptr0, _sum0); + _mm_storeu_ps(outptr0 + 4, _sum1); + _mm_storeu_ps(outptr0 + 4 * 2, _sum2); + _mm_storeu_ps(outptr0 + 4 * 3, _sum3); + outptr0 += 16; + } + if (out_elempack == 1) + { + _MM_TRANSPOSE4_PS(_sum0, _sum1, _sum2, _sum3); + + _mm_storeu_ps(outptr0, _sum0); + _mm_storeu_ps(outptr0 + out_hstep * 1, _sum1); + _mm_storeu_ps(outptr0 + out_hstep * 2, _sum2); + _mm_storeu_ps(outptr0 + out_hstep * 3, _sum3); + outptr0 += 4; + } + } + else + { + _mm_store_ps(outptr, _sum0); + _mm_store_ps(outptr + 4, _sum1); + _mm_store_ps(outptr + 4 * 2, _sum2); + _mm_store_ps(outptr + 4 * 3, _sum3); + } + + outptr += 16; + } + for (; jj + 1 < max_jj; jj += 2) + { + const float* pA = pAT; + + __m128 _sum0; + __m128 _sum1; + + if (k == 0) + { + if (pC) + { + _sum0 = _mm_loadu_ps(pC); + _sum1 = _sum0; + } + else + { + _sum0 = _mm_setzero_ps(); + _sum1 = _mm_setzero_ps(); + } + } + else + { + _sum0 = _mm_load_ps(outptr); + _sum1 = _mm_load_ps(outptr + 4); + } + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + __m128 _pA = _mm_loadu_ps(pA); + + _sum0 = _mm_comp_fmadd_ps(_pA, _mm_set1_ps(pB[0]), _sum0); + _sum1 = _mm_comp_fmadd_ps(_pA, _mm_set1_ps(pB[1]), _sum1); + + pA += 4; + pB += 2; + } + + if (k_end) + { + if (out_elempack == 4) + { + _mm_storeu_ps(outptr0, _sum0); + _mm_storeu_ps(outptr0 + 4, _sum1); + outptr0 += 8; + } + if (out_elempack == 1) + { + float sum0[4]; + float sum1[4]; + _mm_storeu_ps(sum0, _sum0); + _mm_storeu_ps(sum1, _sum1); + + outptr0[0] = sum0[0]; + outptr0[out_hstep] = sum0[1]; + outptr0[out_hstep * 2] = sum0[2]; + outptr0[out_hstep * 3] = sum0[3]; + outptr0[1] = sum1[0]; + outptr0[out_hstep + 1] = sum1[1]; + outptr0[out_hstep * 2 + 1] = sum1[2]; + outptr0[out_hstep * 3 + 1] = sum1[3]; + outptr0 += 2; + } + } + else + { + _mm_store_ps(outptr, _sum0); + _mm_store_ps(outptr + 4, _sum1); + } + + outptr += 8; + } + for (; jj < max_jj; jj += 1) + { + const float* pA = pAT; + + __m128 _sum0; + + if (k == 0) + { + if (pC) + { + _sum0 = _mm_loadu_ps(pC); + } + else + { + _sum0 = _mm_setzero_ps(); + } + } + else + { + _sum0 = _mm_load_ps(outptr); + } + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + __m128 _pA = _mm_loadu_ps(pA); + + _sum0 = _mm_comp_fmadd_ps(_pA, _mm_set1_ps(pB[0]), _sum0); + + pA += 4; + pB += 1; + } + + if (k_end) + { + if (out_elempack == 4) + { + _mm_storeu_ps(outptr0, _sum0); + outptr0 += 4; + } + if (out_elempack == 1) + { + float sum0[4]; + _mm_storeu_ps(sum0, _sum0); + + outptr0[0] = sum0[0]; + outptr0[out_hstep] = sum0[1]; + outptr0[out_hstep * 2] = sum0[2]; + outptr0[out_hstep * 3] = sum0[3]; + outptr0++; + } + } + else + { + _mm_store_ps(outptr, _sum0); + } + + outptr += 4; + } + + pAT += max_kk * 4; + } +#endif // __SSE2__ + for (; ii + 1 < max_ii; ii += 2) + { + float* outptr0 = (float*)top_blob + (i + ii) * out_hstep + j; + + const float* pB = pBT; + + if (pC) + { + pC = (const float*)CT_tile + i + ii; + } + + int jj = 0; +#if __SSE2__ +#if defined(__x86_64__) || defined(_M_X64) + for (; jj + 11 < max_jj; jj += 12) + { + __m128 _sum00; + __m128 _sum01; + __m128 _sum02; + __m128 _sum10; + __m128 _sum11; + __m128 _sum12; + + if (k == 0) + { + if (pC) + { + _sum00 = _mm_set1_ps(pC[0]); + _sum01 = _mm_set1_ps(pC[0]); + _sum02 = _mm_set1_ps(pC[0]); + _sum10 = _mm_set1_ps(pC[1]); + _sum11 = _mm_set1_ps(pC[1]); + _sum12 = _mm_set1_ps(pC[1]); + } + else + { + _sum00 = _mm_setzero_ps(); + _sum01 = _mm_setzero_ps(); + _sum02 = _mm_setzero_ps(); + _sum10 = _mm_setzero_ps(); + _sum11 = _mm_setzero_ps(); + _sum12 = _mm_setzero_ps(); + } + } + else + { + __m128 _tmp0 = _mm_loadu_ps(outptr); + __m128 _tmp1 = _mm_loadu_ps(outptr + 4); + __m128 _tmp2 = _mm_loadu_ps(outptr + 8); + __m128 _tmp3 = _mm_loadu_ps(outptr + 12); + __m128 _tmp4 = _mm_loadu_ps(outptr + 16); + __m128 _tmp5 = _mm_loadu_ps(outptr + 20); + _sum00 = _mm_shuffle_ps(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _sum01 = _mm_shuffle_ps(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _sum02 = _mm_shuffle_ps(_tmp4, _tmp5, _MM_SHUFFLE(2, 0, 2, 0)); + _sum10 = _mm_shuffle_ps(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _sum11 = _mm_shuffle_ps(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + _sum12 = _mm_shuffle_ps(_tmp4, _tmp5, _MM_SHUFFLE(3, 1, 3, 1)); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + __m128 _pB0 = _mm_loadu_ps(pB); + __m128 _pB1 = _mm_loadu_ps(pB + 4); + __m128 _pB2 = _mm_loadu_ps(pB + 8); + + __m128 _pA0 = _mm_set1_ps(pA[0]); + _sum00 = _mm_comp_fmadd_ps(_pA0, _pB0, _sum00); + _sum01 = _mm_comp_fmadd_ps(_pA0, _pB1, _sum01); + _sum02 = _mm_comp_fmadd_ps(_pA0, _pB2, _sum02); + __m128 _pA1 = _mm_set1_ps(pA[1]); + _sum10 = _mm_comp_fmadd_ps(_pA1, _pB0, _sum10); + _sum11 = _mm_comp_fmadd_ps(_pA1, _pB1, _sum11); + _sum12 = _mm_comp_fmadd_ps(_pA1, _pB2, _sum12); + + pA += 2; + pB += 12; + } + + if (k_end) + { + // if (out_elempack == 1) + { + _mm_storeu_ps(outptr0, _sum00); + _mm_storeu_ps(outptr0 + 4, _sum01); + _mm_storeu_ps(outptr0 + 8, _sum02); + _mm_storeu_ps(outptr0 + out_hstep, _sum10); + _mm_storeu_ps(outptr0 + out_hstep + 4, _sum11); + _mm_storeu_ps(outptr0 + out_hstep + 8, _sum12); + outptr0 += 12; + } + } + else + { + __m128 _tmp0 = _mm_unpacklo_ps(_sum00, _sum10); + __m128 _tmp1 = _mm_unpackhi_ps(_sum00, _sum10); + __m128 _tmp2 = _mm_unpacklo_ps(_sum01, _sum11); + __m128 _tmp3 = _mm_unpackhi_ps(_sum01, _sum11); + __m128 _tmp4 = _mm_unpacklo_ps(_sum02, _sum12); + __m128 _tmp5 = _mm_unpackhi_ps(_sum02, _sum12); + _mm_store_ps(outptr, _tmp0); + _mm_store_ps(outptr + 4, _tmp1); + _mm_store_ps(outptr + 8, _tmp2); + _mm_store_ps(outptr + 12, _tmp3); + _mm_store_ps(outptr + 16, _tmp4); + _mm_store_ps(outptr + 20, _tmp5); + } + + outptr += 24; + } + for (; jj + 7 < max_jj; jj += 8) + { + __m128 _sum00; + __m128 _sum01; + __m128 _sum10; + __m128 _sum11; + + if (k == 0) + { + if (pC) + { + _sum00 = _mm_set1_ps(pC[0]); + _sum01 = _mm_set1_ps(pC[0]); + _sum10 = _mm_set1_ps(pC[1]); + _sum11 = _mm_set1_ps(pC[1]); + } + else + { + _sum00 = _mm_setzero_ps(); + _sum01 = _mm_setzero_ps(); + _sum10 = _mm_setzero_ps(); + _sum11 = _mm_setzero_ps(); + } + } + else + { + __m128 _tmp0 = _mm_loadu_ps(outptr); + __m128 _tmp1 = _mm_loadu_ps(outptr + 4); + __m128 _tmp2 = _mm_loadu_ps(outptr + 8); + __m128 _tmp3 = _mm_loadu_ps(outptr + 12); + _sum00 = _mm_shuffle_ps(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _sum01 = _mm_shuffle_ps(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _sum10 = _mm_shuffle_ps(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _sum11 = _mm_shuffle_ps(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + __m128 _pB0 = _mm_loadu_ps(pB); + __m128 _pB1 = _mm_loadu_ps(pB + 4); + + __m128 _pA0 = _mm_set1_ps(pA[0]); + _sum00 = _mm_comp_fmadd_ps(_pA0, _pB0, _sum00); + _sum01 = _mm_comp_fmadd_ps(_pA0, _pB1, _sum01); + __m128 _pA1 = _mm_set1_ps(pA[1]); + _sum10 = _mm_comp_fmadd_ps(_pA1, _pB0, _sum10); + _sum11 = _mm_comp_fmadd_ps(_pA1, _pB1, _sum11); + + pA += 2; + pB += 8; + } + + if (k_end) + { + // if (out_elempack == 1) + { + _mm_storeu_ps(outptr0, _sum00); + _mm_storeu_ps(outptr0 + 4, _sum01); + _mm_storeu_ps(outptr0 + out_hstep, _sum10); + _mm_storeu_ps(outptr0 + out_hstep + 4, _sum11); + outptr0 += 8; + } + } + else + { + __m128 _tmp0 = _mm_unpacklo_ps(_sum00, _sum10); + __m128 _tmp1 = _mm_unpackhi_ps(_sum00, _sum10); + __m128 _tmp2 = _mm_unpacklo_ps(_sum01, _sum11); + __m128 _tmp3 = _mm_unpackhi_ps(_sum01, _sum11); + _mm_store_ps(outptr, _tmp0); + _mm_store_ps(outptr + 4, _tmp1); + _mm_store_ps(outptr + 8, _tmp2); + _mm_store_ps(outptr + 12, _tmp3); + } + + outptr += 16; + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + __m128 _sum0; + __m128 _sum1; + + if (k == 0) + { + if (pC) + { + _sum0 = _mm_set1_ps(pC[0]); + _sum1 = _mm_set1_ps(pC[1]); + } + else + { + _sum0 = _mm_setzero_ps(); + _sum1 = _mm_setzero_ps(); + } + } + else + { + __m128 _tmp0 = _mm_loadu_ps(outptr); + __m128 _tmp1 = _mm_loadu_ps(outptr + 4); + _sum0 = _mm_shuffle_ps(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _sum1 = _mm_shuffle_ps(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + __m128 _pB = _mm_loadu_ps(pB); + + _sum0 = _mm_comp_fmadd_ps(_mm_set1_ps(pA[0]), _pB, _sum0); + _sum1 = _mm_comp_fmadd_ps(_mm_set1_ps(pA[1]), _pB, _sum1); + + pA += 2; + pB += 4; + } + + if (k_end) + { + // if (out_elempack == 1) + { + _mm_storeu_ps(outptr0, _sum0); + _mm_storeu_ps(outptr0 + out_hstep, _sum1); + outptr0 += 4; + } + } + else + { + __m128 _tmp0 = _mm_unpacklo_ps(_sum0, _sum1); + __m128 _tmp1 = _mm_unpackhi_ps(_sum0, _sum1); + _mm_storeu_ps(outptr, _tmp0); + _mm_storeu_ps(outptr + 4, _tmp1); + } + + outptr += 8; + } +#endif // __SSE2__ + for (; jj + 1 < max_jj; jj += 2) + { + float sum00; + float sum01; + float sum10; + float sum11; + + if (k == 0) + { + if (pC) + { + sum00 = pC[0]; + sum01 = pC[1]; + sum10 = pC[0]; + sum11 = pC[1]; + } + else + { + sum00 = 0.f; + sum01 = 0.f; + sum10 = 0.f; + sum11 = 0.f; + } + } + else + { + sum00 = outptr[0]; + sum01 = outptr[1]; + sum10 = outptr[2]; + sum11 = outptr[3]; + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + sum00 += pA[0] * pB[0]; + sum01 += pA[1] * pB[0]; + sum10 += pA[0] * pB[1]; + sum11 += pA[1] * pB[1]; + + pA += 2; + pB += 2; + } + + if (k_end) + { + // if (out_elempack == 1) + { + outptr0[0] = sum00; + outptr0[1] = sum10; + outptr0[out_hstep] = sum01; + outptr0[out_hstep + 1] = sum11; + outptr0 += 2; + } + } + else + { + outptr[0] = sum00; + outptr[1] = sum01; + outptr[2] = sum10; + outptr[3] = sum11; + } + + outptr += 4; + } + for (; jj < max_jj; jj += 1) + { + float sum0; + float sum1; + + if (k == 0) + { + if (pC) + { + sum0 = pC[0]; + sum1 = pC[1]; + } + else + { + sum0 = 0.f; + sum1 = 0.f; + } + } + else + { + sum0 = outptr[0]; + sum1 = outptr[1]; + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + sum0 += pA[0] * pB[0]; + sum1 += pA[1] * pB[0]; + pA += 2; + pB += 1; + } + + if (k_end) + { + // if (out_elempack == 1) + { + outptr0[0] = sum0; + outptr0[out_hstep] = sum1; + outptr0++; + } + } + else + { + outptr[0] = sum0; + outptr[1] = sum1; + } + + outptr += 2; + } + + pAT += max_kk * 2; + } + for (; ii < max_ii; ii += 1) + { + float* outptr0 = (float*)top_blob + (i + ii) * out_hstep + j; + + const float* pB = pBT; + + if (pC) + { + pC = (const float*)CT_tile + i + ii; + } + + int jj = 0; +#if __SSE2__ +#if defined(__x86_64__) || defined(_M_X64) + for (; jj + 11 < max_jj; jj += 12) + { + __m128 _sum0; + __m128 _sum1; + __m128 _sum2; + + if (k == 0) + { + if (pC) + { + _sum0 = _mm_set1_ps(pC[0]); + _sum1 = _mm_set1_ps(pC[0]); + _sum2 = _mm_set1_ps(pC[0]); + } + else + { + _sum0 = _mm_setzero_ps(); + _sum1 = _mm_setzero_ps(); + _sum2 = _mm_setzero_ps(); + } + } + else + { + _sum0 = _mm_loadu_ps(outptr); + _sum1 = _mm_loadu_ps(outptr + 4); + _sum2 = _mm_loadu_ps(outptr + 8); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + __m128 _pB0 = _mm_loadu_ps(pB); + __m128 _pB1 = _mm_loadu_ps(pB + 4); + __m128 _pB2 = _mm_loadu_ps(pB + 8); + + __m128 _pA0 = _mm_set1_ps(pA[0]); + _sum0 = _mm_comp_fmadd_ps(_pA0, _pB0, _sum0); + _sum1 = _mm_comp_fmadd_ps(_pA0, _pB1, _sum1); + _sum2 = _mm_comp_fmadd_ps(_pA0, _pB2, _sum2); + + pA += 1; + pB += 12; + } + + if (k_end) + { + // if (out_elempack == 1) + { + _mm_storeu_ps(outptr0, _sum0); + _mm_storeu_ps(outptr0 + 4, _sum1); + _mm_storeu_ps(outptr0 + 8, _sum2); + outptr0 += 12; + } + } + else + { + _mm_storeu_ps(outptr, _sum0); + _mm_storeu_ps(outptr + 4, _sum1); + _mm_storeu_ps(outptr + 8, _sum2); + } + + outptr += 12; + } + for (; jj + 7 < max_jj; jj += 8) + { + __m128 _sum0; + __m128 _sum1; + + if (k == 0) + { + if (pC) + { + _sum0 = _mm_set1_ps(pC[0]); + _sum1 = _mm_set1_ps(pC[0]); + } + else + { + _sum0 = _mm_setzero_ps(); + _sum1 = _mm_setzero_ps(); + } + } + else + { + _sum0 = _mm_loadu_ps(outptr); + _sum1 = _mm_loadu_ps(outptr + 4); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + __m128 _pB0 = _mm_loadu_ps(pB); + __m128 _pB1 = _mm_loadu_ps(pB + 4); + + __m128 _pA0 = _mm_set1_ps(pA[0]); + _sum0 = _mm_comp_fmadd_ps(_pA0, _pB0, _sum0); + _sum1 = _mm_comp_fmadd_ps(_pA0, _pB1, _sum1); + + pA += 1; + pB += 8; + } + + if (k_end) + { + // if (out_elempack == 1) + { + _mm_storeu_ps(outptr0, _sum0); + _mm_storeu_ps(outptr0 + 4, _sum1); + outptr0 += 8; + } + } + else + { + _mm_storeu_ps(outptr, _sum0); + _mm_storeu_ps(outptr + 4, _sum1); + } + + outptr += 8; + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + __m128 _sum; + + if (k == 0) + { + if (pC) + { + _sum = _mm_set1_ps(pC[0]); + } + else + { + _sum = _mm_setzero_ps(); + } + } + else + { + _sum = _mm_loadu_ps(outptr); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + __m128 _pB = _mm_loadu_ps(pB); + + _sum = _mm_comp_fmadd_ps(_mm_set1_ps(pA[0]), _pB, _sum); + + pA += 1; + pB += 4; + } + + if (k_end) + { + // if (out_elempack == 1) + { + _mm_storeu_ps(outptr0, _sum); + outptr0 += 4; + } + } + else + { + _mm_storeu_ps(outptr, _sum); + } + + outptr += 4; + } +#endif // __SSE2__ + for (; jj + 1 < max_jj; jj += 2) + { + float sum0; + float sum1; + + if (k == 0) + { + if (pC) + { + sum0 = pC[0]; + sum1 = pC[0]; + } + else + { + sum0 = 0.f; + sum1 = 0.f; + } + } + else + { + sum0 = outptr[0]; + sum1 = outptr[1]; + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + sum0 += pA[0] * pB[0]; + sum1 += pA[0] * pB[1]; + + pA += 1; + pB += 2; + } + + if (k_end) + { + // if (out_elempack == 1) + { + outptr0[0] = sum0; + outptr0[1] = sum1; + outptr0 += 2; + } + } + else + { + outptr[0] = sum0; + outptr[1] = sum1; + } + + outptr += 2; + } + for (; jj < max_jj; jj += 1) + { + float sum; + + if (k == 0) + { + if (pC) + { + sum = pC[0]; + } + else + { + sum = 0.f; + } + } + else + { + sum = outptr[0]; + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + sum += pA[0] * pB[0]; + pA += 1; + pB += 1; + } + + if (k_end) + { + // if (out_elempack == 1) + { + outptr0[0] = sum; + outptr0++; + } + } + else + { + outptr[0] = sum; + } + + outptr += 1; + } + + pAT += max_kk; + } +} + +static void convolution_im2col_gemm_get_optimal_tile_mnk(int M, int N, int K, int& TILE_M, int& TILE_N, int& TILE_K, int nT) +{ + // resolve optimal tile size from cache size + const int l2_cache_size_fp32 = (int)(get_cpu_level2_cache_size() / sizeof(float)); + + if (nT == 0) + nT = get_physical_big_cpu_count(); + + // solve K + { + // try not to split K +#if __AVX512F__ + int tile_size = (l2_cache_size_fp32 - 64) / 16; +#elif __AVX__ + int tile_size = (l2_cache_size_fp32 - 32) / 8; +#elif __SSE2__ + int tile_size = (l2_cache_size_fp32 - 16) / 8; +#else + int tile_size = (l2_cache_size_fp32 - 2) / 3; +#endif + +#if __AVX512F__ + TILE_K = std::max(16, tile_size / 16 * 16); +#elif __AVX__ + TILE_K = std::max(8, tile_size / 8 * 8); +#elif __SSE2__ + TILE_K = std::max(4, tile_size / 4 * 4); +#else + TILE_K = std::max(2, tile_size / 2 * 2); +#endif + + int nn_K = (K + TILE_K - 1) / TILE_K; +#if __AVX512F__ + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 15) / 16 * 16); +#elif __AVX__ + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 7) / 8 * 8); +#elif __SSE2__ + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 3) / 4 * 4); +#else + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 1) / 2 * 2); +#endif + } + + // solve M + { +#if __AVX512F__ + int nn_M = (M + 63) / 64; +#elif __AVX__ + int nn_M = (M + 31) / 32; +#elif __SSE2__ + int nn_M = (M + 15) / 16; +#else + int nn_M = (M + 7) / 8; +#endif + +#if __AVX512F__ + TILE_M = std::max(16, ((M + nn_M - 1) / nn_M + 15) / 16 * 16); +#elif __AVX__ + TILE_M = std::max(8, ((M + nn_M - 1) / nn_M + 7) / 8 * 8); +#elif __SSE2__ + TILE_M = std::max(4, ((M + nn_M - 1) / nn_M + 3) / 4 * 4); +#else + TILE_M = std::max(2, ((M + nn_M - 1) / nn_M + 1) / 2 * 2); +#endif + } + + { + TILE_M *= std::min(nT, get_physical_cpu_count()); + + int nn_M = (M + TILE_M - 1) / TILE_M; +#if __AVX512F__ + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 15) / 16 * 16); +#elif __AVX__ + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 7) / 8 * 8); +#elif __SSE2__ + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 3) / 4 * 4); +#else + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 1) / 2 * 2); +#endif + + if (nT > 1) + { +#if __AVX512F__ + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 15) / 16 * 16); +#elif __AVX__ + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 7) / 8 * 8); +#elif __SSE2__ + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 3) / 4 * 4); +#else + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 1) / 2 * 2); +#endif + } + } + + if (N > 0) + { + int tile_size; + if (TILE_K >= K) + { + tile_size = (l2_cache_size_fp32 - TILE_M * TILE_K) / TILE_K; + } + else + { + tile_size = (l2_cache_size_fp32 - TILE_M * TILE_K) / (TILE_M + TILE_K); + } + +#if __AVX512F__ + TILE_N = std::max(4, tile_size / 4 * 4); +#elif __AVX__ + TILE_N = std::max(4, tile_size / 4 * 4); +#elif __SSE2__ + TILE_N = std::max(4, tile_size / 4 * 4); +#else + TILE_N = std::max(1, tile_size); +#endif + + int nn_N = (N + TILE_N - 1) / TILE_N; +#if __AVX512F__ + TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 3) / 4 * 4); +#elif __AVX__ + TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 3) / 4 * 4); +#elif __SSE2__ + TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 3) / 4 * 4); +#else + TILE_N = std::min(TILE_N, (N + nn_N - 1) / nn_N); +#endif + +#if __AVX512F__ + TILE_N = std::max(4, TILE_N); +#elif __AVX__ + TILE_N = std::max(4, TILE_N); +#elif __SSE2__ + TILE_N = std::max(4, TILE_N); +#else + TILE_N = std::max(1, TILE_N); +#endif + } +} + +static void convolution_im2col_input_tile_conv1x1s1d1(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk) +{ + const int elempack = bottom_blob.elempack; + + float* pp = B; + + int jj = 0; +#if __SSE2__ +#if defined(__x86_64__) || defined(_M_X64) + for (; jj + 11 < max_jj; jj += 12) + { +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + const float* p0 = (const float*)bottom_blob.channel(k / 16) + (j + jj) * 16; + + int kk = 0; + for (; kk < max_kk / 16; kk++) + { + __m512 _r0 = _mm512_load_ps(p0); + __m512 _r1 = _mm512_load_ps(p0 + 16); + __m512 _r2 = _mm512_load_ps(p0 + 16 * 2); + __m512 _r3 = _mm512_load_ps(p0 + 16 * 3); + __m512 _r4 = _mm512_load_ps(p0 + 16 * 4); + __m512 _r5 = _mm512_load_ps(p0 + 16 * 5); + __m512 _r6 = _mm512_load_ps(p0 + 16 * 6); + __m512 _r7 = _mm512_load_ps(p0 + 16 * 7); + __m512 _r8 = _mm512_load_ps(p0 + 16 * 8); + __m512 _r9 = _mm512_load_ps(p0 + 16 * 9); + __m512 _ra = _mm512_load_ps(p0 + 16 * 10); + __m512 _rb = _mm512_load_ps(p0 + 16 * 11); + transpose16x12_ps(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7, _r8, _r9, _ra, _rb); + _mm512_store_ps(pp, _r0); + _mm512_store_ps(pp + 16 * 1, _r1); + _mm512_store_ps(pp + 16 * 2, _r2); + _mm512_store_ps(pp + 16 * 3, _r3); + _mm512_store_ps(pp + 16 * 4, _r4); + _mm512_store_ps(pp + 16 * 5, _r5); + _mm512_store_ps(pp + 16 * 6, _r6); + _mm512_store_ps(pp + 16 * 7, _r7); + _mm512_store_ps(pp + 16 * 8, _r8); + _mm512_store_ps(pp + 16 * 9, _r9); + _mm512_store_ps(pp + 16 * 10, _ra); + _mm512_store_ps(pp + 16 * 11, _rb); + pp += 192; + p0 += bottom_blob.cstep * 16; + } + } +#endif // __AVX512F__ + if (elempack == 8) + { + const float* p0 = (const float*)bottom_blob.channel(k / 8) + (j + jj) * 8; + + int kk = 0; + for (; kk < max_kk / 8; kk++) + { + __m256 _r0 = _mm256_load_ps(p0); + __m256 _r1 = _mm256_load_ps(p0 + 8); + __m256 _r2 = _mm256_load_ps(p0 + 8 * 2); + __m256 _r3 = _mm256_load_ps(p0 + 8 * 3); + __m256 _r4 = _mm256_load_ps(p0 + 8 * 4); + __m256 _r5 = _mm256_load_ps(p0 + 8 * 5); + __m256 _r6 = _mm256_load_ps(p0 + 8 * 6); + __m256 _r7 = _mm256_load_ps(p0 + 8 * 7); + __m256 _r8 = _mm256_load_ps(p0 + 8 * 8); + __m256 _r9 = _mm256_load_ps(p0 + 8 * 9); + __m256 _ra = _mm256_load_ps(p0 + 8 * 10); + __m256 _rb = _mm256_load_ps(p0 + 8 * 11); + transpose8x12_ps(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7, _r8, _r9, _ra, _rb); + _mm256_store_ps(pp, _r0); + _mm256_store_ps(pp + 8 * 1, _r1); + _mm256_store_ps(pp + 8 * 2, _r2); + _mm256_store_ps(pp + 8 * 3, _r3); + _mm256_store_ps(pp + 8 * 4, _r4); + _mm256_store_ps(pp + 8 * 5, _r5); + _mm256_store_ps(pp + 8 * 6, _r6); + _mm256_store_ps(pp + 8 * 7, _r7); + _mm256_store_ps(pp + 8 * 8, _r8); + _mm256_store_ps(pp + 8 * 9, _r9); + _mm256_store_ps(pp + 8 * 10, _ra); + _mm256_store_ps(pp + 8 * 11, _rb); + pp += 96; + p0 += bottom_blob.cstep * 8; + } + } +#endif // __AVX__ + if (elempack == 4) + { + const float* p0 = (const float*)bottom_blob.channel(k / 4) + (j + jj) * 4; + + int kk = 0; + for (; kk < max_kk / 4; kk++) + { + __m128 _r0 = _mm_load_ps(p0); + __m128 _r1 = _mm_load_ps(p0 + 4); + __m128 _r2 = _mm_load_ps(p0 + 4 * 2); + __m128 _r3 = _mm_load_ps(p0 + 4 * 3); + __m128 _r4 = _mm_load_ps(p0 + 4 * 4); + __m128 _r5 = _mm_load_ps(p0 + 4 * 5); + __m128 _r6 = _mm_load_ps(p0 + 4 * 6); + __m128 _r7 = _mm_load_ps(p0 + 4 * 7); + __m128 _r8 = _mm_load_ps(p0 + 4 * 8); + __m128 _r9 = _mm_load_ps(p0 + 4 * 9); + __m128 _ra = _mm_load_ps(p0 + 4 * 10); + __m128 _rb = _mm_load_ps(p0 + 4 * 11); + _MM_TRANSPOSE4_PS(_r0, _r1, _r2, _r3); + _MM_TRANSPOSE4_PS(_r4, _r5, _r6, _r7); + _MM_TRANSPOSE4_PS(_r8, _r9, _ra, _rb); + _mm_store_ps(pp, _r0); + _mm_store_ps(pp + 4 * 1, _r4); + _mm_store_ps(pp + 4 * 2, _r8); + _mm_store_ps(pp + 4 * 3, _r1); + _mm_store_ps(pp + 4 * 4, _r5); + _mm_store_ps(pp + 4 * 5, _r9); + _mm_store_ps(pp + 4 * 6, _r2); + _mm_store_ps(pp + 4 * 7, _r6); + _mm_store_ps(pp + 4 * 8, _ra); + _mm_store_ps(pp + 4 * 9, _r3); + _mm_store_ps(pp + 4 * 10, _r7); + _mm_store_ps(pp + 4 * 11, _rb); + pp += 48; + p0 += bottom_blob.cstep * 4; + } + } + + if (elempack == 1) + { + const float* p0 = (const float*)bottom_blob.channel(k) + (j + jj); + + int kk = 0; + for (; kk < max_kk; kk++) + { + __m128 _r0 = _mm_loadu_ps(p0); + __m128 _r1 = _mm_loadu_ps(p0 + 4); + __m128 _r2 = _mm_loadu_ps(p0 + 8); + _mm_storeu_ps(pp, _r0); + _mm_storeu_ps(pp + 4, _r1); + _mm_storeu_ps(pp + 8, _r2); + pp += 12; + p0 += bottom_blob.cstep; + } + } + } + for (; jj + 7 < max_jj; jj += 8) + { +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + const float* p0 = (const float*)bottom_blob.channel(k / 16) + (j + jj) * 16; + + int kk = 0; + for (; kk < max_kk / 16; kk++) + { + __m512 _r0 = _mm512_load_ps(p0); + __m512 _r1 = _mm512_load_ps(p0 + 16); + __m512 _r2 = _mm512_load_ps(p0 + 16 * 2); + __m512 _r3 = _mm512_load_ps(p0 + 16 * 3); + __m512 _r4 = _mm512_load_ps(p0 + 16 * 4); + __m512 _r5 = _mm512_load_ps(p0 + 16 * 5); + __m512 _r6 = _mm512_load_ps(p0 + 16 * 6); + __m512 _r7 = _mm512_load_ps(p0 + 16 * 7); + transpose16x8_ps(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7); + _mm512_store_ps(pp, _r0); + _mm512_store_ps(pp + 16 * 1, _r1); + _mm512_store_ps(pp + 16 * 2, _r2); + _mm512_store_ps(pp + 16 * 3, _r3); + _mm512_store_ps(pp + 16 * 4, _r4); + _mm512_store_ps(pp + 16 * 5, _r5); + _mm512_store_ps(pp + 16 * 6, _r6); + _mm512_store_ps(pp + 16 * 7, _r7); + pp += 128; + p0 += bottom_blob.cstep * 16; + } + } +#endif // __AVX512F__ + if (elempack == 8) + { + const float* p0 = (const float*)bottom_blob.channel(k / 8) + (j + jj) * 8; + + int kk = 0; + for (; kk < max_kk / 8; kk++) + { + __m256 _r0 = _mm256_load_ps(p0); + __m256 _r1 = _mm256_load_ps(p0 + 8); + __m256 _r2 = _mm256_load_ps(p0 + 8 * 2); + __m256 _r3 = _mm256_load_ps(p0 + 8 * 3); + __m256 _r4 = _mm256_load_ps(p0 + 8 * 4); + __m256 _r5 = _mm256_load_ps(p0 + 8 * 5); + __m256 _r6 = _mm256_load_ps(p0 + 8 * 6); + __m256 _r7 = _mm256_load_ps(p0 + 8 * 7); + transpose8x8_ps(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7); + _mm256_store_ps(pp, _r0); + _mm256_store_ps(pp + 8 * 1, _r1); + _mm256_store_ps(pp + 8 * 2, _r2); + _mm256_store_ps(pp + 8 * 3, _r3); + _mm256_store_ps(pp + 8 * 4, _r4); + _mm256_store_ps(pp + 8 * 5, _r5); + _mm256_store_ps(pp + 8 * 6, _r6); + _mm256_store_ps(pp + 8 * 7, _r7); + pp += 64; + p0 += bottom_blob.cstep * 8; + } + } +#endif // __AVX__ + if (elempack == 4) + { + const float* p0 = (const float*)bottom_blob.channel(k / 4) + (j + jj) * 4; + + int kk = 0; + for (; kk < max_kk / 4; kk++) + { + __m128 _r0 = _mm_load_ps(p0); + __m128 _r1 = _mm_load_ps(p0 + 4); + __m128 _r2 = _mm_load_ps(p0 + 4 * 2); + __m128 _r3 = _mm_load_ps(p0 + 4 * 3); + __m128 _r4 = _mm_load_ps(p0 + 4 * 4); + __m128 _r5 = _mm_load_ps(p0 + 4 * 5); + __m128 _r6 = _mm_load_ps(p0 + 4 * 6); + __m128 _r7 = _mm_load_ps(p0 + 4 * 7); + _MM_TRANSPOSE4_PS(_r0, _r1, _r2, _r3); + _MM_TRANSPOSE4_PS(_r4, _r5, _r6, _r7); + _mm_store_ps(pp, _r0); + _mm_store_ps(pp + 4 * 1, _r4); + _mm_store_ps(pp + 4 * 2, _r1); + _mm_store_ps(pp + 4 * 3, _r5); + _mm_store_ps(pp + 4 * 4, _r2); + _mm_store_ps(pp + 4 * 5, _r6); + _mm_store_ps(pp + 4 * 6, _r3); + _mm_store_ps(pp + 4 * 7, _r7); + pp += 32; + p0 += bottom_blob.cstep * 4; + } + } + + if (elempack == 1) + { + const float* p0 = (const float*)bottom_blob.channel(k) + (j + jj); + + int kk = 0; + for (; kk < max_kk; kk++) + { + __m128 _r0 = _mm_loadu_ps(p0); + __m128 _r1 = _mm_loadu_ps(p0 + 4); + _mm_storeu_ps(pp, _r0); + _mm_storeu_ps(pp + 4, _r1); + pp += 8; + p0 += bottom_blob.cstep; + } + } + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + const float* p0 = (const float*)bottom_blob.channel(k / 16) + (j + jj) * 16; + + int kk = 0; + for (; kk < max_kk / 16; kk++) + { + __m512 _r0 = _mm512_load_ps(p0); + __m512 _r1 = _mm512_load_ps(p0 + 16); + __m512 _r2 = _mm512_load_ps(p0 + 16 * 2); + __m512 _r3 = _mm512_load_ps(p0 + 16 * 3); + transpose16x4_ps(_r0, _r1, _r2, _r3); + _mm512_store_ps(pp, _r0); + _mm512_store_ps(pp + 16 * 1, _r1); + _mm512_store_ps(pp + 16 * 2, _r2); + _mm512_store_ps(pp + 16 * 3, _r3); + pp += 64; + p0 += bottom_blob.cstep * 16; + } + } +#endif // __AVX512F__ + if (elempack == 8) + { + const float* p0 = (const float*)bottom_blob.channel(k / 8) + (j + jj) * 8; + + int kk = 0; + for (; kk < max_kk / 8; kk++) + { + __m256 _r0 = _mm256_load_ps(p0); + __m256 _r1 = _mm256_load_ps(p0 + 8); + __m256 _r2 = _mm256_load_ps(p0 + 8 * 2); + __m256 _r3 = _mm256_load_ps(p0 + 8 * 3); + transpose8x4_ps(_r0, _r1, _r2, _r3); + _mm256_store_ps(pp, _r0); + _mm256_store_ps(pp + 8 * 1, _r1); + _mm256_store_ps(pp + 8 * 2, _r2); + _mm256_store_ps(pp + 8 * 3, _r3); + pp += 32; + p0 += bottom_blob.cstep * 8; + } + } +#endif // __AVX__ + if (elempack == 4) + { + const float* p0 = (const float*)bottom_blob.channel(k / 4) + (j + jj) * 4; + + int kk = 0; + for (; kk < max_kk / 4; kk++) + { + __m128 _r0 = _mm_load_ps(p0); + __m128 _r1 = _mm_load_ps(p0 + 4); + __m128 _r2 = _mm_load_ps(p0 + 4 * 2); + __m128 _r3 = _mm_load_ps(p0 + 4 * 3); + _MM_TRANSPOSE4_PS(_r0, _r1, _r2, _r3); + _mm_store_ps(pp, _r0); + _mm_store_ps(pp + 4 * 1, _r1); + _mm_store_ps(pp + 4 * 2, _r2); + _mm_store_ps(pp + 4 * 3, _r3); + pp += 16; + p0 += bottom_blob.cstep * 4; + } + } + + if (elempack == 1) + { + const float* p0 = (const float*)bottom_blob.channel(k) + (j + jj); + + int kk = 0; + for (; kk < max_kk; kk++) + { + _mm_storeu_ps(pp, _mm_loadu_ps(p0)); + pp += 4; + p0 += bottom_blob.cstep; + } + } + } +#endif // __SSE2__ + for (; jj + 1 < max_jj; jj += 2) + { +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + const float* p0 = (const float*)bottom_blob.channel(k / 16) + (j + jj) * 16; + + int kk = 0; + for (; kk < max_kk / 16; kk++) + { + __m512 _r0 = _mm512_load_ps(p0); + __m512 _r1 = _mm512_load_ps(p0 + 16); + transpose16x2_ps(_r0, _r1); + _mm512_store_ps(pp, _r0); + _mm512_store_ps(pp + 16, _r1); + pp += 32; + p0 += bottom_blob.cstep * 16; + } + } +#endif // __AVX512F__ + if (elempack == 8) + { + const float* p0 = (const float*)bottom_blob.channel(k / 8) + (j + jj) * 8; + + int kk = 0; + for (; kk < max_kk / 8; kk++) + { + __m256 _r0 = _mm256_load_ps(p0); + __m256 _r1 = _mm256_load_ps(p0 + 8); + transpose8x2_ps(_r0, _r1); + _mm256_store_ps(pp, _r0); + _mm256_store_ps(pp + 8, _r1); + pp += 16; + p0 += bottom_blob.cstep * 8; + } + } +#endif // __AVX__ + if (elempack == 4) + { + const float* p0 = (const float*)bottom_blob.channel(k / 4) + (j + jj) * 4; + + int kk = 0; + for (; kk < max_kk / 4; kk++) + { + // transpose4x2 + __m128 _r0 = _mm_load_ps(p0); + __m128 _r1 = _mm_load_ps(p0 + 4); + __m128 _tmp0 = _mm_unpacklo_ps(_r0, _r1); + __m128 _tmp1 = _mm_unpackhi_ps(_r0, _r1); + _mm_store_ps(pp, _tmp0); + _mm_store_ps(pp + 4, _tmp1); + pp += 8; + p0 += bottom_blob.cstep * 4; + } + } +#endif // __SSE2__ + + if (elempack == 1) + { + const float* p0 = (const float*)bottom_blob.channel(k) + (j + jj); + + int kk = 0; + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp += 2; + p0 += bottom_blob.cstep; + } + } + } + for (; jj < max_jj; jj++) + { +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + const float* p0 = (const float*)bottom_blob.channel(k / 16) + (j + jj) * 16; + + int kk = 0; + for (; kk < max_kk / 16; kk++) + { + _mm512_store_ps(pp, _mm512_load_ps(p0)); + pp += 16; + p0 += bottom_blob.cstep * 16; + } + } +#endif // __AVX512F__ + if (elempack == 8) + { + const float* p0 = (const float*)bottom_blob.channel(k / 8) + (j + jj) * 8; + + int kk = 0; + for (; kk < max_kk / 8; kk++) + { + _mm256_store_ps(pp, _mm256_load_ps(p0)); + pp += 8; + p0 += bottom_blob.cstep * 8; + } + } +#endif // __AVX__ + if (elempack == 4) + { + const float* p0 = (const float*)bottom_blob.channel(k / 4) + (j + jj) * 4; + + int kk = 0; + for (; kk < max_kk / 4; kk++) + { + _mm_store_ps(pp, _mm_load_ps(p0)); + pp += 4; + p0 += bottom_blob.cstep * 4; + } + } +#endif // __SSE2__ + + if (elempack == 1) + { + const float* p0 = (const float*)bottom_blob.channel(k) + (j + jj); + + int kk = 0; + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp += 1; + p0 += bottom_blob.cstep; + } + } + } +} + +static inline void convolution_im2col_input_tile_impl(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h) +{ + const int w = bottom_blob.w; + // const int channels = bottom_blob.c; + const int elempack = bottom_blob.elempack; + + const int kernel_extent_w = dilation_w * (kernel_w - 1) + 1; + const int outw = (w - kernel_extent_w) / stride_w + 1; + + // j max_jj outw*outh split w and h + + // k max_kk pa*maxk*(inch/pa) split inch + + // k/max_kk shall be multiple of maxk + + const int maxk = kernel_w * kernel_h; + + float* pp = B; + + int jj = 0; +#if __SSE2__ +#if defined(__x86_64__) || defined(_M_X64) + for (; jj + 11 < max_jj; jj += 12) + { + int dy0 = (j + jj) / outw; + int dy1 = (j + jj + 1) / outw; + int dy2 = (j + jj + 2) / outw; + int dy3 = (j + jj + 3) / outw; + int dy4 = (j + jj + 4) / outw; + int dy5 = (j + jj + 5) / outw; + int dy6 = (j + jj + 6) / outw; + int dy7 = (j + jj + 7) / outw; + int dy8 = (j + jj + 8) / outw; + int dy9 = (j + jj + 9) / outw; + int dya = (j + jj + 10) / outw; + int dyb = (j + jj + 11) / outw; + int dx0 = (j + jj) % outw; + int dx1 = (j + jj + 1) % outw; + int dx2 = (j + jj + 2) % outw; + int dx3 = (j + jj + 3) % outw; + int dx4 = (j + jj + 4) % outw; + int dx5 = (j + jj + 5) % outw; + int dx6 = (j + jj + 6) % outw; + int dx7 = (j + jj + 7) % outw; + int dx8 = (j + jj + 8) % outw; + int dx9 = (j + jj + 9) % outw; + int dxa = (j + jj + 10) % outw; + int dxb = (j + jj + 11) % outw; + + if (dy0 == dyb) + { + int kk = 0; + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + + const float* sptr = img.row(y0) + x0 * elempack; + +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + __m512 _r0 = _mm512_load_ps(sptr); + __m512 _r1 = _mm512_load_ps(sptr + stride_w * 16); + __m512 _r2 = _mm512_load_ps(sptr + stride_w * 32); + __m512 _r3 = _mm512_load_ps(sptr + stride_w * 48); + __m512 _r4 = _mm512_load_ps(sptr + stride_w * 64); + __m512 _r5 = _mm512_load_ps(sptr + stride_w * 80); + __m512 _r6 = _mm512_load_ps(sptr + stride_w * 96); + __m512 _r7 = _mm512_load_ps(sptr + stride_w * 112); + __m512 _r8 = _mm512_load_ps(sptr + stride_w * 128); + __m512 _r9 = _mm512_load_ps(sptr + stride_w * 144); + __m512 _ra = _mm512_load_ps(sptr + stride_w * 160); + __m512 _rb = _mm512_load_ps(sptr + stride_w * 176); + transpose16x12_ps(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7, _r8, _r9, _ra, _rb); + _mm512_store_ps(pp, _r0); + _mm512_store_ps(pp + 16 * 1, _r1); + _mm512_store_ps(pp + 16 * 2, _r2); + _mm512_store_ps(pp + 16 * 3, _r3); + _mm512_store_ps(pp + 16 * 4, _r4); + _mm512_store_ps(pp + 16 * 5, _r5); + _mm512_store_ps(pp + 16 * 6, _r6); + _mm512_store_ps(pp + 16 * 7, _r7); + _mm512_store_ps(pp + 16 * 8, _r8); + _mm512_store_ps(pp + 16 * 9, _r9); + _mm512_store_ps(pp + 16 * 10, _ra); + _mm512_store_ps(pp + 16 * 11, _rb); + pp += 192; + } +#endif // __AVX512F__ + if (elempack == 8) + { + __m256 _r0 = _mm256_load_ps(sptr); + __m256 _r1 = _mm256_load_ps(sptr + stride_w * 8); + __m256 _r2 = _mm256_load_ps(sptr + stride_w * 16); + __m256 _r3 = _mm256_load_ps(sptr + stride_w * 24); + __m256 _r4 = _mm256_load_ps(sptr + stride_w * 32); + __m256 _r5 = _mm256_load_ps(sptr + stride_w * 40); + __m256 _r6 = _mm256_load_ps(sptr + stride_w * 48); + __m256 _r7 = _mm256_load_ps(sptr + stride_w * 56); + __m256 _r8 = _mm256_load_ps(sptr + stride_w * 64); + __m256 _r9 = _mm256_load_ps(sptr + stride_w * 72); + __m256 _ra = _mm256_load_ps(sptr + stride_w * 80); + __m256 _rb = _mm256_load_ps(sptr + stride_w * 88); + transpose8x12_ps(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7, _r8, _r9, _ra, _rb); + _mm256_store_ps(pp, _r0); + _mm256_store_ps(pp + 8 * 1, _r1); + _mm256_store_ps(pp + 8 * 2, _r2); + _mm256_store_ps(pp + 8 * 3, _r3); + _mm256_store_ps(pp + 8 * 4, _r4); + _mm256_store_ps(pp + 8 * 5, _r5); + _mm256_store_ps(pp + 8 * 6, _r6); + _mm256_store_ps(pp + 8 * 7, _r7); + _mm256_store_ps(pp + 8 * 8, _r8); + _mm256_store_ps(pp + 8 * 9, _r9); + _mm256_store_ps(pp + 8 * 10, _ra); + _mm256_store_ps(pp + 8 * 11, _rb); + pp += 96; + } +#endif // __AVX__ + if (elempack == 4) + { + __m128 _r0 = _mm_load_ps(sptr); + __m128 _r1 = _mm_load_ps(sptr + stride_w * 4); + __m128 _r2 = _mm_load_ps(sptr + stride_w * 8); + __m128 _r3 = _mm_load_ps(sptr + stride_w * 12); + __m128 _r4 = _mm_load_ps(sptr + stride_w * 16); + __m128 _r5 = _mm_load_ps(sptr + stride_w * 20); + __m128 _r6 = _mm_load_ps(sptr + stride_w * 24); + __m128 _r7 = _mm_load_ps(sptr + stride_w * 28); + __m128 _r8 = _mm_load_ps(sptr + stride_w * 32); + __m128 _r9 = _mm_load_ps(sptr + stride_w * 36); + __m128 _ra = _mm_load_ps(sptr + stride_w * 40); + __m128 _rb = _mm_load_ps(sptr + stride_w * 44); + _MM_TRANSPOSE4_PS(_r0, _r1, _r2, _r3); + _MM_TRANSPOSE4_PS(_r4, _r5, _r6, _r7); + _MM_TRANSPOSE4_PS(_r8, _r9, _ra, _rb); + _mm_store_ps(pp, _r0); + _mm_store_ps(pp + 4 * 1, _r4); + _mm_store_ps(pp + 4 * 2, _r8); + _mm_store_ps(pp + 4 * 3, _r1); + _mm_store_ps(pp + 4 * 4, _r5); + _mm_store_ps(pp + 4 * 5, _r9); + _mm_store_ps(pp + 4 * 6, _r2); + _mm_store_ps(pp + 4 * 7, _r6); + _mm_store_ps(pp + 4 * 8, _ra); + _mm_store_ps(pp + 4 * 9, _r3); + _mm_store_ps(pp + 4 * 10, _r7); + _mm_store_ps(pp + 4 * 11, _rb); + pp += 48; + } + if (elempack == 1) + { + pp[0] = sptr[0]; + pp[1] = sptr[stride_w]; + pp[2] = sptr[stride_w * 2]; + pp[3] = sptr[stride_w * 3]; + pp[4] = sptr[stride_w * 4]; + pp[5] = sptr[stride_w * 5]; + pp[6] = sptr[stride_w * 6]; + pp[7] = sptr[stride_w * 7]; + pp[8] = sptr[stride_w * 8]; + pp[9] = sptr[stride_w * 9]; + pp[10] = sptr[stride_w * 10]; + pp[11] = sptr[stride_w * 11]; + pp += 12; + } + } + } + else + { + int kk = 0; + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int x1 = stride_w * dx1 + dilation_w * v; + int x2 = stride_w * dx2 + dilation_w * v; + int x3 = stride_w * dx3 + dilation_w * v; + int x4 = stride_w * dx4 + dilation_w * v; + int x5 = stride_w * dx5 + dilation_w * v; + int x6 = stride_w * dx6 + dilation_w * v; + int x7 = stride_w * dx7 + dilation_w * v; + int x8 = stride_w * dx8 + dilation_w * v; + int x9 = stride_w * dx9 + dilation_w * v; + int xa = stride_w * dxa + dilation_w * v; + int xb = stride_w * dxb + dilation_w * v; + + int y0 = stride_h * dy0 + dilation_h * u; + int y1 = stride_h * dy1 + dilation_h * u; + int y2 = stride_h * dy2 + dilation_h * u; + int y3 = stride_h * dy3 + dilation_h * u; + int y4 = stride_h * dy4 + dilation_h * u; + int y5 = stride_h * dy5 + dilation_h * u; + int y6 = stride_h * dy6 + dilation_h * u; + int y7 = stride_h * dy7 + dilation_h * u; + int y8 = stride_h * dy8 + dilation_h * u; + int y9 = stride_h * dy9 + dilation_h * u; + int ya = stride_h * dya + dilation_h * u; + int yb = stride_h * dyb + dilation_h * u; + + const float* sptr0 = img.row(y0) + x0 * elempack; + const float* sptr1 = img.row(y1) + x1 * elempack; + const float* sptr2 = img.row(y2) + x2 * elempack; + const float* sptr3 = img.row(y3) + x3 * elempack; + const float* sptr4 = img.row(y4) + x4 * elempack; + const float* sptr5 = img.row(y5) + x5 * elempack; + const float* sptr6 = img.row(y6) + x6 * elempack; + const float* sptr7 = img.row(y7) + x7 * elempack; + const float* sptr8 = img.row(y8) + x8 * elempack; + const float* sptr9 = img.row(y9) + x9 * elempack; + const float* sptra = img.row(ya) + xa * elempack; + const float* sptrb = img.row(yb) + xb * elempack; + +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + __m512 _r0 = _mm512_load_ps(sptr0); + __m512 _r1 = _mm512_load_ps(sptr1); + __m512 _r2 = _mm512_load_ps(sptr2); + __m512 _r3 = _mm512_load_ps(sptr3); + __m512 _r4 = _mm512_load_ps(sptr4); + __m512 _r5 = _mm512_load_ps(sptr5); + __m512 _r6 = _mm512_load_ps(sptr6); + __m512 _r7 = _mm512_load_ps(sptr7); + __m512 _r8 = _mm512_load_ps(sptr8); + __m512 _r9 = _mm512_load_ps(sptr9); + __m512 _ra = _mm512_load_ps(sptra); + __m512 _rb = _mm512_load_ps(sptrb); + transpose16x12_ps(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7, _r8, _r9, _ra, _rb); + _mm512_store_ps(pp, _r0); + _mm512_store_ps(pp + 16 * 1, _r1); + _mm512_store_ps(pp + 16 * 2, _r2); + _mm512_store_ps(pp + 16 * 3, _r3); + _mm512_store_ps(pp + 16 * 4, _r4); + _mm512_store_ps(pp + 16 * 5, _r5); + _mm512_store_ps(pp + 16 * 6, _r6); + _mm512_store_ps(pp + 16 * 7, _r7); + _mm512_store_ps(pp + 16 * 8, _r8); + _mm512_store_ps(pp + 16 * 9, _r9); + _mm512_store_ps(pp + 16 * 10, _ra); + _mm512_store_ps(pp + 16 * 11, _rb); + pp += 192; + } +#endif // __AVX512F__ + if (elempack == 8) + { + __m256 _r0 = _mm256_load_ps(sptr0); + __m256 _r1 = _mm256_load_ps(sptr1); + __m256 _r2 = _mm256_load_ps(sptr2); + __m256 _r3 = _mm256_load_ps(sptr3); + __m256 _r4 = _mm256_load_ps(sptr4); + __m256 _r5 = _mm256_load_ps(sptr5); + __m256 _r6 = _mm256_load_ps(sptr6); + __m256 _r7 = _mm256_load_ps(sptr7); + __m256 _r8 = _mm256_load_ps(sptr8); + __m256 _r9 = _mm256_load_ps(sptr9); + __m256 _ra = _mm256_load_ps(sptra); + __m256 _rb = _mm256_load_ps(sptrb); + transpose8x12_ps(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7, _r8, _r9, _ra, _rb); + _mm256_store_ps(pp, _r0); + _mm256_store_ps(pp + 8 * 1, _r1); + _mm256_store_ps(pp + 8 * 2, _r2); + _mm256_store_ps(pp + 8 * 3, _r3); + _mm256_store_ps(pp + 8 * 4, _r4); + _mm256_store_ps(pp + 8 * 5, _r5); + _mm256_store_ps(pp + 8 * 6, _r6); + _mm256_store_ps(pp + 8 * 7, _r7); + _mm256_store_ps(pp + 8 * 8, _r8); + _mm256_store_ps(pp + 8 * 9, _r9); + _mm256_store_ps(pp + 8 * 10, _ra); + _mm256_store_ps(pp + 8 * 11, _rb); + pp += 96; + } +#endif // __AVX__ + if (elempack == 4) + { + __m128 _r0 = _mm_load_ps(sptr0); + __m128 _r1 = _mm_load_ps(sptr1); + __m128 _r2 = _mm_load_ps(sptr2); + __m128 _r3 = _mm_load_ps(sptr3); + __m128 _r4 = _mm_load_ps(sptr4); + __m128 _r5 = _mm_load_ps(sptr5); + __m128 _r6 = _mm_load_ps(sptr6); + __m128 _r7 = _mm_load_ps(sptr7); + __m128 _r8 = _mm_load_ps(sptr8); + __m128 _r9 = _mm_load_ps(sptr9); + __m128 _ra = _mm_load_ps(sptra); + __m128 _rb = _mm_load_ps(sptrb); + _MM_TRANSPOSE4_PS(_r0, _r1, _r2, _r3); + _MM_TRANSPOSE4_PS(_r4, _r5, _r6, _r7); + _MM_TRANSPOSE4_PS(_r8, _r9, _ra, _rb); + _mm_store_ps(pp, _r0); + _mm_store_ps(pp + 4 * 1, _r4); + _mm_store_ps(pp + 4 * 2, _r8); + _mm_store_ps(pp + 4 * 3, _r1); + _mm_store_ps(pp + 4 * 4, _r5); + _mm_store_ps(pp + 4 * 5, _r9); + _mm_store_ps(pp + 4 * 6, _r2); + _mm_store_ps(pp + 4 * 7, _r6); + _mm_store_ps(pp + 4 * 8, _ra); + _mm_store_ps(pp + 4 * 9, _r3); + _mm_store_ps(pp + 4 * 10, _r7); + _mm_store_ps(pp + 4 * 11, _rb); + pp += 48; + } + if (elempack == 1) + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr2[0]; + pp[3] = sptr3[0]; + pp[4] = sptr4[0]; + pp[5] = sptr5[0]; + pp[6] = sptr6[0]; + pp[7] = sptr7[0]; + pp[8] = sptr8[0]; + pp[9] = sptr9[0]; + pp[10] = sptra[0]; + pp[11] = sptrb[0]; + pp += 12; + } + } + } + } + for (; jj + 7 < max_jj; jj += 8) + { + int dy0 = (j + jj) / outw; + int dy1 = (j + jj + 1) / outw; + int dy2 = (j + jj + 2) / outw; + int dy3 = (j + jj + 3) / outw; + int dy4 = (j + jj + 4) / outw; + int dy5 = (j + jj + 5) / outw; + int dy6 = (j + jj + 6) / outw; + int dy7 = (j + jj + 7) / outw; + int dx0 = (j + jj) % outw; + int dx1 = (j + jj + 1) % outw; + int dx2 = (j + jj + 2) % outw; + int dx3 = (j + jj + 3) % outw; + int dx4 = (j + jj + 4) % outw; + int dx5 = (j + jj + 5) % outw; + int dx6 = (j + jj + 6) % outw; + int dx7 = (j + jj + 7) % outw; + + if (dy0 == dy7) + { + int kk = 0; + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + + const float* sptr = img.row(y0) + x0 * elempack; + +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + __m512 _r0 = _mm512_load_ps(sptr); + __m512 _r1 = _mm512_load_ps(sptr + stride_w * 16); + __m512 _r2 = _mm512_load_ps(sptr + stride_w * 32); + __m512 _r3 = _mm512_load_ps(sptr + stride_w * 48); + __m512 _r4 = _mm512_load_ps(sptr + stride_w * 64); + __m512 _r5 = _mm512_load_ps(sptr + stride_w * 80); + __m512 _r6 = _mm512_load_ps(sptr + stride_w * 96); + __m512 _r7 = _mm512_load_ps(sptr + stride_w * 112); + transpose16x8_ps(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7); + _mm512_store_ps(pp, _r0); + _mm512_store_ps(pp + 16 * 1, _r1); + _mm512_store_ps(pp + 16 * 2, _r2); + _mm512_store_ps(pp + 16 * 3, _r3); + _mm512_store_ps(pp + 16 * 4, _r4); + _mm512_store_ps(pp + 16 * 5, _r5); + _mm512_store_ps(pp + 16 * 6, _r6); + _mm512_store_ps(pp + 16 * 7, _r7); + pp += 128; + } +#endif // __AVX512F__ + if (elempack == 8) + { + __m256 _r0 = _mm256_load_ps(sptr); + __m256 _r1 = _mm256_load_ps(sptr + stride_w * 8); + __m256 _r2 = _mm256_load_ps(sptr + stride_w * 16); + __m256 _r3 = _mm256_load_ps(sptr + stride_w * 24); + __m256 _r4 = _mm256_load_ps(sptr + stride_w * 32); + __m256 _r5 = _mm256_load_ps(sptr + stride_w * 40); + __m256 _r6 = _mm256_load_ps(sptr + stride_w * 48); + __m256 _r7 = _mm256_load_ps(sptr + stride_w * 56); + transpose8x8_ps(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7); + _mm256_store_ps(pp, _r0); + _mm256_store_ps(pp + 8 * 1, _r1); + _mm256_store_ps(pp + 8 * 2, _r2); + _mm256_store_ps(pp + 8 * 3, _r3); + _mm256_store_ps(pp + 8 * 4, _r4); + _mm256_store_ps(pp + 8 * 5, _r5); + _mm256_store_ps(pp + 8 * 6, _r6); + _mm256_store_ps(pp + 8 * 7, _r7); + pp += 64; + } +#endif // __AVX__ + if (elempack == 4) + { + __m128 _r0 = _mm_load_ps(sptr); + __m128 _r1 = _mm_load_ps(sptr + stride_w * 4); + __m128 _r2 = _mm_load_ps(sptr + stride_w * 8); + __m128 _r3 = _mm_load_ps(sptr + stride_w * 12); + __m128 _r4 = _mm_load_ps(sptr + stride_w * 16); + __m128 _r5 = _mm_load_ps(sptr + stride_w * 20); + __m128 _r6 = _mm_load_ps(sptr + stride_w * 24); + __m128 _r7 = _mm_load_ps(sptr + stride_w * 28); + _MM_TRANSPOSE4_PS(_r0, _r1, _r2, _r3); + _MM_TRANSPOSE4_PS(_r4, _r5, _r6, _r7); + _mm_store_ps(pp, _r0); + _mm_store_ps(pp + 4 * 1, _r4); + _mm_store_ps(pp + 4 * 2, _r1); + _mm_store_ps(pp + 4 * 3, _r5); + _mm_store_ps(pp + 4 * 4, _r2); + _mm_store_ps(pp + 4 * 5, _r6); + _mm_store_ps(pp + 4 * 6, _r3); + _mm_store_ps(pp + 4 * 7, _r7); + pp += 32; + } + if (elempack == 1) + { + pp[0] = sptr[0]; + pp[1] = sptr[stride_w]; + pp[2] = sptr[stride_w * 2]; + pp[3] = sptr[stride_w * 3]; + pp[4] = sptr[stride_w * 4]; + pp[5] = sptr[stride_w * 5]; + pp[6] = sptr[stride_w * 6]; + pp[7] = sptr[stride_w * 7]; + pp += 8; + } + } + } + else + { + int kk = 0; + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int x1 = stride_w * dx1 + dilation_w * v; + int x2 = stride_w * dx2 + dilation_w * v; + int x3 = stride_w * dx3 + dilation_w * v; + int x4 = stride_w * dx4 + dilation_w * v; + int x5 = stride_w * dx5 + dilation_w * v; + int x6 = stride_w * dx6 + dilation_w * v; + int x7 = stride_w * dx7 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + int y1 = stride_h * dy1 + dilation_h * u; + int y2 = stride_h * dy2 + dilation_h * u; + int y3 = stride_h * dy3 + dilation_h * u; + int y4 = stride_h * dy4 + dilation_h * u; + int y5 = stride_h * dy5 + dilation_h * u; + int y6 = stride_h * dy6 + dilation_h * u; + int y7 = stride_h * dy7 + dilation_h * u; + + const float* sptr0 = img.row(y0) + x0 * elempack; + const float* sptr1 = img.row(y1) + x1 * elempack; + const float* sptr2 = img.row(y2) + x2 * elempack; + const float* sptr3 = img.row(y3) + x3 * elempack; + const float* sptr4 = img.row(y4) + x4 * elempack; + const float* sptr5 = img.row(y5) + x5 * elempack; + const float* sptr6 = img.row(y6) + x6 * elempack; + const float* sptr7 = img.row(y7) + x7 * elempack; + +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + __m512 _r0 = _mm512_load_ps(sptr0); + __m512 _r1 = _mm512_load_ps(sptr1); + __m512 _r2 = _mm512_load_ps(sptr2); + __m512 _r3 = _mm512_load_ps(sptr3); + __m512 _r4 = _mm512_load_ps(sptr4); + __m512 _r5 = _mm512_load_ps(sptr5); + __m512 _r6 = _mm512_load_ps(sptr6); + __m512 _r7 = _mm512_load_ps(sptr7); + transpose16x8_ps(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7); + _mm512_store_ps(pp, _r0); + _mm512_store_ps(pp + 16 * 1, _r1); + _mm512_store_ps(pp + 16 * 2, _r2); + _mm512_store_ps(pp + 16 * 3, _r3); + _mm512_store_ps(pp + 16 * 4, _r4); + _mm512_store_ps(pp + 16 * 5, _r5); + _mm512_store_ps(pp + 16 * 6, _r6); + _mm512_store_ps(pp + 16 * 7, _r7); + pp += 128; + } +#endif // __AVX512F__ + if (elempack == 8) + { + __m256 _r0 = _mm256_load_ps(sptr0); + __m256 _r1 = _mm256_load_ps(sptr1); + __m256 _r2 = _mm256_load_ps(sptr2); + __m256 _r3 = _mm256_load_ps(sptr3); + __m256 _r4 = _mm256_load_ps(sptr4); + __m256 _r5 = _mm256_load_ps(sptr5); + __m256 _r6 = _mm256_load_ps(sptr6); + __m256 _r7 = _mm256_load_ps(sptr7); + transpose8x8_ps(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7); + _mm256_store_ps(pp, _r0); + _mm256_store_ps(pp + 8 * 1, _r1); + _mm256_store_ps(pp + 8 * 2, _r2); + _mm256_store_ps(pp + 8 * 3, _r3); + _mm256_store_ps(pp + 8 * 4, _r4); + _mm256_store_ps(pp + 8 * 5, _r5); + _mm256_store_ps(pp + 8 * 6, _r6); + _mm256_store_ps(pp + 8 * 7, _r7); + pp += 64; + } +#endif // __AVX__ + if (elempack == 4) + { + __m128 _r0 = _mm_load_ps(sptr0); + __m128 _r1 = _mm_load_ps(sptr1); + __m128 _r2 = _mm_load_ps(sptr2); + __m128 _r3 = _mm_load_ps(sptr3); + __m128 _r4 = _mm_load_ps(sptr4); + __m128 _r5 = _mm_load_ps(sptr5); + __m128 _r6 = _mm_load_ps(sptr6); + __m128 _r7 = _mm_load_ps(sptr7); + _MM_TRANSPOSE4_PS(_r0, _r1, _r2, _r3); + _MM_TRANSPOSE4_PS(_r4, _r5, _r6, _r7); + _mm_store_ps(pp, _r0); + _mm_store_ps(pp + 4 * 1, _r4); + _mm_store_ps(pp + 4 * 2, _r1); + _mm_store_ps(pp + 4 * 3, _r5); + _mm_store_ps(pp + 4 * 4, _r2); + _mm_store_ps(pp + 4 * 5, _r6); + _mm_store_ps(pp + 4 * 6, _r3); + _mm_store_ps(pp + 4 * 7, _r7); + pp += 32; + } + if (elempack == 1) + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr2[0]; + pp[3] = sptr3[0]; + pp[4] = sptr4[0]; + pp[5] = sptr5[0]; + pp[6] = sptr6[0]; + pp[7] = sptr7[0]; + pp += 8; + } + } + } + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + int dy0 = (j + jj) / outw; + int dy1 = (j + jj + 1) / outw; + int dy2 = (j + jj + 2) / outw; + int dy3 = (j + jj + 3) / outw; + int dx0 = (j + jj) % outw; + int dx1 = (j + jj + 1) % outw; + int dx2 = (j + jj + 2) % outw; + int dx3 = (j + jj + 3) % outw; + + if (dy0 == dy3) + { + int kk = 0; + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + + const float* sptr = img.row(y0) + x0 * elempack; + +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + __m512 _r0 = _mm512_load_ps(sptr); + __m512 _r1 = _mm512_load_ps(sptr + stride_w * 16); + __m512 _r2 = _mm512_load_ps(sptr + stride_w * 32); + __m512 _r3 = _mm512_load_ps(sptr + stride_w * 48); + transpose16x4_ps(_r0, _r1, _r2, _r3); + _mm512_store_ps(pp, _r0); + _mm512_store_ps(pp + 16 * 1, _r1); + _mm512_store_ps(pp + 16 * 2, _r2); + _mm512_store_ps(pp + 16 * 3, _r3); + pp += 64; + } +#endif // __AVX512F__ + if (elempack == 8) + { + __m256 _r0 = _mm256_load_ps(sptr); + __m256 _r1 = _mm256_load_ps(sptr + stride_w * 8); + __m256 _r2 = _mm256_load_ps(sptr + stride_w * 16); + __m256 _r3 = _mm256_load_ps(sptr + stride_w * 24); + transpose8x4_ps(_r0, _r1, _r2, _r3); + _mm256_store_ps(pp, _r0); + _mm256_store_ps(pp + 8 * 1, _r1); + _mm256_store_ps(pp + 8 * 2, _r2); + _mm256_store_ps(pp + 8 * 3, _r3); + pp += 32; + } +#endif // __AVX__ + if (elempack == 4) + { + __m128 _r0 = _mm_load_ps(sptr); + __m128 _r1 = _mm_load_ps(sptr + stride_w * 4); + __m128 _r2 = _mm_load_ps(sptr + stride_w * 8); + __m128 _r3 = _mm_load_ps(sptr + stride_w * 12); + _MM_TRANSPOSE4_PS(_r0, _r1, _r2, _r3); + _mm_store_ps(pp, _r0); + _mm_store_ps(pp + 4 * 1, _r1); + _mm_store_ps(pp + 4 * 2, _r2); + _mm_store_ps(pp + 4 * 3, _r3); + pp += 16; + } + if (elempack == 1) + { + pp[0] = sptr[0]; + pp[1] = sptr[stride_w]; + pp[2] = sptr[stride_w * 2]; + pp[3] = sptr[stride_w * 3]; + pp += 4; + } + } + } + else + { + int kk = 0; + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int x1 = stride_w * dx1 + dilation_w * v; + int x2 = stride_w * dx2 + dilation_w * v; + int x3 = stride_w * dx3 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + int y1 = stride_h * dy1 + dilation_h * u; + int y2 = stride_h * dy2 + dilation_h * u; + int y3 = stride_h * dy3 + dilation_h * u; + + const float* sptr0 = img.row(y0) + x0 * elempack; + const float* sptr1 = img.row(y1) + x1 * elempack; + const float* sptr2 = img.row(y2) + x2 * elempack; + const float* sptr3 = img.row(y3) + x3 * elempack; + +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + __m512 _r0 = _mm512_load_ps(sptr0); + __m512 _r1 = _mm512_load_ps(sptr1); + __m512 _r2 = _mm512_load_ps(sptr2); + __m512 _r3 = _mm512_load_ps(sptr3); + transpose16x4_ps(_r0, _r1, _r2, _r3); + _mm512_store_ps(pp, _r0); + _mm512_store_ps(pp + 16 * 1, _r1); + _mm512_store_ps(pp + 16 * 2, _r2); + _mm512_store_ps(pp + 16 * 3, _r3); + pp += 64; + } +#endif // __AVX512F__ + if (elempack == 8) + { + __m256 _r0 = _mm256_load_ps(sptr0); + __m256 _r1 = _mm256_load_ps(sptr1); + __m256 _r2 = _mm256_load_ps(sptr2); + __m256 _r3 = _mm256_load_ps(sptr3); + transpose8x4_ps(_r0, _r1, _r2, _r3); + _mm256_store_ps(pp, _r0); + _mm256_store_ps(pp + 8 * 1, _r1); + _mm256_store_ps(pp + 8 * 2, _r2); + _mm256_store_ps(pp + 8 * 3, _r3); + pp += 32; + } +#endif // __AVX__ + if (elempack == 4) + { + __m128 _r0 = _mm_load_ps(sptr0); + __m128 _r1 = _mm_load_ps(sptr1); + __m128 _r2 = _mm_load_ps(sptr2); + __m128 _r3 = _mm_load_ps(sptr3); + _MM_TRANSPOSE4_PS(_r0, _r1, _r2, _r3); + _mm_store_ps(pp, _r0); + _mm_store_ps(pp + 4 * 1, _r1); + _mm_store_ps(pp + 4 * 2, _r2); + _mm_store_ps(pp + 4 * 3, _r3); + pp += 16; + } + if (elempack == 1) + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr2[0]; + pp[3] = sptr3[0]; + pp += 4; + } + } + } + } +#endif // __SSE2__ + for (; jj + 1 < max_jj; jj += 2) + { + int dy0 = (j + jj) / outw; + int dy1 = (j + jj + 1) / outw; + int dx0 = (j + jj) % outw; + int dx1 = (j + jj + 1) % outw; + + if (dy0 == dy1) + { + int kk = 0; + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + + const float* sptr = img.row(y0) + x0 * elempack; + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + __m512 _r0 = _mm512_load_ps(sptr); + __m512 _r1 = _mm512_load_ps(sptr + stride_w * 16); + transpose16x2_ps(_r0, _r1); + _mm512_store_ps(pp, _r0); + _mm512_store_ps(pp + 16, _r1); + pp += 32; + } +#endif // __AVX512F__ + if (elempack == 8) + { + __m256 _r0 = _mm256_load_ps(sptr); + __m256 _r1 = _mm256_load_ps(sptr + stride_w * 8); + transpose8x2_ps(_r0, _r1); + _mm256_store_ps(pp, _r0); + _mm256_store_ps(pp + 8, _r1); + pp += 16; + } +#endif // __AVX__ + if (elempack == 4) + { + __m128 _r0 = _mm_load_ps(sptr); + __m128 _r1 = _mm_load_ps(sptr + stride_w * 4); + __m128 _tmp0 = _mm_unpacklo_ps(_r0, _r1); + __m128 _tmp1 = _mm_unpackhi_ps(_r0, _r1); + _mm_store_ps(pp, _tmp0); + _mm_store_ps(pp + 4, _tmp1); + pp += 8; + } +#endif // __SSE2__ + if (elempack == 1) + { + pp[0] = sptr[0]; + pp[1] = sptr[stride_w]; + pp += 2; + } + } + } + else + { + int kk = 0; + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int x1 = stride_w * dx1 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + int y1 = stride_h * dy1 + dilation_h * u; + + const float* sptr0 = img.row(y0) + x0 * elempack; + const float* sptr1 = img.row(y1) + x1 * elempack; + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + __m512 _r0 = _mm512_load_ps(sptr0); + __m512 _r1 = _mm512_load_ps(sptr1); + transpose16x2_ps(_r0, _r1); + _mm512_store_ps(pp, _r0); + _mm512_store_ps(pp + 16, _r1); + pp += 32; + } +#endif // __AVX512F__ + if (elempack == 8) + { + __m256 _r0 = _mm256_load_ps(sptr0); + __m256 _r1 = _mm256_load_ps(sptr1); + transpose8x2_ps(_r0, _r1); + _mm256_store_ps(pp, _r0); + _mm256_store_ps(pp + 8, _r1); + pp += 16; + } +#endif // __AVX__ + if (elempack == 4) + { + __m128 _r0 = _mm_load_ps(sptr0); + __m128 _r1 = _mm_load_ps(sptr1); + __m128 _tmp0 = _mm_unpacklo_ps(_r0, _r1); + __m128 _tmp1 = _mm_unpackhi_ps(_r0, _r1); + _mm_store_ps(pp, _tmp0); + _mm_store_ps(pp + 4, _tmp1); + pp += 8; + } +#endif // __SSE2__ + if (elempack == 1) + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp += 2; + } + } + } + } + for (; jj < max_jj; jj++) + { + int dy = (j + jj) / outw; + int dx = (j + jj) % outw; + + int kk = 0; + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x = stride_w * dx + dilation_w * v; + int y = stride_h * dy + dilation_h * u; + + const float* sptr = img.row(y) + x * elempack; + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + _mm512_store_ps(pp, _mm512_load_ps(sptr)); + pp += 16; + } +#endif // __AVX512F__ + if (elempack == 8) + { + _mm256_store_ps(pp, _mm256_load_ps(sptr)); + pp += 8; + } +#endif // __AVX__ + if (elempack == 4) + { + _mm_store_ps(pp, _mm_load_ps(sptr)); + pp += 4; + } +#endif // __SSE2__ + if (elempack == 1) + { + pp[0] = sptr[0]; + pp += 1; + } + } + } +} + +template +#if __AVX512F__ +void convolution_im2col_input_tile_avx512(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk) +#elif __AVX__ +void convolution_im2col_input_tile_avx(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk) +#else +void convolution_im2col_input_tile(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk) +#endif +{ + convolution_im2col_input_tile_impl(bottom_blob, B, j, max_jj, k, max_kk, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h); +} + +#if __AVX512F__ +template void convolution_im2col_input_tile_avx512<1, 1, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_avx512<3, 3, 1, 1, 1, 1>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_avx512<3, 3, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_avx512<5, 5, 1, 1, 1, 1>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_avx512<5, 5, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_avx512<7, 7, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +#elif __AVX__ +template void convolution_im2col_input_tile_avx<1, 1, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_avx<3, 3, 1, 1, 1, 1>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_avx<3, 3, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_avx<5, 5, 1, 1, 1, 1>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_avx<5, 5, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_avx<7, 7, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +#else +template void convolution_im2col_input_tile<1, 1, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile<3, 3, 1, 1, 1, 1>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile<3, 3, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile<5, 5, 1, 1, 1, 1>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile<5, 5, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile<7, 7, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +#endif + +static void convolution_im2col_input_tile(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h) +{ + if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) + { + convolution_im2col_input_tile_conv1x1s1d1(bottom_blob, B, j, max_jj, k, max_kk); + return; + } + + if (kernel_w == 1 && kernel_h == 1 && stride_w == 2 && stride_h == 2) + { +#if __AVX512F__ + convolution_im2col_input_tile_avx512<1, 1, 1, 1, 2, 2>(bottom_blob, B, j, max_jj, k, max_kk); +#elif __AVX__ + convolution_im2col_input_tile_avx<1, 1, 1, 1, 2, 2>(bottom_blob, B, j, max_jj, k, max_kk); +#else + convolution_im2col_input_tile<1, 1, 1, 1, 2, 2>(bottom_blob, B, j, max_jj, k, max_kk); +#endif + return; + } + + if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) + { +#if __AVX512F__ + convolution_im2col_input_tile_avx512<3, 3, 1, 1, 1, 1>(bottom_blob, B, j, max_jj, k, max_kk); +#elif __AVX__ + convolution_im2col_input_tile_avx<3, 3, 1, 1, 1, 1>(bottom_blob, B, j, max_jj, k, max_kk); +#else + convolution_im2col_input_tile<3, 3, 1, 1, 1, 1>(bottom_blob, B, j, max_jj, k, max_kk); +#endif + return; + } + + if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) + { +#if __AVX512F__ + convolution_im2col_input_tile_avx512<3, 3, 1, 1, 2, 2>(bottom_blob, B, j, max_jj, k, max_kk); +#elif __AVX__ + convolution_im2col_input_tile_avx<3, 3, 1, 1, 2, 2>(bottom_blob, B, j, max_jj, k, max_kk); +#else + convolution_im2col_input_tile<3, 3, 1, 1, 2, 2>(bottom_blob, B, j, max_jj, k, max_kk); +#endif + return; + } + + if (kernel_w == 5 && kernel_h == 5 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) + { +#if __AVX512F__ + convolution_im2col_input_tile_avx512<5, 5, 1, 1, 1, 1>(bottom_blob, B, j, max_jj, k, max_kk); +#elif __AVX__ + convolution_im2col_input_tile_avx<5, 5, 1, 1, 1, 1>(bottom_blob, B, j, max_jj, k, max_kk); +#else + convolution_im2col_input_tile<5, 5, 1, 1, 1, 1>(bottom_blob, B, j, max_jj, k, max_kk); +#endif + return; + } + + if (kernel_w == 5 && kernel_h == 5 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) + { +#if __AVX512F__ + convolution_im2col_input_tile_avx512<5, 5, 1, 1, 2, 2>(bottom_blob, B, j, max_jj, k, max_kk); +#elif __AVX__ + convolution_im2col_input_tile_avx<5, 5, 1, 1, 2, 2>(bottom_blob, B, j, max_jj, k, max_kk); +#else + convolution_im2col_input_tile<5, 5, 1, 1, 2, 2>(bottom_blob, B, j, max_jj, k, max_kk); +#endif + return; + } + + if (kernel_w == 7 && kernel_h == 7 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) + { +#if __AVX512F__ + convolution_im2col_input_tile_avx512<7, 7, 1, 1, 2, 2>(bottom_blob, B, j, max_jj, k, max_kk); +#elif __AVX__ + convolution_im2col_input_tile_avx<7, 7, 1, 1, 2, 2>(bottom_blob, B, j, max_jj, k, max_kk); +#else + convolution_im2col_input_tile<7, 7, 1, 1, 2, 2>(bottom_blob, B, j, max_jj, k, max_kk); +#endif + return; + } + + convolution_im2col_input_tile_impl(bottom_blob, B, j, max_jj, k, max_kk, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h); +} + +static void convolution_im2col_gemm_transform_kernel(const Mat& kernel, Mat& AT, int inch, int outch, int kernel_w, int kernel_h, const Option& opt) +{ + // NCNN_LOGE("convolution_im2col_gemm_transform_kernel"); + const int maxk = kernel_w * kernel_h; + + const int M = outch; + const int K = inch * maxk; + + int TILE_M, TILE_N, TILE_K; + convolution_im2col_gemm_get_optimal_tile_mnk(M, 0, K, TILE_M, TILE_N, TILE_K, opt.num_threads); + + const int nn_M = (M + TILE_M - 1) / TILE_M; + + int elempack = 1; +#if __SSE2__ + if (opt.use_packing_layout) + { +#if __AVX512F__ + elempack = inch % 16 == 0 ? 16 : inch % 8 == 0 ? 8 : inch % 4 == 0 ? 4 : 1; +#elif __AVX__ + elempack = inch % 8 == 0 ? 8 : inch % 4 == 0 ? 4 : 1; +#else + elempack = inch % 4 == 0 ? 4 : 1; +#endif + } +#endif // __SSE2__ + + // maxk-inch-outch to pa-maxk-inch/pa-outch + Mat A_data; + if (maxk == 1) + { + A_data = kernel.reshape(maxk * inch, outch); + } + else + { + Mat weight_data_r2 = kernel.reshape(maxk, inch, outch); + + A_data.create(maxk * inch, outch); + + for (int q = 0; q < outch; q += 1) + { + float* g00 = A_data.row(q); + + for (int p = 0; p + (elempack - 1) < inch; p += elempack) + { + for (int k = 0; k < maxk; k++) + { + for (int i = 0; i < elempack; i++) + { + const float* k00 = weight_data_r2.channel(q).row(p + i); + g00[0] = k00[k]; + g00++; + } + } + } + } + } + + AT.create(TILE_K * TILE_M, (K + TILE_K - 1) / TILE_K, (M + TILE_M - 1) / TILE_M); + + #pragma omp parallel for num_threads(opt.num_threads) + for (int ppj = 0; ppj < nn_M; ppj++) + { + const int i = ppj * TILE_M; + + const int max_ii = std::min((M - i), TILE_M); + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + Mat AT_tile = AT.channel(i / TILE_M).row_range(k / TILE_K, 1); + + convolution_im2col_pack_A_tile(A_data, AT_tile, i, max_ii, k, max_kk); + } + } +} + +static void convolution_im2col_gemm(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, const Mat& bias, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, int nT, const Option& opt) +{ + const int maxk = kernel_w * kernel_h; + + const int M = top_blob.c * top_blob.elempack; + const int N = top_blob.w * top_blob.h; + const int K = bottom_blob.c * bottom_blob.elempack * maxk; + + int TILE_M, TILE_N, TILE_K; + convolution_im2col_gemm_get_optimal_tile_mnk(M, N, K, TILE_M, TILE_N, TILE_K, nT); + + const int nn_M = (M + TILE_M - 1) / TILE_M; + const int nn_N = (N + TILE_N - 1) / TILE_N; + const int nn_K = (K + TILE_K - 1) / TILE_K; + + // NCNN_LOGE("TILE M/N/K = %d %d %d -> %d %d %d", M, N, K, TILE_M, TILE_N, TILE_K); + + Mat BT(TILE_K * TILE_N, (K + TILE_K - 1) / TILE_K, (N + TILE_N - 1) / TILE_N, 4u, opt.workspace_allocator); + + const int nn_NK = nn_N * nn_K; + + #pragma omp parallel for num_threads(nT) + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + // im2col + convolution_im2col_input_tile(bottom_blob, BT_tile, j, max_jj, k, max_kk, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h); + } + + Mat topT_tileX; + if (K > TILE_K) + topT_tileX.create(TILE_N * TILE_M, 1, nT, 4u, opt.workspace_allocator); + + #pragma omp parallel for num_threads(nT) + for (int ppj = 0; ppj < nn_M; ppj++) + { + const int i = ppj * TILE_M; + + Mat topT_tile; + if (K > TILE_K) + topT_tile = topT_tileX.channel(get_omp_thread_num()); + + const int max_ii = std::min((M - i), TILE_M); + + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + const Mat AT_tile = AT.channel(i / TILE_M).row_range(k / TILE_K, 1); + + const Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + bool k_end = k + TILE_K >= K; + + convolution_gemm_transB_packed_tile(AT_tile, BT_tile, bias, topT_tile, top_blob, i, max_ii, j, max_jj, k, max_kk, k_end); + } + } + } +} diff --git a/src/layer/x86/convolution_im2col_gemm_int8.h b/src/layer/x86/convolution_im2col_gemm_int8.h index e72dd8882dd..351987abaab 100644 --- a/src/layer/x86/convolution_im2col_gemm_int8.h +++ b/src/layer/x86/convolution_im2col_gemm_int8.h @@ -6138,12 +6138,7 @@ static void convolution_im2col_input_tile_conv1x1s1d1_int8(const Mat& bottom_blo } } -template -#if __AVX512F__ -void convolution_im2col_input_tile_int8_avx512(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk) -#else // __AVX512F__ -void convolution_im2col_input_tile_int8(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk) -#endif // __AVX512F__ +static inline void convolution_im2col_input_tile_int8_impl(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h) { const int w = bottom_blob.w; // const int channels = bottom_blob.c; @@ -7382,6 +7377,16 @@ void convolution_im2col_input_tile_int8(const Mat& bottom_blob, Mat& B, int j, i } } +template +#if __AVX512F__ +void convolution_im2col_input_tile_int8_avx512(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk) +#else // __AVX512F__ +void convolution_im2col_input_tile_int8(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk) +#endif // __AVX512F__ +{ + convolution_im2col_input_tile_int8_impl(bottom_blob, B, j, max_jj, k, max_kk, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h); +} + #if __AVX512F__ template void convolution_im2col_input_tile_int8_avx512<1, 1, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); template void convolution_im2col_input_tile_int8_avx512<3, 3, 1, 1, 1, 1>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); @@ -7466,1241 +7471,7 @@ static void convolution_im2col_input_tile_int8(const Mat& bottom_blob, Mat& B, i return; } - const int w = bottom_blob.w; - // const int channels = bottom_blob.c; - const int elempack = bottom_blob.elempack; - - const int kernel_extent_w = dilation_w * (kernel_w - 1) + 1; - const int outw = (w - kernel_extent_w) / stride_w + 1; - - // j max_jj outw*outh split w and h - - // k max_kk pa*maxk*(inch/pa) split inch - - // k/max_kk shall be multiple of maxk - - const int maxk = kernel_w * kernel_h; - - signed char* pp = B; - - int jj = 0; -#if __SSE2__ -#if defined(__x86_64__) || defined(_M_X64) -#if __AVX512F__ - for (; jj + 15 < max_jj; jj += 16) - { - int dy0 = (j + jj) / outw; - int dy1 = (j + jj + 1) / outw; - int dy2 = (j + jj + 2) / outw; - int dy3 = (j + jj + 3) / outw; - int dy4 = (j + jj + 4) / outw; - int dy5 = (j + jj + 5) / outw; - int dy6 = (j + jj + 6) / outw; - int dy7 = (j + jj + 7) / outw; - int dy8 = (j + jj + 8) / outw; - int dy9 = (j + jj + 9) / outw; - int dya = (j + jj + 10) / outw; - int dyb = (j + jj + 11) / outw; - int dyc = (j + jj + 12) / outw; - int dyd = (j + jj + 13) / outw; - int dye = (j + jj + 14) / outw; - int dyf = (j + jj + 15) / outw; - int dx0 = (j + jj) % outw; - int dx1 = (j + jj + 1) % outw; - int dx2 = (j + jj + 2) % outw; - int dx3 = (j + jj + 3) % outw; - int dx4 = (j + jj + 4) % outw; - int dx5 = (j + jj + 5) % outw; - int dx6 = (j + jj + 6) % outw; - int dx7 = (j + jj + 7) % outw; - int dx8 = (j + jj + 8) % outw; - int dx9 = (j + jj + 9) % outw; - int dxa = (j + jj + 10) % outw; - int dxb = (j + jj + 11) % outw; - int dxc = (j + jj + 12) % outw; - int dxd = (j + jj + 13) % outw; - int dxe = (j + jj + 14) % outw; - int dxf = (j + jj + 15) % outw; - - if (dy0 == dyf) - { - int kk = 0; - if (elempack == 1) - { - for (; kk + 1 < max_kk; kk += 2) - { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - - int x00 = stride_w * dx0 + dilation_w * v0; - int y00 = stride_h * dy0 + dilation_h * u0; - int x10 = stride_w * dx0 + dilation_w * v1; - int y10 = stride_h * dy0 + dilation_h * u1; - - const signed char* sptr0 = img0.row(y00) + x00; - const signed char* sptr1 = img1.row(y10) + x10; - - if (stride_w == 1) - { - __m128i _r0 = _mm_loadu_si128((const __m128i*)sptr0); - __m128i _r1 = _mm_loadu_si128((const __m128i*)sptr1); - __m128i _tmp0 = _mm_unpacklo_epi8(_r0, _r1); - __m128i _tmp1 = _mm_unpackhi_epi8(_r0, _r1); - _mm_store_si128((__m128i*)pp, _tmp0); - _mm_store_si128((__m128i*)(pp + 16), _tmp1); - pp += 32; - } - else if (stride_w == 2) - { - __m256i _r0 = _mm256_loadu_si256((const __m256i*)sptr0); - __m256i _r1 = _mm256_loadu_si256((const __m256i*)sptr1); - __m256i _tmp0 = _mm256_unpacklo_epi8(_r0, _r1); - __m256i _tmp1 = _mm256_unpackhi_epi8(_r0, _r1); - _tmp0 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_tmp0, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); - _tmp1 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_tmp1, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); - _tmp0 = _mm256_shuffle_epi32(_tmp0, _MM_SHUFFLE(3, 1, 2, 0)); - _tmp1 = _mm256_shuffle_epi32(_tmp1, _MM_SHUFFLE(3, 1, 2, 0)); - __m256i _r01 = _mm256_unpacklo_epi64(_tmp0, _tmp1); - _mm256_storeu_si256((__m256i*)pp, _r01); - pp += 32; - } - else - { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp[2] = sptr0[stride_w]; - pp[3] = sptr1[stride_w]; - pp[4] = sptr0[stride_w * 2]; - pp[5] = sptr1[stride_w * 2]; - pp[6] = sptr0[stride_w * 3]; - pp[7] = sptr1[stride_w * 3]; - pp[8] = sptr0[stride_w * 4]; - pp[9] = sptr1[stride_w * 4]; - pp[10] = sptr0[stride_w * 5]; - pp[11] = sptr1[stride_w * 5]; - pp[12] = sptr0[stride_w * 6]; - pp[13] = sptr1[stride_w * 6]; - pp[14] = sptr0[stride_w * 7]; - pp[15] = sptr1[stride_w * 7]; - pp[16 + 0] = sptr0[stride_w * 8]; - pp[16 + 1] = sptr1[stride_w * 8]; - pp[16 + 2] = sptr0[stride_w * 9]; - pp[16 + 3] = sptr1[stride_w * 9]; - pp[16 + 4] = sptr0[stride_w * 10]; - pp[16 + 5] = sptr1[stride_w * 10]; - pp[16 + 6] = sptr0[stride_w * 11]; - pp[16 + 7] = sptr1[stride_w * 11]; - pp[16 + 8] = sptr0[stride_w * 12]; - pp[16 + 9] = sptr1[stride_w * 12]; - pp[16 + 10] = sptr0[stride_w * 13]; - pp[16 + 11] = sptr1[stride_w * 13]; - pp[16 + 12] = sptr0[stride_w * 14]; - pp[16 + 13] = sptr1[stride_w * 14]; - pp[16 + 14] = sptr0[stride_w * 15]; - pp[16 + 15] = sptr1[stride_w * 15]; - pp += 32; - } - } - } - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; - - const Mat img = bottom_blob.channel(p); - - int x0 = stride_w * dx0 + dilation_w * v; - int y0 = stride_h * dy0 + dilation_h * u; - - const signed char* sptr = img.row(y0) + x0 * elempack; - - if (elempack == 8) - { - __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr); - __m128i _r1 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 8)); - __m128i _r2 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 16)); - __m128i _r3 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 24)); - __m128i _r4 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 32)); - __m128i _r5 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 40)); - __m128i _r6 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 48)); - __m128i _r7 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 56)); - __m128i _r8 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 64)); - __m128i _r9 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 72)); - __m128i _ra = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 80)); - __m128i _rb = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 88)); - __m128i _rc = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 96)); - __m128i _rd = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 104)); - __m128i _re = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 112)); - __m128i _rf = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 120)); - __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); - __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); - __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); - __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); - __m128i _r89 = _mm_unpacklo_epi16(_r8, _r9); - __m128i _rab = _mm_unpacklo_epi16(_ra, _rb); - __m128i _rcd = _mm_unpacklo_epi16(_rc, _rd); - __m128i _ref = _mm_unpacklo_epi16(_re, _rf); - _r0 = _mm_unpacklo_epi32(_r01, _r23); - _r1 = _mm_unpackhi_epi32(_r01, _r23); - _r2 = _mm_unpacklo_epi32(_r45, _r67); - _r3 = _mm_unpackhi_epi32(_r45, _r67); - _r4 = _mm_unpacklo_epi32(_r89, _rab); - _r5 = _mm_unpackhi_epi32(_r89, _rab); - _r6 = _mm_unpacklo_epi32(_rcd, _ref); - _r7 = _mm_unpackhi_epi32(_rcd, _ref); - _r8 = _mm_unpacklo_epi64(_r0, _r2); - _r9 = _mm_unpacklo_epi64(_r4, _r6); - _ra = _mm_unpackhi_epi64(_r0, _r2); - _rb = _mm_unpackhi_epi64(_r4, _r6); - _rc = _mm_unpacklo_epi64(_r1, _r3); - _rd = _mm_unpacklo_epi64(_r5, _r7); - _re = _mm_unpackhi_epi64(_r1, _r3); - _rf = _mm_unpackhi_epi64(_r5, _r7); - _mm_store_si128((__m128i*)pp, _r8); - _mm_store_si128((__m128i*)(pp + 16), _r9); - _mm_store_si128((__m128i*)(pp + 32), _ra); - _mm_store_si128((__m128i*)(pp + 48), _rb); - _mm_store_si128((__m128i*)(pp + 64), _rc); - _mm_store_si128((__m128i*)(pp + 80), _rd); - _mm_store_si128((__m128i*)(pp + 96), _re); - _mm_store_si128((__m128i*)(pp + 112), _rf); - pp += 128; - } - if (elempack == 1) - { - pp[0] = sptr[0]; - pp[1] = sptr[stride_w]; - pp[2] = sptr[stride_w * 2]; - pp[3] = sptr[stride_w * 3]; - pp[4] = sptr[stride_w * 4]; - pp[5] = sptr[stride_w * 5]; - pp[6] = sptr[stride_w * 6]; - pp[7] = sptr[stride_w * 7]; - pp[8] = sptr[stride_w * 8]; - pp[9] = sptr[stride_w * 9]; - pp[10] = sptr[stride_w * 10]; - pp[11] = sptr[stride_w * 11]; - pp[12] = sptr[stride_w * 12]; - pp[13] = sptr[stride_w * 13]; - pp[14] = sptr[stride_w * 14]; - pp[15] = sptr[stride_w * 15]; - pp += 16; - } - } - } - else - { - int kk = 0; - if (elempack == 1) - { - for (; kk + 1 < max_kk; kk += 2) - { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - - int x00 = stride_w * dx0 + dilation_w * v0; - int x01 = stride_w * dx1 + dilation_w * v0; - int x02 = stride_w * dx2 + dilation_w * v0; - int x03 = stride_w * dx3 + dilation_w * v0; - int x04 = stride_w * dx4 + dilation_w * v0; - int x05 = stride_w * dx5 + dilation_w * v0; - int x06 = stride_w * dx6 + dilation_w * v0; - int x07 = stride_w * dx7 + dilation_w * v0; - int x08 = stride_w * dx8 + dilation_w * v0; - int x09 = stride_w * dx9 + dilation_w * v0; - int x0a = stride_w * dxa + dilation_w * v0; - int x0b = stride_w * dxb + dilation_w * v0; - int x0c = stride_w * dxc + dilation_w * v0; - int x0d = stride_w * dxd + dilation_w * v0; - int x0e = stride_w * dxe + dilation_w * v0; - int x0f = stride_w * dxf + dilation_w * v0; - int y00 = stride_h * dy0 + dilation_h * u0; - int y01 = stride_h * dy1 + dilation_h * u0; - int y02 = stride_h * dy2 + dilation_h * u0; - int y03 = stride_h * dy3 + dilation_h * u0; - int y04 = stride_h * dy4 + dilation_h * u0; - int y05 = stride_h * dy5 + dilation_h * u0; - int y06 = stride_h * dy6 + dilation_h * u0; - int y07 = stride_h * dy7 + dilation_h * u0; - int y08 = stride_h * dy8 + dilation_h * u0; - int y09 = stride_h * dy9 + dilation_h * u0; - int y0a = stride_h * dya + dilation_h * u0; - int y0b = stride_h * dyb + dilation_h * u0; - int y0c = stride_h * dyc + dilation_h * u0; - int y0d = stride_h * dyd + dilation_h * u0; - int y0e = stride_h * dye + dilation_h * u0; - int y0f = stride_h * dyf + dilation_h * u0; - - int x10 = stride_w * dx0 + dilation_w * v1; - int x11 = stride_w * dx1 + dilation_w * v1; - int x12 = stride_w * dx2 + dilation_w * v1; - int x13 = stride_w * dx3 + dilation_w * v1; - int x14 = stride_w * dx4 + dilation_w * v1; - int x15 = stride_w * dx5 + dilation_w * v1; - int x16 = stride_w * dx6 + dilation_w * v1; - int x17 = stride_w * dx7 + dilation_w * v1; - int x18 = stride_w * dx8 + dilation_w * v1; - int x19 = stride_w * dx9 + dilation_w * v1; - int x1a = stride_w * dxa + dilation_w * v1; - int x1b = stride_w * dxb + dilation_w * v1; - int x1c = stride_w * dxc + dilation_w * v1; - int x1d = stride_w * dxd + dilation_w * v1; - int x1e = stride_w * dxe + dilation_w * v1; - int x1f = stride_w * dxf + dilation_w * v1; - int y10 = stride_h * dy0 + dilation_h * u1; - int y11 = stride_h * dy1 + dilation_h * u1; - int y12 = stride_h * dy2 + dilation_h * u1; - int y13 = stride_h * dy3 + dilation_h * u1; - int y14 = stride_h * dy4 + dilation_h * u1; - int y15 = stride_h * dy5 + dilation_h * u1; - int y16 = stride_h * dy6 + dilation_h * u1; - int y17 = stride_h * dy7 + dilation_h * u1; - int y18 = stride_h * dy8 + dilation_h * u1; - int y19 = stride_h * dy9 + dilation_h * u1; - int y1a = stride_h * dya + dilation_h * u1; - int y1b = stride_h * dyb + dilation_h * u1; - int y1c = stride_h * dyc + dilation_h * u1; - int y1d = stride_h * dyd + dilation_h * u1; - int y1e = stride_h * dye + dilation_h * u1; - int y1f = stride_h * dyf + dilation_h * u1; - - const signed char* sptr00 = img0.row(y00) + x00; - const signed char* sptr01 = img0.row(y01) + x01; - const signed char* sptr02 = img0.row(y02) + x02; - const signed char* sptr03 = img0.row(y03) + x03; - const signed char* sptr04 = img0.row(y04) + x04; - const signed char* sptr05 = img0.row(y05) + x05; - const signed char* sptr06 = img0.row(y06) + x06; - const signed char* sptr07 = img0.row(y07) + x07; - const signed char* sptr08 = img0.row(y08) + x08; - const signed char* sptr09 = img0.row(y09) + x09; - const signed char* sptr0a = img0.row(y0a) + x0a; - const signed char* sptr0b = img0.row(y0b) + x0b; - const signed char* sptr0c = img0.row(y0c) + x0c; - const signed char* sptr0d = img0.row(y0d) + x0d; - const signed char* sptr0e = img0.row(y0e) + x0e; - const signed char* sptr0f = img0.row(y0f) + x0f; - - const signed char* sptr10 = img1.row(y10) + x10; - const signed char* sptr11 = img1.row(y11) + x11; - const signed char* sptr12 = img1.row(y12) + x12; - const signed char* sptr13 = img1.row(y13) + x13; - const signed char* sptr14 = img1.row(y14) + x14; - const signed char* sptr15 = img1.row(y15) + x15; - const signed char* sptr16 = img1.row(y16) + x16; - const signed char* sptr17 = img1.row(y17) + x17; - const signed char* sptr18 = img1.row(y18) + x18; - const signed char* sptr19 = img1.row(y19) + x19; - const signed char* sptr1a = img1.row(y1a) + x1a; - const signed char* sptr1b = img1.row(y1b) + x1b; - const signed char* sptr1c = img1.row(y1c) + x1c; - const signed char* sptr1d = img1.row(y1d) + x1d; - const signed char* sptr1e = img1.row(y1e) + x1e; - const signed char* sptr1f = img1.row(y1f) + x1f; - - pp[0] = sptr00[0]; - pp[1] = sptr10[0]; - pp[2] = sptr01[0]; - pp[3] = sptr11[0]; - pp[4] = sptr02[0]; - pp[5] = sptr12[0]; - pp[6] = sptr03[0]; - pp[7] = sptr13[0]; - pp[8] = sptr04[0]; - pp[9] = sptr14[0]; - pp[10] = sptr05[0]; - pp[11] = sptr15[0]; - pp[12] = sptr06[0]; - pp[13] = sptr16[0]; - pp[14] = sptr07[0]; - pp[15] = sptr17[0]; - pp[16 + 0] = sptr08[0]; - pp[16 + 1] = sptr18[0]; - pp[16 + 2] = sptr09[0]; - pp[16 + 3] = sptr19[0]; - pp[16 + 4] = sptr0a[0]; - pp[16 + 5] = sptr1a[0]; - pp[16 + 6] = sptr0b[0]; - pp[16 + 7] = sptr1b[0]; - pp[16 + 8] = sptr0c[0]; - pp[16 + 9] = sptr1c[0]; - pp[16 + 10] = sptr0d[0]; - pp[16 + 11] = sptr1d[0]; - pp[16 + 12] = sptr0e[0]; - pp[16 + 13] = sptr1e[0]; - pp[16 + 14] = sptr0f[0]; - pp[16 + 15] = sptr1f[0]; - pp += 32; - } - } - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; - - const Mat img = bottom_blob.channel(p); - - int x0 = stride_w * dx0 + dilation_w * v; - int x1 = stride_w * dx1 + dilation_w * v; - int x2 = stride_w * dx2 + dilation_w * v; - int x3 = stride_w * dx3 + dilation_w * v; - int x4 = stride_w * dx4 + dilation_w * v; - int x5 = stride_w * dx5 + dilation_w * v; - int x6 = stride_w * dx6 + dilation_w * v; - int x7 = stride_w * dx7 + dilation_w * v; - int x8 = stride_w * dx8 + dilation_w * v; - int x9 = stride_w * dx9 + dilation_w * v; - int xa = stride_w * dxa + dilation_w * v; - int xb = stride_w * dxb + dilation_w * v; - int xc = stride_w * dxc + dilation_w * v; - int xd = stride_w * dxd + dilation_w * v; - int xe = stride_w * dxe + dilation_w * v; - int xf = stride_w * dxf + dilation_w * v; - int y0 = stride_h * dy0 + dilation_h * u; - int y1 = stride_h * dy1 + dilation_h * u; - int y2 = stride_h * dy2 + dilation_h * u; - int y3 = stride_h * dy3 + dilation_h * u; - int y4 = stride_h * dy4 + dilation_h * u; - int y5 = stride_h * dy5 + dilation_h * u; - int y6 = stride_h * dy6 + dilation_h * u; - int y7 = stride_h * dy7 + dilation_h * u; - int y8 = stride_h * dy8 + dilation_h * u; - int y9 = stride_h * dy9 + dilation_h * u; - int ya = stride_h * dya + dilation_h * u; - int yb = stride_h * dyb + dilation_h * u; - int yc = stride_h * dyc + dilation_h * u; - int yd = stride_h * dyd + dilation_h * u; - int ye = stride_h * dye + dilation_h * u; - int yf = stride_h * dyf + dilation_h * u; - - const signed char* sptr0 = img.row(y0) + x0 * elempack; - const signed char* sptr1 = img.row(y1) + x1 * elempack; - const signed char* sptr2 = img.row(y2) + x2 * elempack; - const signed char* sptr3 = img.row(y3) + x3 * elempack; - const signed char* sptr4 = img.row(y4) + x4 * elempack; - const signed char* sptr5 = img.row(y5) + x5 * elempack; - const signed char* sptr6 = img.row(y6) + x6 * elempack; - const signed char* sptr7 = img.row(y7) + x7 * elempack; - const signed char* sptr8 = img.row(y8) + x8 * elempack; - const signed char* sptr9 = img.row(y9) + x9 * elempack; - const signed char* sptra = img.row(ya) + xa * elempack; - const signed char* sptrb = img.row(yb) + xb * elempack; - const signed char* sptrc = img.row(yc) + xc * elempack; - const signed char* sptrd = img.row(yd) + xd * elempack; - const signed char* sptre = img.row(ye) + xe * elempack; - const signed char* sptrf = img.row(yf) + xf * elempack; - - if (elempack == 8) - { - __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); - __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); - __m128i _r2 = _mm_loadl_epi64((const __m128i*)sptr2); - __m128i _r3 = _mm_loadl_epi64((const __m128i*)sptr3); - __m128i _r4 = _mm_loadl_epi64((const __m128i*)sptr4); - __m128i _r5 = _mm_loadl_epi64((const __m128i*)sptr5); - __m128i _r6 = _mm_loadl_epi64((const __m128i*)sptr6); - __m128i _r7 = _mm_loadl_epi64((const __m128i*)sptr7); - __m128i _r8 = _mm_loadl_epi64((const __m128i*)sptr8); - __m128i _r9 = _mm_loadl_epi64((const __m128i*)sptr9); - __m128i _ra = _mm_loadl_epi64((const __m128i*)sptra); - __m128i _rb = _mm_loadl_epi64((const __m128i*)sptrb); - __m128i _rc = _mm_loadl_epi64((const __m128i*)sptrc); - __m128i _rd = _mm_loadl_epi64((const __m128i*)sptrd); - __m128i _re = _mm_loadl_epi64((const __m128i*)sptre); - __m128i _rf = _mm_loadl_epi64((const __m128i*)sptrf); - __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); - __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); - __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); - __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); - __m128i _r89 = _mm_unpacklo_epi16(_r8, _r9); - __m128i _rab = _mm_unpacklo_epi16(_ra, _rb); - __m128i _rcd = _mm_unpacklo_epi16(_rc, _rd); - __m128i _ref = _mm_unpacklo_epi16(_re, _rf); - _r0 = _mm_unpacklo_epi32(_r01, _r23); - _r1 = _mm_unpackhi_epi32(_r01, _r23); - _r2 = _mm_unpacklo_epi32(_r45, _r67); - _r3 = _mm_unpackhi_epi32(_r45, _r67); - _r4 = _mm_unpacklo_epi32(_r89, _rab); - _r5 = _mm_unpackhi_epi32(_r89, _rab); - _r6 = _mm_unpacklo_epi32(_rcd, _ref); - _r7 = _mm_unpackhi_epi32(_rcd, _ref); - _r8 = _mm_unpacklo_epi64(_r0, _r2); - _r9 = _mm_unpacklo_epi64(_r4, _r6); - _ra = _mm_unpackhi_epi64(_r0, _r2); - _rb = _mm_unpackhi_epi64(_r4, _r6); - _rc = _mm_unpacklo_epi64(_r1, _r3); - _rd = _mm_unpacklo_epi64(_r5, _r7); - _re = _mm_unpackhi_epi64(_r1, _r3); - _rf = _mm_unpackhi_epi64(_r5, _r7); - _mm_store_si128((__m128i*)pp, _r8); - _mm_store_si128((__m128i*)(pp + 16), _r9); - _mm_store_si128((__m128i*)(pp + 32), _ra); - _mm_store_si128((__m128i*)(pp + 48), _rb); - _mm_store_si128((__m128i*)(pp + 64), _rc); - _mm_store_si128((__m128i*)(pp + 80), _rd); - _mm_store_si128((__m128i*)(pp + 96), _re); - _mm_store_si128((__m128i*)(pp + 112), _rf); - pp += 128; - } - if (elempack == 1) - { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp[2] = sptr2[0]; - pp[3] = sptr3[0]; - pp[4] = sptr4[0]; - pp[5] = sptr5[0]; - pp[6] = sptr6[0]; - pp[7] = sptr7[0]; - pp[8] = sptr8[0]; - pp[9] = sptr9[0]; - pp[10] = sptra[0]; - pp[11] = sptrb[0]; - pp[12] = sptrc[0]; - pp[13] = sptrd[0]; - pp[14] = sptre[0]; - pp[15] = sptrf[0]; - pp += 16; - } - } - } - } -#endif // __AVX512F__ - for (; jj + 7 < max_jj; jj += 8) - { - int dy0 = (j + jj) / outw; - int dy1 = (j + jj + 1) / outw; - int dy2 = (j + jj + 2) / outw; - int dy3 = (j + jj + 3) / outw; - int dy4 = (j + jj + 4) / outw; - int dy5 = (j + jj + 5) / outw; - int dy6 = (j + jj + 6) / outw; - int dy7 = (j + jj + 7) / outw; - int dx0 = (j + jj) % outw; - int dx1 = (j + jj + 1) % outw; - int dx2 = (j + jj + 2) % outw; - int dx3 = (j + jj + 3) % outw; - int dx4 = (j + jj + 4) % outw; - int dx5 = (j + jj + 5) % outw; - int dx6 = (j + jj + 6) % outw; - int dx7 = (j + jj + 7) % outw; - - if (dy0 == dy7) - { - int kk = 0; - if (elempack == 1) - { - for (; kk + 1 < max_kk; kk += 2) - { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - - int x00 = stride_w * dx0 + dilation_w * v0; - int y00 = stride_h * dy0 + dilation_h * u0; - int x10 = stride_w * dx0 + dilation_w * v1; - int y10 = stride_h * dy0 + dilation_h * u1; - - const signed char* sptr0 = img0.row(y00) + x00; - const signed char* sptr1 = img1.row(y10) + x10; - - if (stride_w == 1) - { - __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); - __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); - __m128i _r01 = _mm_unpacklo_epi8(_r0, _r1); - _mm_storeu_si128((__m128i*)pp, _r01); - pp += 16; - } - else if (stride_w == 2) - { - __m128i _r0 = _mm_loadu_si128((const __m128i*)sptr0); - __m128i _r1 = _mm_loadu_si128((const __m128i*)sptr1); - __m128i _tmp0 = _mm_unpacklo_epi8(_r0, _r1); - __m128i _tmp1 = _mm_unpackhi_epi8(_r0, _r1); - _tmp0 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_tmp0, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); - _tmp1 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_tmp1, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); - _tmp0 = _mm_shuffle_epi32(_tmp0, _MM_SHUFFLE(3, 1, 2, 0)); - _tmp1 = _mm_shuffle_epi32(_tmp1, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i _r01 = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(_tmp0), _mm_castsi128_ps(_tmp1), _MM_SHUFFLE(1, 0, 1, 0))); - _mm_storeu_si128((__m128i*)pp, _r01); - pp += 16; - } - else - { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp[2] = sptr0[stride_w]; - pp[3] = sptr1[stride_w]; - pp[4] = sptr0[stride_w * 2]; - pp[5] = sptr1[stride_w * 2]; - pp[6] = sptr0[stride_w * 3]; - pp[7] = sptr1[stride_w * 3]; - pp[8] = sptr0[stride_w * 4]; - pp[9] = sptr1[stride_w * 4]; - pp[10] = sptr0[stride_w * 5]; - pp[11] = sptr1[stride_w * 5]; - pp[12] = sptr0[stride_w * 6]; - pp[13] = sptr1[stride_w * 6]; - pp[14] = sptr0[stride_w * 7]; - pp[15] = sptr1[stride_w * 7]; - pp += 16; - } - } - } - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; - - const Mat img = bottom_blob.channel(p); - - int x0 = stride_w * dx0 + dilation_w * v; - int y0 = stride_h * dy0 + dilation_h * u; - - const signed char* sptr = img.row(y0) + x0 * elempack; - - if (elempack == 8) - { - __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr); - __m128i _r1 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 8)); - __m128i _r2 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 16)); - __m128i _r3 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 24)); - __m128i _r4 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 32)); - __m128i _r5 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 40)); - __m128i _r6 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 48)); - __m128i _r7 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 56)); - __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); - __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); - __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); - __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); - _r0 = _mm_unpacklo_epi32(_r01, _r23); - _r1 = _mm_unpackhi_epi32(_r01, _r23); - _r2 = _mm_unpacklo_epi32(_r45, _r67); - _r3 = _mm_unpackhi_epi32(_r45, _r67); - _r4 = _mm_unpacklo_epi64(_r0, _r2); - _r5 = _mm_unpackhi_epi64(_r0, _r2); - _r6 = _mm_unpacklo_epi64(_r1, _r3); - _r7 = _mm_unpackhi_epi64(_r1, _r3); - _mm_storeu_si128((__m128i*)pp, _r4); - _mm_storeu_si128((__m128i*)(pp + 16), _r5); - _mm_storeu_si128((__m128i*)(pp + 32), _r6); - _mm_storeu_si128((__m128i*)(pp + 48), _r7); - pp += 64; - } - if (elempack == 1) - { - pp[0] = sptr[0]; - pp[1] = sptr[stride_w]; - pp[2] = sptr[stride_w * 2]; - pp[3] = sptr[stride_w * 3]; - pp[4] = sptr[stride_w * 4]; - pp[5] = sptr[stride_w * 5]; - pp[6] = sptr[stride_w * 6]; - pp[7] = sptr[stride_w * 7]; - pp += 8; - } - } - } - else - { - int kk = 0; - if (elempack == 1) - { - for (; kk + 1 < max_kk; kk += 2) - { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - - int x00 = stride_w * dx0 + dilation_w * v0; - int x01 = stride_w * dx1 + dilation_w * v0; - int x02 = stride_w * dx2 + dilation_w * v0; - int x03 = stride_w * dx3 + dilation_w * v0; - int x04 = stride_w * dx4 + dilation_w * v0; - int x05 = stride_w * dx5 + dilation_w * v0; - int x06 = stride_w * dx6 + dilation_w * v0; - int x07 = stride_w * dx7 + dilation_w * v0; - int y00 = stride_h * dy0 + dilation_h * u0; - int y01 = stride_h * dy1 + dilation_h * u0; - int y02 = stride_h * dy2 + dilation_h * u0; - int y03 = stride_h * dy3 + dilation_h * u0; - int y04 = stride_h * dy4 + dilation_h * u0; - int y05 = stride_h * dy5 + dilation_h * u0; - int y06 = stride_h * dy6 + dilation_h * u0; - int y07 = stride_h * dy7 + dilation_h * u0; - - int x10 = stride_w * dx0 + dilation_w * v1; - int x11 = stride_w * dx1 + dilation_w * v1; - int x12 = stride_w * dx2 + dilation_w * v1; - int x13 = stride_w * dx3 + dilation_w * v1; - int x14 = stride_w * dx4 + dilation_w * v1; - int x15 = stride_w * dx5 + dilation_w * v1; - int x16 = stride_w * dx6 + dilation_w * v1; - int x17 = stride_w * dx7 + dilation_w * v1; - int y10 = stride_h * dy0 + dilation_h * u1; - int y11 = stride_h * dy1 + dilation_h * u1; - int y12 = stride_h * dy2 + dilation_h * u1; - int y13 = stride_h * dy3 + dilation_h * u1; - int y14 = stride_h * dy4 + dilation_h * u1; - int y15 = stride_h * dy5 + dilation_h * u1; - int y16 = stride_h * dy6 + dilation_h * u1; - int y17 = stride_h * dy7 + dilation_h * u1; - - const signed char* sptr00 = img0.row(y00) + x00; - const signed char* sptr01 = img0.row(y01) + x01; - const signed char* sptr02 = img0.row(y02) + x02; - const signed char* sptr03 = img0.row(y03) + x03; - const signed char* sptr04 = img0.row(y04) + x04; - const signed char* sptr05 = img0.row(y05) + x05; - const signed char* sptr06 = img0.row(y06) + x06; - const signed char* sptr07 = img0.row(y07) + x07; - - const signed char* sptr10 = img1.row(y10) + x10; - const signed char* sptr11 = img1.row(y11) + x11; - const signed char* sptr12 = img1.row(y12) + x12; - const signed char* sptr13 = img1.row(y13) + x13; - const signed char* sptr14 = img1.row(y14) + x14; - const signed char* sptr15 = img1.row(y15) + x15; - const signed char* sptr16 = img1.row(y16) + x16; - const signed char* sptr17 = img1.row(y17) + x17; - - pp[0] = sptr00[0]; - pp[1] = sptr10[0]; - pp[2] = sptr01[0]; - pp[3] = sptr11[0]; - pp[4] = sptr02[0]; - pp[5] = sptr12[0]; - pp[6] = sptr03[0]; - pp[7] = sptr13[0]; - pp[8] = sptr04[0]; - pp[9] = sptr14[0]; - pp[10] = sptr05[0]; - pp[11] = sptr15[0]; - pp[12] = sptr06[0]; - pp[13] = sptr16[0]; - pp[14] = sptr07[0]; - pp[15] = sptr17[0]; - pp += 16; - } - } - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; - - const Mat img = bottom_blob.channel(p); - - int x0 = stride_w * dx0 + dilation_w * v; - int x1 = stride_w * dx1 + dilation_w * v; - int x2 = stride_w * dx2 + dilation_w * v; - int x3 = stride_w * dx3 + dilation_w * v; - int x4 = stride_w * dx4 + dilation_w * v; - int x5 = stride_w * dx5 + dilation_w * v; - int x6 = stride_w * dx6 + dilation_w * v; - int x7 = stride_w * dx7 + dilation_w * v; - int y0 = stride_h * dy0 + dilation_h * u; - int y1 = stride_h * dy1 + dilation_h * u; - int y2 = stride_h * dy2 + dilation_h * u; - int y3 = stride_h * dy3 + dilation_h * u; - int y4 = stride_h * dy4 + dilation_h * u; - int y5 = stride_h * dy5 + dilation_h * u; - int y6 = stride_h * dy6 + dilation_h * u; - int y7 = stride_h * dy7 + dilation_h * u; - - const signed char* sptr0 = img.row(y0) + x0 * elempack; - const signed char* sptr1 = img.row(y1) + x1 * elempack; - const signed char* sptr2 = img.row(y2) + x2 * elempack; - const signed char* sptr3 = img.row(y3) + x3 * elempack; - const signed char* sptr4 = img.row(y4) + x4 * elempack; - const signed char* sptr5 = img.row(y5) + x5 * elempack; - const signed char* sptr6 = img.row(y6) + x6 * elempack; - const signed char* sptr7 = img.row(y7) + x7 * elempack; - - if (elempack == 8) - { - __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); - __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); - __m128i _r2 = _mm_loadl_epi64((const __m128i*)sptr2); - __m128i _r3 = _mm_loadl_epi64((const __m128i*)sptr3); - __m128i _r4 = _mm_loadl_epi64((const __m128i*)sptr4); - __m128i _r5 = _mm_loadl_epi64((const __m128i*)sptr5); - __m128i _r6 = _mm_loadl_epi64((const __m128i*)sptr6); - __m128i _r7 = _mm_loadl_epi64((const __m128i*)sptr7); - __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); - __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); - __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); - __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); - _r0 = _mm_unpacklo_epi32(_r01, _r23); - _r1 = _mm_unpackhi_epi32(_r01, _r23); - _r2 = _mm_unpacklo_epi32(_r45, _r67); - _r3 = _mm_unpackhi_epi32(_r45, _r67); - _r4 = _mm_unpacklo_epi64(_r0, _r2); - _r5 = _mm_unpackhi_epi64(_r0, _r2); - _r6 = _mm_unpacklo_epi64(_r1, _r3); - _r7 = _mm_unpackhi_epi64(_r1, _r3); - _mm_storeu_si128((__m128i*)pp, _r4); - _mm_storeu_si128((__m128i*)(pp + 16), _r5); - _mm_storeu_si128((__m128i*)(pp + 32), _r6); - _mm_storeu_si128((__m128i*)(pp + 48), _r7); - pp += 64; - } - if (elempack == 1) - { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp[2] = sptr2[0]; - pp[3] = sptr3[0]; - pp[4] = sptr4[0]; - pp[5] = sptr5[0]; - pp[6] = sptr6[0]; - pp[7] = sptr7[0]; - pp += 8; - } - } - } - } -#endif // defined(__x86_64__) || defined(_M_X64) - for (; jj + 3 < max_jj; jj += 4) - { - int dy0 = (j + jj) / outw; - int dy1 = (j + jj + 1) / outw; - int dy2 = (j + jj + 2) / outw; - int dy3 = (j + jj + 3) / outw; - int dx0 = (j + jj) % outw; - int dx1 = (j + jj + 1) % outw; - int dx2 = (j + jj + 2) % outw; - int dx3 = (j + jj + 3) % outw; - - if (dy0 == dy3) - { - int kk = 0; - if (elempack == 1) - { - for (; kk + 1 < max_kk; kk += 2) - { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - - int x00 = stride_w * dx0 + dilation_w * v0; - int y00 = stride_h * dy0 + dilation_h * u0; - int x10 = stride_w * dx0 + dilation_w * v1; - int y10 = stride_h * dy0 + dilation_h * u1; - - const signed char* sptr0 = img0.row(y00) + x00; - const signed char* sptr1 = img1.row(y10) + x10; - - if (stride_w == 1) - { - __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); - __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); - __m128i _r01 = _mm_unpacklo_epi8(_r0, _r1); - _mm_storel_epi64((__m128i*)pp, _r01); - pp += 8; - } - else if (stride_w == 2) - { - __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); - __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); - __m128i _r01 = _mm_unpacklo_epi8(_r0, _r1); - _r01 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_r01, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); - _r01 = _mm_shuffle_epi32(_r01, _MM_SHUFFLE(3, 1, 2, 0)); - _mm_storel_epi64((__m128i*)pp, _r01); - pp += 8; - } - else - { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp[2] = sptr0[stride_w]; - pp[3] = sptr1[stride_w]; - pp[4] = sptr0[stride_w * 2]; - pp[5] = sptr1[stride_w * 2]; - pp[6] = sptr0[stride_w * 3]; - pp[7] = sptr1[stride_w * 3]; - pp += 8; - } - } - } - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; - - const Mat img = bottom_blob.channel(p); - - int x0 = stride_w * dx0 + dilation_w * v; - int y0 = stride_h * dy0 + dilation_h * u; - - const signed char* sptr = img.row(y0) + x0 * elempack; - - if (elempack == 8) - { - __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr); - __m128i _r1 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 8)); - __m128i _r2 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 16)); - __m128i _r3 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 24)); - __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); - __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); - _r0 = _mm_unpacklo_epi32(_r01, _r23); - _r1 = _mm_unpackhi_epi32(_r01, _r23); - _mm_storeu_si128((__m128i*)pp, _r0); - _mm_storeu_si128((__m128i*)(pp + 16), _r1); - pp += 32; - } - if (elempack == 1) - { - pp[0] = sptr[0]; - pp[1] = sptr[stride_w]; - pp[2] = sptr[stride_w * 2]; - pp[3] = sptr[stride_w * 3]; - pp += 4; - } - } - } - else - { - int kk = 0; - if (elempack == 1) - { - for (; kk + 1 < max_kk; kk += 2) - { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - - int x00 = stride_w * dx0 + dilation_w * v0; - int x01 = stride_w * dx1 + dilation_w * v0; - int x02 = stride_w * dx2 + dilation_w * v0; - int x03 = stride_w * dx3 + dilation_w * v0; - int y00 = stride_h * dy0 + dilation_h * u0; - int y01 = stride_h * dy1 + dilation_h * u0; - int y02 = stride_h * dy2 + dilation_h * u0; - int y03 = stride_h * dy3 + dilation_h * u0; - - int x10 = stride_w * dx0 + dilation_w * v1; - int x11 = stride_w * dx1 + dilation_w * v1; - int x12 = stride_w * dx2 + dilation_w * v1; - int x13 = stride_w * dx3 + dilation_w * v1; - int y10 = stride_h * dy0 + dilation_h * u1; - int y11 = stride_h * dy1 + dilation_h * u1; - int y12 = stride_h * dy2 + dilation_h * u1; - int y13 = stride_h * dy3 + dilation_h * u1; - - const signed char* sptr00 = img0.row(y00) + x00; - const signed char* sptr01 = img0.row(y01) + x01; - const signed char* sptr02 = img0.row(y02) + x02; - const signed char* sptr03 = img0.row(y03) + x03; - - const signed char* sptr10 = img1.row(y10) + x10; - const signed char* sptr11 = img1.row(y11) + x11; - const signed char* sptr12 = img1.row(y12) + x12; - const signed char* sptr13 = img1.row(y13) + x13; - - pp[0] = sptr00[0]; - pp[1] = sptr10[0]; - pp[2] = sptr01[0]; - pp[3] = sptr11[0]; - pp[4] = sptr02[0]; - pp[5] = sptr12[0]; - pp[6] = sptr03[0]; - pp[7] = sptr13[0]; - pp += 8; - } - } - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; - - const Mat img = bottom_blob.channel(p); - - int x0 = stride_w * dx0 + dilation_w * v; - int x1 = stride_w * dx1 + dilation_w * v; - int x2 = stride_w * dx2 + dilation_w * v; - int x3 = stride_w * dx3 + dilation_w * v; - int y0 = stride_h * dy0 + dilation_h * u; - int y1 = stride_h * dy1 + dilation_h * u; - int y2 = stride_h * dy2 + dilation_h * u; - int y3 = stride_h * dy3 + dilation_h * u; - - const signed char* sptr0 = img.row(y0) + x0 * elempack; - const signed char* sptr1 = img.row(y1) + x1 * elempack; - const signed char* sptr2 = img.row(y2) + x2 * elempack; - const signed char* sptr3 = img.row(y3) + x3 * elempack; - - if (elempack == 8) - { - __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); - __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); - __m128i _r2 = _mm_loadl_epi64((const __m128i*)sptr2); - __m128i _r3 = _mm_loadl_epi64((const __m128i*)sptr3); - __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); - __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); - _r0 = _mm_unpacklo_epi32(_r01, _r23); - _r1 = _mm_unpackhi_epi32(_r01, _r23); - _mm_storeu_si128((__m128i*)pp, _r0); - _mm_storeu_si128((__m128i*)(pp + 16), _r1); - pp += 32; - } - if (elempack == 1) - { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp[2] = sptr2[0]; - pp[3] = sptr3[0]; - pp += 4; - } - } - } - } -#endif // __SSE2__ - for (; jj + 1 < max_jj; jj += 2) - { - int dy0 = (j + jj) / outw; - int dy1 = (j + jj + 1) / outw; - int dx0 = (j + jj) % outw; - int dx1 = (j + jj + 1) % outw; - - if (dy0 == dy1) - { - int kk = 0; - if (elempack == 1) - { - for (; kk + 1 < max_kk; kk += 2) - { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - - int x00 = stride_w * dx0 + dilation_w * v0; - int y00 = stride_h * dy0 + dilation_h * u0; - int x10 = stride_w * dx0 + dilation_w * v1; - int y10 = stride_h * dy0 + dilation_h * u1; - - const signed char* sptr0 = img0.row(y00) + x00; - const signed char* sptr1 = img1.row(y10) + x10; - - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp[2] = sptr0[stride_w]; - pp[3] = sptr1[stride_w]; - pp += 4; - } - } - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; - - const Mat img = bottom_blob.channel(p); - - int x0 = stride_w * dx0 + dilation_w * v; - int y0 = stride_h * dy0 + dilation_h * u; - - const signed char* sptr = img.row(y0) + x0 * elempack; - -#if __SSE2__ - if (elempack == 8) - { - __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr); - __m128i _r1 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 8)); - __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); - _mm_storeu_si128((__m128i*)pp, _r01); - pp += 16; - } -#endif // __SSE2__ - if (elempack == 1) - { - pp[0] = sptr[0]; - pp[1] = sptr[stride_w]; - pp += 2; - } - } - } - else - { - int kk = 0; - if (elempack == 1) - { - for (; kk + 1 < max_kk; kk += 2) - { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - - int x00 = stride_w * dx0 + dilation_w * v0; - int x01 = stride_w * dx1 + dilation_w * v0; - int y00 = stride_h * dy0 + dilation_h * u0; - int y01 = stride_h * dy1 + dilation_h * u0; - int x10 = stride_w * dx0 + dilation_w * v1; - int x11 = stride_w * dx1 + dilation_w * v1; - int y10 = stride_h * dy0 + dilation_h * u1; - int y11 = stride_h * dy1 + dilation_h * u1; - - const signed char* sptr00 = img0.row(y00) + x00; - const signed char* sptr01 = img0.row(y01) + x01; - const signed char* sptr10 = img1.row(y10) + x10; - const signed char* sptr11 = img1.row(y11) + x11; - - pp[0] = sptr00[0]; - pp[1] = sptr10[0]; - pp[2] = sptr01[0]; - pp[3] = sptr11[0]; - pp += 4; - } - } - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; - - const Mat img = bottom_blob.channel(p); - - int x0 = stride_w * dx0 + dilation_w * v; - int x1 = stride_w * dx1 + dilation_w * v; - int y0 = stride_h * dy0 + dilation_h * u; - int y1 = stride_h * dy1 + dilation_h * u; - - const signed char* sptr0 = img.row(y0) + x0 * elempack; - const signed char* sptr1 = img.row(y1) + x1 * elempack; - -#if __SSE2__ - if (elempack == 8) - { - __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); - __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); - __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); - _mm_storeu_si128((__m128i*)pp, _r01); - pp += 16; - } -#endif // __SSE2__ - if (elempack == 1) - { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp += 2; - } - } - } - } - for (; jj < max_jj; jj++) - { - int dy = (j + jj) / outw; - int dx = (j + jj) % outw; - - int kk = 0; - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; - - const Mat img = bottom_blob.channel(p); - - int x = stride_w * dx + dilation_w * v; - int y = stride_h * dy + dilation_h * u; - - const signed char* sptr = img.row(y) + x * elempack; - -#if __SSE2__ - if (elempack == 8) - { - _mm_storel_epi64((__m128i*)pp, _mm_loadl_epi64((const __m128i*)sptr)); - pp += 8; - } -#endif // __SSE2__ - if (elempack == 1) - { - pp[0] = sptr[0]; - pp += 1; - } - } - } + convolution_im2col_input_tile_int8_impl(bottom_blob, B, j, max_jj, k, max_kk, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h); } static void convolution_im2col_gemm_transform_kernel_int8(const Mat& kernel, Mat& AT, int inch, int outch, int kernel_w, int kernel_h, const Option& opt) diff --git a/src/layer/x86/convolution_x86.cpp b/src/layer/x86/convolution_x86.cpp index 4bd6a4ef2bf..9dcd9f30a2c 100644 --- a/src/layer/x86/convolution_x86.cpp +++ b/src/layer/x86/convolution_x86.cpp @@ -40,6 +40,7 @@ namespace ncnn { #include "convolution_3x3_winograd.h" #include "convolution_packed.h" +#include "convolution_im2col_gemm.h" #if NCNN_INT8 #include "convolution_3x3_int8.h" @@ -74,7 +75,6 @@ Convolution_x86::Convolution_x86() activation = 0; nT = 0; convolution_dilation1 = 0; - gemm = 0; } static void convolution_transform_kernel_packed_sse(const Mat& weight_data, Mat& weight_data_tm, int num_input, int num_output, int kernel_w, int kernel_h, int elempack, int out_elempack) @@ -463,85 +463,28 @@ int Convolution_x86::create_pipeline(const Option& opt) if ((opt.use_sgemm_convolution && prefer_sgemm) || (kernel_w == 1 && kernel_h == 1)) { - const int maxk = kernel_w * kernel_h; + convolution_im2col_gemm_transform_kernel(weight_data, weight_sgemm_data, num_input, num_output, kernel_w, kernel_h, opt); - gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm); - - ncnn::ParamDict pd; - pd.set(2, 0); // transA - pd.set(3, 0); // transB - pd.set(4, 1); // constantA - pd.set(5, 0); // constantB - pd.set(6, 1); // constantC - pd.set(7, num_output); // M = outch - pd.set(8, 0); // N = size - pd.set(9, maxk * num_input); // K = maxk*inch - pd.set(10, bias_term ? 1 : -1); // constant_broadcast_type_C = (M) - pd.set(11, 1); // output_N1M - - gemm->load_param(pd); - - // maxk-inch-outch to pa-maxk-inch/pa-outch - Mat tmp; - { - Mat weight_data_r2 = weight_data.reshape(maxk, num_input, num_output); - - tmp.create(maxk * num_input, num_output); - - for (int q = 0; q < num_output; q += 1) - { - float* g00 = tmp.row(q); - - for (int p = 0; p + (elempack - 1) < num_input; p += elempack) - { - for (int k = 0; k < maxk; k++) - { - for (int i = 0; i < elempack; i++) - { - const float* k00 = weight_data_r2.channel(q).row(p + i); - g00[0] = k00[k]; - g00++; - } - } - } - } - } - - if (bias_term) - { - ncnn::Mat weights[2]; - weights[0] = tmp; - weights[1] = bias_data; - - gemm->load_model(ModelBinFromMatArray(weights)); - } - else - { - ncnn::Mat weights[1]; - weights[0] = tmp; + if (opt.lightmode) + weight_data.release(); - gemm->load_model(ModelBinFromMatArray(weights)); - } + return 0; + } - gemm->create_pipeline(opt); + if ((elempack == 16 && out_elempack == 1 && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) + || (elempack == 8 && out_elempack == 8 && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) + || (elempack == 8 && out_elempack == 8 && kernel_w == 2 && kernel_h == 2 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) + || (elempack == 1 && out_elempack == 8 && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) + || (elempack == 1 && out_elempack == 8 && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) + || (elempack == 8 && out_elempack == 1 && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) + || (elempack == 1 && out_elempack == 4 && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) + || (elempack == 1 && out_elempack == 4 && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2)) + { + convolution_transform_kernel_packed_sse(weight_data, weight_data_tm, num_input, num_output, kernel_w, kernel_h, elempack, out_elempack); } else { - if ((elempack == 16 && out_elempack == 1 && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) - || (elempack == 8 && out_elempack == 8 && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) - || (elempack == 8 && out_elempack == 8 && kernel_w == 2 && kernel_h == 2 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) - || (elempack == 1 && out_elempack == 8 && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) - || (elempack == 1 && out_elempack == 8 && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) - || (elempack == 8 && out_elempack == 1 && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) - || (elempack == 1 && out_elempack == 4 && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) - || (elempack == 1 && out_elempack == 4 && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2)) - { - convolution_transform_kernel_packed_sse(weight_data, weight_data_tm, num_input, num_output, kernel_w, kernel_h, elempack, out_elempack); - } - else - { - convolution_transform_kernel_packed(weight_data, weight_data_tm, num_input, num_output, kernel_w, kernel_h); - } + convolution_transform_kernel_packed(weight_data, weight_data_tm, num_input, num_output, kernel_w, kernel_h); } if (opt.lightmode) @@ -566,13 +509,6 @@ int Convolution_x86::destroy_pipeline(const Option& opt) convolution_dilation1 = 0; } - if (gemm) - { - gemm->destroy_pipeline(opt); - delete gemm; - gemm = 0; - } - return 0; } @@ -746,398 +682,130 @@ int Convolution_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option if ((opt.use_sgemm_convolution && prefer_sgemm) || (kernel_w == 1 && kernel_h == 1)) { - // im2col - Mat bottom_im2col; - if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) + int _nT = nT ? nT : opt.num_threads; + if (nT != 0 && opt.num_threads != nT) { - bottom_im2col = bottom_blob_bordered; - bottom_im2col.w = w * h; - bottom_im2col.h = 1; + // force num_threads the same as in create_pipeline + // so we could use pre-packed A/B from the same tile config + NCNN_LOGE("opt.num_threads %d changed, convolution gemm will use load-time value %d", opt.num_threads, nT); } - else if (kernel_w == 1 && kernel_h == 1) - { - const int size = outw * outh; - bottom_im2col.create(size, channels, elemsize, elempack, opt.workspace_allocator); - if (bottom_im2col.empty()) - return -100; + convolution_im2col_gemm(bottom_blob_bordered, top_blob, weight_sgemm_data, bias_data, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, _nT, opt); - const int gap = (w * stride_h - outw * stride_w) * elempack; + if (activation) + { + activation->forward_inplace(top_blob, opt); + } + return 0; + } #if __SSE2__ #if __AVX__ #if __AVX512F__ - if (elempack == 16) - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < channels; p++) - { - const float* sptr = bottom_blob_bordered.channel(p); - float* ptr = bottom_im2col.row(p); - - for (int i = 0; i < outh; i++) - { - for (int j = 0; j < outw; j++) - { - __m512 _val = _mm512_load_ps(sptr); - _mm512_store_ps(ptr, _val); - - sptr += stride_w * 16; - ptr += 16; - } - - sptr += gap; - } - } - } -#endif // __AVX512F__ - - if (elempack == 8) - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < channels; p++) - { - const float* sptr = bottom_blob_bordered.channel(p); - float* ptr = bottom_im2col.row(p); - - for (int i = 0; i < outh; i++) - { - for (int j = 0; j < outw; j++) - { - __m256 _val = _mm256_load_ps(sptr); - _mm256_store_ps(ptr, _val); - - sptr += stride_w * 8; - ptr += 8; - } - - sptr += gap; - } - } - } -#endif // __AVX__ - - if (elempack == 4) - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < channels; p++) - { - const float* sptr = bottom_blob_bordered.channel(p); - float* ptr = bottom_im2col.row(p); - - for (int i = 0; i < outh; i++) - { - for (int j = 0; j < outw; j++) - { - __m128 _val = _mm_load_ps(sptr); - _mm_store_ps(ptr, _val); - - sptr += stride_w * 4; - ptr += 4; - } - - sptr += gap; - } - } - } -#endif // __SSE2__ - - if (elempack == 1) - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < channels; p++) - { - const float* sptr = bottom_blob_bordered.channel(p); - float* ptr = bottom_im2col.row(p); - - for (int i = 0; i < outh; i++) - { - for (int j = 0; j < outw; j++) - { - ptr[0] = sptr[0]; - - sptr += stride_w; - ptr += 1; - } - - sptr += gap; - } - } - } - } - else + if (elempack == 16 && out_elempack == 1) + { + if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) { - const int size = outw * outh; - const int maxk = kernel_w * kernel_h; - - bottom_im2col.create(size, maxk * channels, elemsize, elempack, opt.workspace_allocator); - if (bottom_im2col.empty()) - return -100; - - const int gap = (w * stride_h - outw * stride_w) * elempack; + conv3x3s1_pack16to1_avx512(bottom_blob_bordered, top_blob, weight_data_tm, bias_data, opt); -#if __SSE2__ -#if __AVX__ -#if __AVX512F__ - if (elempack == 16) + if (activation) { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < channels; p++) - { - const Mat img = bottom_blob_bordered.channel(p); - float* ptr = bottom_im2col.row(p * maxk); - - for (int u = 0; u < kernel_h; u++) - { - for (int v = 0; v < kernel_w; v++) - { - const float* sptr = img.row(dilation_h * u) + dilation_w * v * 16; - - for (int i = 0; i < outh; i++) - { - for (int j = 0; j < outw; j++) - { - __m512 _val = _mm512_load_ps(sptr); - _mm512_store_ps(ptr, _val); - - sptr += stride_w * 16; - ptr += 16; - } - - sptr += gap; - } - } - } - } + activation->forward_inplace(top_blob, opt); } + return 0; + } + } #endif // __AVX512F__ - if (elempack == 8) - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < channels; p++) - { - const Mat img = bottom_blob_bordered.channel(p); - float* ptr = bottom_im2col.row(p * maxk); - - for (int u = 0; u < kernel_h; u++) - { - for (int v = 0; v < kernel_w; v++) - { - const float* sptr = img.row(dilation_h * u) + dilation_w * v * 8; - - for (int i = 0; i < outh; i++) - { - for (int j = 0; j < outw; j++) - { - __m256 _val = _mm256_load_ps(sptr); - _mm256_store_ps(ptr, _val); - - sptr += stride_w * 8; - ptr += 8; - } - - sptr += gap; - } - } - } - } - } -#endif // __AVX__ - - if (elempack == 4) - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < channels; p++) - { - const Mat img = bottom_blob_bordered.channel(p); - float* ptr = bottom_im2col.row(p * maxk); - - for (int u = 0; u < kernel_h; u++) - { - for (int v = 0; v < kernel_w; v++) - { - const float* sptr = img.row(dilation_h * u) + dilation_w * v * 4; - - for (int i = 0; i < outh; i++) - { - for (int j = 0; j < outw; j++) - { - __m128 _val = _mm_load_ps(sptr); - _mm_store_ps(ptr, _val); - - sptr += stride_w * 4; - ptr += 4; - } - - sptr += gap; - } - } - } - } - } -#endif // __SSE2__ + if (elempack == 8 && out_elempack == 8) + { + if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) + { + conv3x3s1_pack8_avx(bottom_blob_bordered, top_blob, weight_data_tm, bias_data, opt); - if (elempack == 1) + if (activation) { - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < channels; p++) - { - const Mat img = bottom_blob_bordered.channel(p); - float* ptr = bottom_im2col.row(p * maxk); - - for (int u = 0; u < kernel_h; u++) - { - for (int v = 0; v < kernel_w; v++) - { - const float* sptr = img.row(dilation_h * u) + dilation_w * v; - - for (int i = 0; i < outh; i++) - { - for (int j = 0; j < outw; j++) - { - ptr[0] = sptr[0]; - - sptr += stride_w; - ptr += 1; - } - - sptr += gap; - } - } - } - } + activation->forward_inplace(top_blob, opt); } + return 0; } - - // sgemm + if (kernel_w == 2 && kernel_h == 2 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) { - top_blob.w = outw * outh; - top_blob.h = 1; - } - Option opt_b = opt; - opt_b.blob_allocator = top_blob.allocator; - gemm->forward(bottom_im2col, top_blob, opt_b); - { - top_blob.w = outw; - top_blob.h = outh; - } + conv2x2s1_pack8_avx(bottom_blob_bordered, top_blob, weight_data_tm, bias_data, opt); - if (activation) - { - activation->forward_inplace(top_blob, opt); + if (activation) + { + activation->forward_inplace(top_blob, opt); + } + return 0; } } - else + + if (elempack == 1 && out_elempack == 8) { -#if __SSE2__ -#if __AVX__ -#if __AVX512F__ - if (elempack == 16 && out_elempack == 1) + if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) { - if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) - { - conv3x3s1_pack16to1_avx512(bottom_blob_bordered, top_blob, weight_data_tm, bias_data, opt); + conv3x3s1_pack1to8_avx(bottom_blob_bordered, top_blob, weight_data_tm, bias_data, opt); - if (activation) - { - activation->forward_inplace(top_blob, opt); - } - return 0; + if (activation) + { + activation->forward_inplace(top_blob, opt); } + return 0; } -#endif // __AVX512F__ - - if (elempack == 8 && out_elempack == 8) + if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) { - if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) - { - conv3x3s1_pack8_avx(bottom_blob_bordered, top_blob, weight_data_tm, bias_data, opt); + conv3x3s2_pack1to8_avx(bottom_blob_bordered, top_blob, weight_data_tm, bias_data, opt); - if (activation) - { - activation->forward_inplace(top_blob, opt); - } - return 0; - } - if (kernel_w == 2 && kernel_h == 2 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) + if (activation) { - conv2x2s1_pack8_avx(bottom_blob_bordered, top_blob, weight_data_tm, bias_data, opt); - - if (activation) - { - activation->forward_inplace(top_blob, opt); - } - return 0; + activation->forward_inplace(top_blob, opt); } + return 0; } + } - if (elempack == 1 && out_elempack == 8) + if (elempack == 8 && out_elempack == 1) + { + if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) { - if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) - { - conv3x3s1_pack1to8_avx(bottom_blob_bordered, top_blob, weight_data_tm, bias_data, opt); + conv3x3s1_pack8to1_avx(bottom_blob_bordered, top_blob, weight_data_tm, bias_data, opt); - if (activation) - { - activation->forward_inplace(top_blob, opt); - } - return 0; - } - if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) + if (activation) { - conv3x3s2_pack1to8_avx(bottom_blob_bordered, top_blob, weight_data_tm, bias_data, opt); - - if (activation) - { - activation->forward_inplace(top_blob, opt); - } - return 0; + activation->forward_inplace(top_blob, opt); } + return 0; } + } +#endif // __AVX__ - if (elempack == 8 && out_elempack == 1) + if (elempack == 1 && out_elempack == 4) + { + if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) { - if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) - { - conv3x3s1_pack8to1_avx(bottom_blob_bordered, top_blob, weight_data_tm, bias_data, opt); + conv3x3s1_pack1to4_sse(bottom_blob_bordered, top_blob, weight_data_tm, bias_data, opt); - if (activation) - { - activation->forward_inplace(top_blob, opt); - } - return 0; + if (activation) + { + activation->forward_inplace(top_blob, opt); } + return 0; } -#endif // __AVX__ - - if (elempack == 1 && out_elempack == 4) + if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) { - if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) - { - conv3x3s1_pack1to4_sse(bottom_blob_bordered, top_blob, weight_data_tm, bias_data, opt); + conv3x3s2_pack1to4_sse(bottom_blob_bordered, top_blob, weight_data_tm, bias_data, opt); - if (activation) - { - activation->forward_inplace(top_blob, opt); - } - return 0; - } - if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) + if (activation) { - conv3x3s2_pack1to4_sse(bottom_blob_bordered, top_blob, weight_data_tm, bias_data, opt); - - if (activation) - { - activation->forward_inplace(top_blob, opt); - } - return 0; + activation->forward_inplace(top_blob, opt); } + return 0; } + } #endif // __SSE2__ - convolution_packed(bottom_blob_bordered, top_blob, weight_data_tm, bias_data, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, activation_type, activation_params, opt); - } + convolution_packed(bottom_blob_bordered, top_blob, weight_data_tm, bias_data, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, activation_type, activation_params, opt); return 0; } diff --git a/src/layer/x86/convolution_x86.h b/src/layer/x86/convolution_x86.h index fdfa88f7374..3befb11f0a4 100644 --- a/src/layer/x86/convolution_x86.h +++ b/src/layer/x86/convolution_x86.h @@ -51,8 +51,6 @@ class Convolution_x86 : public Convolution // forwardDilation Layer* convolution_dilation1; - Layer* gemm; - #if NCNN_INT8 Mat scale_in_data; #endif diff --git a/tests/test_convolution_2.cpp b/tests/test_convolution_2.cpp index 5135f5bd780..7243dd64978 100644 --- a/tests/test_convolution_2.cpp +++ b/tests/test_convolution_2.cpp @@ -159,7 +159,8 @@ static int test_convolution_0() || test_convolution(15, 12, 19, 3, 4, 1, 2, 2, 1) || test_convolution(14, 14, 24, 31, 5, 1, 2, 2, 1) || test_convolution(12, 12, 20, 15, 6, 1, 1, 0, 0) - || test_convolution(11, 10, 12, 7, 4, 2, 1, 2, 1); + || test_convolution(11, 10, 12, 7, 4, 2, 1, 2, 1) + || test_convolution(1, 11, 48, 26, 7, 1, 2, 3, 1); } static int test_convolution_1()