diff --git a/src/layer/x86/gemm_int8.h b/src/layer/x86/gemm_int8.h index a9d5dea9040..05e0447545a 100644 --- a/src/layer/x86/gemm_int8.h +++ b/src/layer/x86/gemm_int8.h @@ -39,8 +39,12 @@ void gemm_transB_packed_tile_int8_avxvnni(const Mat& AT_tile, const Mat& BT_tile #if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ void pack_A_tile_int8_avx2(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk); void transpose_pack_A_tile_int8_avx2(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk); +void pack_B_tile_int8_avx2(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk); +void transpose_pack_B_tile_int8_avx2(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk); void pack_A_tile_fp32_to_int8_avx2(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales); void transpose_pack_A_tile_fp32_to_int8_avx2(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales); +void pack_B_tile_fp32_to_int8_avx2(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale); +void transpose_pack_B_tile_fp32_to_int8_avx2(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale); void unpack_output_tile_int32_to_fp32_avx2(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta, int output_transpose); void gemm_transB_packed_tile_int8_avx2(const Mat& AT_tile, const Mat& BT_tile, Mat& topT_tile, int i, int max_ii, int j, int max_jj, int k, int max_kk); #endif @@ -665,6 +669,14 @@ static void pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, in } #endif +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx2()) + { + pack_B_tile_int8_avx2(B, BT, j, max_jj, k, max_kk); + return; + } +#endif + // NCNN_LOGE("pack_B_tile_int8"); // assert B.elempack == 1 // assert B.dims == 2 @@ -964,6 +976,14 @@ static void transpose_pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, } #endif +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx2()) + { + transpose_pack_B_tile_int8_avx2(B, BT, j, max_jj, k, max_kk); + return; + } +#endif + // NCNN_LOGE("transpose_pack_B_tile_int8"); // assert B.elempack == 1 // assert B.dims == 2 @@ -4668,6 +4688,14 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i } #endif +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx2()) + { + pack_B_tile_fp32_to_int8_avx2(B, BT, j, max_jj, k, max_kk, scale); + return; + } +#endif + const int elempack = B.elempack; const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; @@ -5684,6 +5712,14 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int } #endif +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx2()) + { + transpose_pack_B_tile_fp32_to_int8_avx2(B, BT, j, max_jj, k, max_kk, scale); + return; + } +#endif + const int elempack = B.elempack; const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; diff --git a/src/layer/x86/gemm_x86_avx2.cpp b/src/layer/x86/gemm_x86_avx2.cpp index 050c65934fb..ccc161240c6 100644 --- a/src/layer/x86/gemm_x86_avx2.cpp +++ b/src/layer/x86/gemm_x86_avx2.cpp @@ -38,6 +38,16 @@ void transpose_pack_A_tile_int8_avx2(const Mat& A, Mat& AT, int i, int max_ii, i transpose_pack_A_tile_int8(A, AT, i, max_ii, k, max_kk); } +void pack_B_tile_int8_avx2(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +{ + pack_B_tile_int8(B, BT, j, max_jj, k, max_kk); +} + +void transpose_pack_B_tile_int8_avx2(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +{ + transpose_pack_B_tile_int8(B, BT, j, max_jj, k, max_kk); +} + void pack_A_tile_fp32_to_int8_avx2(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) { pack_A_tile_fp32_to_int8(A, AT, i, max_ii, k, max_kk, scales); @@ -48,6 +58,16 @@ void transpose_pack_A_tile_fp32_to_int8_avx2(const Mat& A, Mat& AT, int i, int m transpose_pack_A_tile_fp32_to_int8(A, AT, i, max_ii, k, max_kk, scales); } +void pack_B_tile_fp32_to_int8_avx2(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ + pack_B_tile_fp32_to_int8(B, BT, j, max_jj, k, max_kk, scale); +} + +void transpose_pack_B_tile_fp32_to_int8_avx2(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ + transpose_pack_B_tile_fp32_to_int8(B, BT, j, max_jj, k, max_kk, scale); +} + void unpack_output_tile_int32_to_fp32_avx2(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta, int output_transpose) { unpack_output_tile_int32_to_fp32(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales, alpha, beta, output_transpose);