Skip to content

Commit

Permalink
dispatch packb avx2
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Dec 11, 2024
1 parent 87f94c0 commit 3710641
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 0 deletions.
36 changes: 36 additions & 0 deletions src/layer/x86/gemm_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,12 @@ void gemm_transB_packed_tile_int8_avxvnni(const Mat& AT_tile, const Mat& BT_tile
#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__
void pack_A_tile_int8_avx2(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk);
void transpose_pack_A_tile_int8_avx2(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk);
void pack_B_tile_int8_avx2(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk);
void transpose_pack_B_tile_int8_avx2(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk);
void pack_A_tile_fp32_to_int8_avx2(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales);
void transpose_pack_A_tile_fp32_to_int8_avx2(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales);
void pack_B_tile_fp32_to_int8_avx2(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale);
void transpose_pack_B_tile_fp32_to_int8_avx2(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale);
void unpack_output_tile_int32_to_fp32_avx2(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta, int output_transpose);
void gemm_transB_packed_tile_int8_avx2(const Mat& AT_tile, const Mat& BT_tile, Mat& topT_tile, int i, int max_ii, int j, int max_jj, int k, int max_kk);
#endif
Expand Down Expand Up @@ -665,6 +669,14 @@ static void pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, in
}
#endif

#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__
if (ncnn::cpu_support_x86_avx2())
{
pack_B_tile_int8_avx2(B, BT, j, max_jj, k, max_kk);
return;
}
#endif

// NCNN_LOGE("pack_B_tile_int8");
// assert B.elempack == 1
// assert B.dims == 2
Expand Down Expand Up @@ -964,6 +976,14 @@ static void transpose_pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj,
}
#endif

#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__
if (ncnn::cpu_support_x86_avx2())
{
transpose_pack_B_tile_int8_avx2(B, BT, j, max_jj, k, max_kk);
return;
}
#endif

// NCNN_LOGE("transpose_pack_B_tile_int8");
// assert B.elempack == 1
// assert B.dims == 2
Expand Down Expand Up @@ -4668,6 +4688,14 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i
}
#endif

#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__
if (ncnn::cpu_support_x86_avx2())
{
pack_B_tile_fp32_to_int8_avx2(B, BT, j, max_jj, k, max_kk, scale);
return;
}
#endif

const int elempack = B.elempack;
const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w;

Expand Down Expand Up @@ -5684,6 +5712,14 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int
}
#endif

#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__
if (ncnn::cpu_support_x86_avx2())
{
transpose_pack_B_tile_fp32_to_int8_avx2(B, BT, j, max_jj, k, max_kk, scale);
return;
}
#endif

const int elempack = B.elempack;
const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w;

Expand Down
20 changes: 20 additions & 0 deletions src/layer/x86/gemm_x86_avx2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@ void transpose_pack_A_tile_int8_avx2(const Mat& A, Mat& AT, int i, int max_ii, i
transpose_pack_A_tile_int8(A, AT, i, max_ii, k, max_kk);
}

void pack_B_tile_int8_avx2(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk)
{
pack_B_tile_int8(B, BT, j, max_jj, k, max_kk);
}

void transpose_pack_B_tile_int8_avx2(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk)
{
transpose_pack_B_tile_int8(B, BT, j, max_jj, k, max_kk);
}

void pack_A_tile_fp32_to_int8_avx2(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales)
{
pack_A_tile_fp32_to_int8(A, AT, i, max_ii, k, max_kk, scales);
Expand All @@ -48,6 +58,16 @@ void transpose_pack_A_tile_fp32_to_int8_avx2(const Mat& A, Mat& AT, int i, int m
transpose_pack_A_tile_fp32_to_int8(A, AT, i, max_ii, k, max_kk, scales);
}

void pack_B_tile_fp32_to_int8_avx2(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale)
{
pack_B_tile_fp32_to_int8(B, BT, j, max_jj, k, max_kk, scale);
}

void transpose_pack_B_tile_fp32_to_int8_avx2(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale)
{
transpose_pack_B_tile_fp32_to_int8(B, BT, j, max_jj, k, max_kk, scale);
}

void unpack_output_tile_int32_to_fp32_avx2(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta, int output_transpose)
{
unpack_output_tile_int32_to_fp32(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales, alpha, beta, output_transpose);
Expand Down

0 comments on commit 3710641

Please sign in to comment.