From 1c7af00499380dfe760e6131bea63db210bdab2e Mon Sep 17 00:00:00 2001 From: nihui Date: Sat, 12 Oct 2024 19:02:55 +0800 Subject: [PATCH 01/15] gemm int8 quantization (#5706) * quantize gemm * write gemm quantize scales * update doc * less openmp args * x86 riscv fallback * skip gemm vulkan int8 * fix noint8 test, fix arm bf16 test * enable vfpv4 on neon build only * fix gemm vulkan without C * fp16 pack8 output * enable elempack=8 only for asimdhp+ * tiled gemm int8 test * opt arm64 tiles, fix asimdhp dispatch --- CMakeLists.txt | 28 +- cmake/ncnn_add_layer.cmake | 96 +- docs/developer-guide/operators.md | 7 +- src/layer/arm/gemm_arm.cpp | 1087 ++ src/layer/arm/gemm_arm.h | 6 + src/layer/arm/gemm_arm_asimddp.cpp | 145 + src/layer/arm/gemm_arm_asimdhp.cpp | 21 + src/layer/arm/gemm_arm_i8mm.cpp | 115 + src/layer/arm/gemm_arm_vfpv4.cpp | 51 + src/layer/arm/gemm_int8.h | 14687 +++++++++++++++++++++++++++ src/layer/arm/gemm_int8_bf16s.h | 8566 ++++++++++++++++ src/layer/arm/gemm_int8_fp16s.h | 10368 +++++++++++++++++++ src/layer/gemm.cpp | 451 +- src/layer/gemm.h | 12 + src/layer/riscv/gemm_riscv.cpp | 15 + src/layer/vulkan/gemm_vulkan.cpp | 157 +- src/layer/vulkan/gemm_vulkan.h | 2 + src/layer/x86/gemm_x86.cpp | 15 + tests/test_gemm.cpp | 2 +- tests/test_gemm_1.cpp | 4 +- tests/test_gemm_3.cpp | 336 + tests/test_gemm_4.cpp | 140 + tools/modelwriter.h | 18 + tools/quantize/ncnn2int8.cpp | 109 + 24 files changed, 36265 insertions(+), 173 deletions(-) create mode 100644 src/layer/arm/gemm_arm_asimddp.cpp create mode 100644 src/layer/arm/gemm_arm_i8mm.cpp create mode 100644 src/layer/arm/gemm_int8.h create mode 100644 src/layer/arm/gemm_int8_bf16s.h create mode 100644 src/layer/arm/gemm_int8_fp16s.h create mode 100644 tests/test_gemm_3.cpp create mode 100644 tests/test_gemm_4.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 0f32a80c86ee..875a8d06598f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -162,21 +162,25 @@ if((IOS AND CMAKE_OSX_ARCHITECTURES MATCHES "arm") endif() if(CMAKE_SIZEOF_VOID_P EQUAL 4 AND NOT NCNN_TARGET_ILP32) - if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC" OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC")) - set(CMAKE_REQUIRED_FLAGS "/arch:VFPv4") - check_cxx_source_compiles("#include \nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4) + check_cxx_source_compiles("#include \nint main() { float32x4_t _s, _a, _b; _s = vmlaq_f32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM_NEON) - unset(CMAKE_REQUIRED_FLAGS) - else() - set(CMAKE_REQUIRED_FLAGS "-mfpu=neon-vfpv4") - check_cxx_source_compiles("#include \nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4) + if(NCNN_COMPILER_SUPPORT_ARM_NEON) + if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC" OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC")) + set(CMAKE_REQUIRED_FLAGS "/arch:VFPv4") + check_cxx_source_compiles("#include \nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4) - if(NOT NCNN_COMPILER_SUPPORT_ARM_VFPV4) - set(CMAKE_REQUIRED_FLAGS "-mfpu=neon-vfpv4 -mfp16-format=ieee") - check_cxx_source_compiles("#include \nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4_FP16) - endif() + unset(CMAKE_REQUIRED_FLAGS) + else() + set(CMAKE_REQUIRED_FLAGS "-mfpu=neon-vfpv4") + check_cxx_source_compiles("#include \nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4) - unset(CMAKE_REQUIRED_FLAGS) + if(NOT NCNN_COMPILER_SUPPORT_ARM_VFPV4) + set(CMAKE_REQUIRED_FLAGS "-mfpu=neon-vfpv4 -mfp16-format=ieee") + check_cxx_source_compiles("#include \nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4_FP16) + endif() + + unset(CMAKE_REQUIRED_FLAGS) + endif() endif() if(NCNN_COMPILER_SUPPORT_ARM_VFPV4 OR NCNN_COMPILER_SUPPORT_ARM_VFPV4_FP16) diff --git a/cmake/ncnn_add_layer.cmake b/cmake/ncnn_add_layer.cmake index 6ce5feadbf31..7f334fb0b68d 100644 --- a/cmake/ncnn_add_layer.cmake +++ b/cmake/ncnn_add_layer.cmake @@ -144,25 +144,25 @@ macro(ncnn_add_layer class) if(NCNN_RUNTIME_CPU AND NCNN_AVX) ncnn_add_arch_opt_layer(${class} avx "/arch:AVX /D__SSSE3__ /D__SSE4_1__") endif() - if(NCNN_AVX512VNNI) + if(NCNN_RUNTIME_CPU AND NCNN_AVX512VNNI) ncnn_add_arch_opt_source(${class} avx512vnni "/arch:AVX512 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512VNNI__") endif() - if(NCNN_AVX512BF16) + if(NCNN_RUNTIME_CPU AND NCNN_AVX512BF16) ncnn_add_arch_opt_source(${class} avx512bf16 "/arch:AVX512 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512BF16__") endif() - if(NCNN_AVX512FP16) + if(NCNN_RUNTIME_CPU AND NCNN_AVX512FP16) ncnn_add_arch_opt_source(${class} avx512fp16 "/arch:AVX512 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512FP16__") endif() - if(NCNN_AVXVNNI) + if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNI) ncnn_add_arch_opt_source(${class} avxvnni "/arch:AVX2 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__") endif() - if(NCNN_AVX2) + if(NCNN_RUNTIME_CPU AND NCNN_AVX2) ncnn_add_arch_opt_source(${class} avx2 "/arch:AVX2 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__") endif() - if(NCNN_XOP) + if(NCNN_RUNTIME_CPU AND NCNN_XOP) ncnn_add_arch_opt_source(${class} xop "/arch:AVX /D__SSSE3__ /D__SSE4_1__ /D__XOP__") endif() - if(NCNN_F16C) + if(NCNN_RUNTIME_CPU AND NCNN_F16C) ncnn_add_arch_opt_source(${class} f16c "/arch:AVX /D__SSSE3__ /D__SSE4_1__ /D__F16C__") endif() elseif(CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC") @@ -175,25 +175,25 @@ macro(ncnn_add_layer class) if(NCNN_RUNTIME_CPU AND NCNN_AVX) ncnn_add_arch_opt_layer(${class} avx "/arch:AVX /D__SSSE3__ /D__SSE4_1__") endif() - if(NCNN_AVX512VNNI) + if(NCNN_RUNTIME_CPU AND NCNN_AVX512VNNI) ncnn_add_arch_opt_source(${class} avx512vnni "/arch:AVX512 -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mfma -mf16c -mavx512vnni /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512VNNI__") endif() - if(NCNN_AVX512BF16) + if(NCNN_RUNTIME_CPU AND NCNN_AVX512BF16) ncnn_add_arch_opt_source(${class} avx512bf16 "/arch:AVX512 -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mfma -mf16c -mavx512bf16 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512BF16__") endif() - if(NCNN_AVX512FP16) + if(NCNN_RUNTIME_CPU AND NCNN_AVX512FP16) ncnn_add_arch_opt_source(${class} avx512fp16 "/arch:AVX512 -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mfma -mf16c -mavx512fp16 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512FP16__") endif() - if(NCNN_AVXVNNI) + if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNI) ncnn_add_arch_opt_source(${class} avxvnni "/arch:AVX2 -mfma -mf16c -mavxvnni /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__") endif() - if(NCNN_AVX2) + if(NCNN_RUNTIME_CPU AND NCNN_AVX2) ncnn_add_arch_opt_source(${class} avx2 "/arch:AVX2 -mfma -mf16c /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__") endif() - if(NCNN_XOP) + if(NCNN_RUNTIME_CPU AND NCNN_XOP) ncnn_add_arch_opt_source(${class} xop "/arch:AVX -mxop /D__SSSE3__ /D__SSE4_1__ /D__XOP__") endif() - if(NCNN_F16C) + if(NCNN_RUNTIME_CPU AND NCNN_F16C) ncnn_add_arch_opt_source(${class} f16c "/arch:AVX -mf16c /D__SSSE3__ /D__SSE4_1__ /D__F16C__") endif() else() @@ -206,25 +206,25 @@ macro(ncnn_add_layer class) if(NCNN_RUNTIME_CPU AND NCNN_AVX) ncnn_add_arch_opt_layer(${class} avx "-mavx") endif() - if(NCNN_AVX512VNNI) + if(NCNN_RUNTIME_CPU AND NCNN_AVX512VNNI) ncnn_add_arch_opt_source(${class} avx512vnni "-mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mfma -mf16c -mavx512vnni") endif() - if(NCNN_AVX512BF16) + if(NCNN_RUNTIME_CPU AND NCNN_AVX512BF16) ncnn_add_arch_opt_source(${class} avx512bf16 "-mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mfma -mf16c -mavx512bf16") endif() - if(NCNN_AVX512FP16) + if(NCNN_RUNTIME_CPU AND NCNN_AVX512FP16) ncnn_add_arch_opt_source(${class} avx512fp16 "-mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mfma -mf16c -mavx512fp16") endif() - if(NCNN_AVXVNNI) + if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNI) ncnn_add_arch_opt_source(${class} avxvnni "-mavx2 -mfma -mf16c -mavxvnni") endif() - if(NCNN_AVX2) + if(NCNN_RUNTIME_CPU AND NCNN_AVX2) ncnn_add_arch_opt_source(${class} avx2 "-mavx2 -mfma -mf16c") endif() - if(NCNN_XOP) + if(NCNN_RUNTIME_CPU AND NCNN_XOP) ncnn_add_arch_opt_source(${class} xop "-mavx -mxop") endif() - if(NCNN_F16C) + if(NCNN_RUNTIME_CPU AND NCNN_F16C) ncnn_add_arch_opt_source(${class} f16c "-mavx -mf16c") endif() endif() @@ -254,28 +254,28 @@ macro(ncnn_add_layer class) if(NCNN_ARM82) ncnn_add_arch_opt_source(${class} asimdhp "/arch:armv8.2 /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC") endif() - if(NCNN_ARM82DOT) + if(NCNN_RUNTIME_CPU AND NCNN_ARM82DOT) ncnn_add_arch_opt_source(${class} asimddp "/arch:armv8.2 /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_DOTPROD") endif() - if(NCNN_ARM82FP16FML) + if(NCNN_RUNTIME_CPU AND NCNN_ARM82FP16FML) ncnn_add_arch_opt_source(${class} asimdfhm "/arch:armv8.2 /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_FP16_FML") endif() - if(NCNN_ARM84BF16) + if(NCNN_RUNTIME_CPU AND NCNN_ARM84BF16) ncnn_add_arch_opt_source(${class} bf16 "/arch:armv8.4 /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_DOTPROD /D__ARM_FEATURE_FP16_FML /D__ARM_FEATURE_BF16_VECTOR_ARITHMETIC") endif() - if(NCNN_ARM84I8MM) + if(NCNN_RUNTIME_CPU AND NCNN_ARM84I8MM) ncnn_add_arch_opt_source(${class} i8mm "/arch:armv8.4 /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_DOTPROD /D__ARM_FEATURE_FP16_FML /D__ARM_FEATURE_MATMUL_INT8") endif() # TODO add support for sve family - if(NCNN_ARM86SVE) + if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVE) endif() - if(NCNN_ARM86SVE2) + if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVE2) endif() - if(NCNN_ARM86SVEBF16) + if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEBF16) endif() - if(NCNN_ARM86SVEI8MM) + if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEI8MM) endif() - if(NCNN_ARM86SVEF32MM) + if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEF32MM) endif() elseif(CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC") if(NCNN_VFPV4) @@ -284,28 +284,28 @@ macro(ncnn_add_layer class) if(NCNN_ARM82) ncnn_add_arch_opt_source(${class} asimdhp "/arch:armv8.2 -march=armv8.2-a+fp16 /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC") endif() - if(NCNN_ARM82DOT) + if(NCNN_RUNTIME_CPU AND NCNN_ARM82DOT) ncnn_add_arch_opt_source(${class} asimddp "/arch:armv8.2 -march=armv8.2-a+fp16+dotprod /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_DOTPROD") endif() - if(NCNN_ARM82FP16FML) + if(NCNN_RUNTIME_CPU AND NCNN_ARM82FP16FML) ncnn_add_arch_opt_source(${class} asimdfhm "/arch:armv8.2 -march=armv8.2-a+fp16+fp16fml /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_FP16_FML") endif() - if(NCNN_ARM84BF16) + if(NCNN_RUNTIME_CPU AND NCNN_ARM84BF16) ncnn_add_arch_opt_source(${class} bf16 "/arch:armv8.4 -march=armv8.4-a+fp16+dotprod+bf16 /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_DOTPROD /D__ARM_FEATURE_FP16_FML /D__ARM_FEATURE_BF16_VECTOR_ARITHMETIC") endif() - if(NCNN_ARM84I8MM) + if(NCNN_RUNTIME_CPU AND NCNN_ARM84I8MM) ncnn_add_arch_opt_source(${class} i8mm "/arch:armv8.4 -march=armv8.4-a+fp16+dotprod+i8mm /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_DOTPROD /D__ARM_FEATURE_FP16_FML /D__ARM_FEATURE_MATMUL_INT8") endif() # TODO add support for sve family - if(NCNN_ARM86SVE) + if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVE) endif() - if(NCNN_ARM86SVE2) + if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVE2) endif() - if(NCNN_ARM86SVEBF16) + if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEBF16) endif() - if(NCNN_ARM86SVEI8MM) + if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEI8MM) endif() - if(NCNN_ARM86SVEF32MM) + if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEF32MM) endif() else() if(NCNN_VFPV4) @@ -314,31 +314,31 @@ macro(ncnn_add_layer class) if(NCNN_ARM82) ncnn_add_arch_opt_source(${class} asimdhp "-march=armv8.2-a+fp16") endif() - if(NCNN_ARM82DOT) + if(NCNN_RUNTIME_CPU AND NCNN_ARM82DOT) ncnn_add_arch_opt_source(${class} asimddp "-march=armv8.2-a+fp16+dotprod") endif() - if(NCNN_ARM82FP16FML) + if(NCNN_RUNTIME_CPU AND NCNN_ARM82FP16FML) ncnn_add_arch_opt_source(${class} asimdfhm "-march=armv8.2-a+fp16+fp16fml") endif() - if(NCNN_ARM84BF16) + if(NCNN_RUNTIME_CPU AND NCNN_ARM84BF16) ncnn_add_arch_opt_source(${class} bf16 "-march=armv8.4-a+fp16+dotprod+bf16") endif() - if(NCNN_ARM84I8MM) + if(NCNN_RUNTIME_CPU AND NCNN_ARM84I8MM) ncnn_add_arch_opt_source(${class} i8mm "-march=armv8.4-a+fp16+dotprod+i8mm") endif() - if(NCNN_ARM86SVE) + if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVE) ncnn_add_arch_opt_source(${class} sve "-march=armv8.6-a+fp16+dotprod+sve") endif() - if(NCNN_ARM86SVE2) + if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVE2) ncnn_add_arch_opt_source(${class} sve2 "-march=armv8.6-a+fp16+dotprod+sve2") endif() - if(NCNN_ARM86SVEBF16) + if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEBF16) ncnn_add_arch_opt_source(${class} svebf16 "-march=armv8.6-a+fp16+dotprod+sve+bf16") endif() - if(NCNN_ARM86SVEI8MM) + if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEI8MM) ncnn_add_arch_opt_source(${class} svei8mm "-march=armv8.6-a+fp16+dotprod+sve+i8mm") endif() - if(NCNN_ARM86SVEF32MM) + if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEF32MM) ncnn_add_arch_opt_source(${class} svef32mm "-march=armv8.6-a+fp16+dotprod+sve+f32mm") endif() endif() diff --git a/docs/developer-guide/operators.md b/docs/developer-guide/operators.md index de4d6b428e99..28f1ce626466 100644 --- a/docs/developer-guide/operators.md +++ b/docs/developer-guide/operators.md @@ -942,15 +942,18 @@ y = (gemm(a, b) + c * beta) * alpha | 12 | output_elempack | int | 0 | | | 13 | output_elemtype | int | 0 | | | 14 | output_transpose | int| 0 | | +| 18 | int8_scale_term | int | 0 | | | 20 | constant_TILE_M | int | 0 | | | 21 | constant_TILE_N | int | 0 | | | 22 | constant_TILE_K | int | 0 | | | weight | type | shape | | ------------- | ----- | --------------------- | -| A_data | float | [M, K] or [K, M] | -| B_data | float | [N, K] or [K, N] | +| A_data | float/fp16/int8 | [M, K] or [K, M] | +| B_data | float/fp16/int8 | [N, K] or [K, N] | | C_data | float | [1], [M] or [N] or [1, M] or [N,1] or [N, M] | +| A_data_int8_scales| float | [M] | +| B_data_int8_scales| float | [1] | # GridSample ``` diff --git a/src/layer/arm/gemm_arm.cpp b/src/layer/arm/gemm_arm.cpp index e798680e2afa..7607d8f523e5 100644 --- a/src/layer/arm/gemm_arm.cpp +++ b/src/layer/arm/gemm_arm.cpp @@ -16,6 +16,7 @@ #if __ARM_NEON #include +#include "neon_mathfun.h" #endif // __ARM_NEON #include "arm_usability.h" @@ -29,6 +30,13 @@ namespace ncnn { #include "gemm_bf16s.h" #endif +#if NCNN_INT8 +#include "gemm_int8.h" +#if NCNN_BF16 +#include "gemm_int8_bf16s.h" +#endif +#endif + Gemm_arm::Gemm_arm() { #if __ARM_NEON @@ -2461,6 +2469,79 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons int kk = 0; for (; kk < max_kk; kk += 1) { +#if NCNN_GNU_INLINE_ASM +#if __aarch64__ + asm volatile( + "prfm pldl1keep, [%0, #128] \n" + "ld1 {v2.4s}, [%0], #16 \n" + "prfm pldl1keep, [%1, #256] \n" + "ld1 {v0.4s, v1.4s}, [%1], #32 \n" + "fmla %2.4s, v2.4s, v0.s[0] \n" + "fmla %3.4s, v2.4s, v0.s[1] \n" + "fmla %4.4s, v2.4s, v0.s[2] \n" + "fmla %5.4s, v2.4s, v0.s[3] \n" + "fmla %6.4s, v2.4s, v1.s[0] \n" + "fmla %7.4s, v2.4s, v1.s[1] \n" + "fmla %8.4s, v2.4s, v1.s[2] \n" + "fmla %9.4s, v2.4s, v1.s[3] \n" + : "=r"(pA), + "=r"(pB), + "=w"(_sum0), + "=w"(_sum1), + "=w"(_sum2), + "=w"(_sum3), + "=w"(_sum4), + "=w"(_sum5), + "=w"(_sum6), + "=w"(_sum7) + : "0"(pA), + "1"(pB), + "2"(_sum0), + "3"(_sum1), + "4"(_sum2), + "5"(_sum3), + "6"(_sum4), + "7"(_sum5), + "8"(_sum6), + "9"(_sum7) + : "memory", "v0", "v1", "v2", "v3"); +#else + asm volatile( + "pld [%0, #128] \n" + "vld1.f32 {d4-d5}, [%0]! \n" + "pld [%1, #256] \n" + "vld1.f32 {d0-d3}, [%1]! \n" + "vmla.f32 %q2, q2, d0[0] \n" + "vmla.f32 %q3, q2, d0[1] \n" + "vmla.f32 %q4, q2, d1[0] \n" + "vmla.f32 %q5, q2, d1[1] \n" + "vmla.f32 %q6, q2, d2[0] \n" + "vmla.f32 %q7, q2, d2[1] \n" + "vmla.f32 %q8, q2, d3[0] \n" + "vmla.f32 %q9, q2, d3[1] \n" + : "=r"(pA), + "=r"(pB), + "=w"(_sum0), + "=w"(_sum1), + "=w"(_sum2), + "=w"(_sum3), + "=w"(_sum4), + "=w"(_sum5), + "=w"(_sum6), + "=w"(_sum7) + : "0"(pA), + "1"(pB), + "2"(_sum0), + "3"(_sum1), + "4"(_sum2), + "5"(_sum3), + "6"(_sum4), + "7"(_sum5), + "8"(_sum6), + "9"(_sum7) + : "memory", "q0", "q1", "q2"); +#endif +#else // NCNN_GNU_INLINE_ASM float32x4_t _pA = vld1q_f32(pA); float32x4_t _pB0 = vld1q_f32(pB); float32x4_t _pB1 = vld1q_f32(pB + 4); @@ -2487,6 +2568,7 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons pA += 4; pB += 8; +#endif // NCNN_GNU_INLINE_ASM } if (k_end) @@ -4164,6 +4246,13 @@ static int gemm_AT_BT_arm(const Mat& AT, const Mat& BT, const Mat& C, Mat& top_b int Gemm_arm::create_pipeline(const Option& opt) { +#if NCNN_INT8 + if (int8_scale_term) + { + return create_pipeline_int8(opt); + } +#endif + #if NCNN_ARM82 if (cpu_support_arm_asimdhp() && opt.use_fp16_storage) { @@ -4311,6 +4400,14 @@ int Gemm_arm::create_pipeline(const Option& opt) int Gemm_arm::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { +#if NCNN_INT8 + if (int8_scale_term) + { + return forward_int8(bottom_blobs, top_blobs, opt); + // return Gemm::forward_int8(bottom_blobs, top_blobs, opt); + } +#endif + const Mat& bottom_blob = constantA ? AT_data : bottom_blobs[0]; int elembits = bottom_blob.elembits(); @@ -5199,4 +5296,994 @@ int Gemm_arm::forward_bf16s(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + int M; + int N; + if (constantA && constantB) + { + M = constantM; + N = constantN; + } + else if (constantA) + { + const Mat& B = bottom_blobs[0]; + M = constantM; + N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + } + else if (constantB) + { + const Mat& A = bottom_blobs[0]; + M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + N = constantN; + } + else + { + const Mat& A = bottom_blobs[0]; + const Mat& B = bottom_blobs[1]; + M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + } + + Mat C; + int broadcast_type_C = 0; + if (constantC) + { + C = CT_data; + broadcast_type_C = constant_broadcast_type_C; + } + else + { + if (constantA && constantB) + { + C = bottom_blobs.size() == 1 ? bottom_blobs[0] : Mat(); + } + else if (constantA) + { + C = bottom_blobs.size() == 2 ? bottom_blobs[1] : Mat(); + } + else if (constantB) + { + C = bottom_blobs.size() == 2 ? bottom_blobs[1] : Mat(); + } + else + { + C = bottom_blobs.size() == 3 ? bottom_blobs[2] : Mat(); + } + + if (!C.empty()) + { + if (C.dims == 1 && C.w == 1) + { + // scalar + broadcast_type_C = 0; + } + if (C.dims == 1 && C.w * C.elempack == M) + { + // M + // auto broadcast from h to w is the ncnn-style convention + broadcast_type_C = 1; + } + if (C.dims == 1 && C.w * C.elempack == N) + { + // N + broadcast_type_C = 4; + } + if (C.dims == 2 && C.w == 1 && C.h * C.elempack == M) + { + // Mx1 + broadcast_type_C = 2; + } + if (C.dims == 2 && C.w == N && C.h * C.elempack == M) + { + // MxN + broadcast_type_C = 3; + } + if (C.dims == 2 && C.w == N && C.h * C.elempack == 1) + { + // 1xN + broadcast_type_C = 4; + } + } + } + + int out_elempack = 1; +#if __ARM_NEON + if (opt.use_packing_layout) + { + int outh = output_transpose ? N : M; + out_elempack = outh % 4 == 0 ? 4 : 1; +#if NCNN_ARM82 + if (cpu_support_arm_asimdhp() && opt.use_fp16_arithmetic) + { + // TODO use output_elemtype + out_elempack = outh % 8 == 0 ? 8 : outh % 4 == 0 ? 4 : 1; + } +#endif + } +#endif // __ARM_NEON + + // FIXME use output_elempack + // int output_elempack = out_elempack > 4 ? 4 : out_elempack; + + if (output_elempack) + out_elempack = output_elempack; + size_t out_elemsize = 4u * out_elempack; + + // FIXME use output_elemtype instead of input_elemtype + int output_elemtype = input_elemtype; + + // TODO use output_elemtype + if (opt.use_bf16_storage) + { + out_elemsize = 2u * out_elempack; + } +#if NCNN_VFPV4 + else if (support_fp16_storage && opt.use_fp16_storage) + { + out_elemsize = 2u * out_elempack; + } +#endif + + Mat& top_blob = top_blobs[0]; + if (output_transpose) + { + if (output_N1M) + top_blob.create(M, 1, N / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + else + top_blob.create(M, N / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + } + else + { + if (output_N1M) + top_blob.create(N, 1, M / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + else + top_blob.create(N, M / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + } + if (top_blob.empty()) + return -100; + + int _nT = nT ? nT : opt.num_threads; + if (nT != 0 && opt.num_threads != nT) + { + // force num_threads the same as in create_pipeline + // so we could use pre-packed A/B from the same tile config + NCNN_LOGE("opt.num_threads %d changed, gemm will use load-time value %d", opt.num_threads, nT); + } + + int ret = 0; + if (constantA && constantB) + { + ret = gemm_AT_BT_arm_int8(AT_data, A_data_int8_scales, BT_data, B_data_int8_scale, C, top_blob, broadcast_type_C, constantM, constantN, constantK, output_transpose, alpha, beta, input_elemtype, output_elemtype, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt); + } + else if (constantA) + { + const Mat& B = bottom_blobs[0]; + ret = gemm_AT_arm_int8(AT_data, A_data_int8_scales, B, C, top_blob, broadcast_type_C, constantM, constantK, transB, output_transpose, alpha, beta, input_elemtype, output_elemtype, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt); + } + else if (constantB) + { + const Mat& A = bottom_blobs[0]; + ret = gemm_BT_arm_int8(A, BT_data, B_data_int8_scale, C, top_blob, broadcast_type_C, constantN, constantK, transA, output_transpose, alpha, beta, input_elemtype, output_elemtype, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt); + } + else + { + const Mat& A = bottom_blobs[0]; + const Mat& B = bottom_blobs[1]; + ret = gemm_arm_int8(A, B, C, top_blob, broadcast_type_C, transA, transB, output_transpose, alpha, beta, input_elemtype, output_elemtype, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt); + } + + return ret; +} +#endif + } // namespace ncnn diff --git a/src/layer/arm/gemm_arm.h b/src/layer/arm/gemm_arm.h index 0c1eab108baf..4b7d8ab24258 100644 --- a/src/layer/arm/gemm_arm.h +++ b/src/layer/arm/gemm_arm.h @@ -41,12 +41,18 @@ class Gemm_arm : public Gemm int create_pipeline_bf16s(const Option& opt); int forward_bf16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; #endif +#if NCNN_INT8 + int create_pipeline_int8(const Option& opt); + int forward_int8(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; +#endif public: int nT; Mat AT_data; Mat BT_data; Mat CT_data; + + int input_elemtype; // 0=auto 1=fp32 2=fp16 3=bf16 }; } // namespace ncnn diff --git a/src/layer/arm/gemm_arm_asimddp.cpp b/src/layer/arm/gemm_arm_asimddp.cpp new file mode 100644 index 000000000000..de4689988148 --- /dev/null +++ b/src/layer/arm/gemm_arm_asimddp.cpp @@ -0,0 +1,145 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "cpu.h" +#include "mat.h" +#include "arm_usability.h" + +namespace ncnn { + +#include "gemm_int8.h" +#include "gemm_int8_fp16s.h" + +#if NCNN_BF16 +#include "gemm_int8_bf16s.h" +#endif + +void pack_A_tile_int8_asimddp(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) +{ + pack_A_tile_int8(A, AT, i, max_ii, k, max_kk); +} + +void transpose_pack_A_tile_int8_asimddp(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) +{ + transpose_pack_A_tile_int8(A, AT, i, max_ii, k, max_kk); +} + +void pack_B_tile_int8_asimddp(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +{ + pack_B_tile_int8(B, BT, j, max_jj, k, max_kk); +} + +void transpose_pack_B_tile_int8_asimddp(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +{ + transpose_pack_B_tile_int8(B, BT, j, max_jj, k, max_kk); +} + +void pack_A_tile_fp32_to_int8_asimddp(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ + pack_A_tile_fp32_to_int8(A, AT, i, max_ii, k, max_kk, scales); +} + +void transpose_pack_A_tile_fp32_to_int8_asimddp(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ + transpose_pack_A_tile_fp32_to_int8(A, AT, i, max_ii, k, max_kk, scales); +} + +void pack_B_tile_fp32_to_int8_asimddp(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ + pack_B_tile_fp32_to_int8(B, BT, j, max_jj, k, max_kk, scale); +} + +void transpose_pack_B_tile_fp32_to_int8_asimddp(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ + transpose_pack_B_tile_fp32_to_int8(B, BT, j, max_jj, k, max_kk, scale); +} + +void unpack_output_tile_int32_to_fp32_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta) +{ + unpack_output_tile_int32_to_fp32(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales, alpha, beta); +} + +void transpose_unpack_output_tile_int32_to_fp32_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta) +{ + transpose_unpack_output_tile_int32_to_fp32(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales, alpha, beta); +} + +void gemm_transB_packed_tile_int8_asimddp(const Mat& AT_tile, const Mat& BT_tile, Mat& topT_tile, int i, int max_ii, int j, int max_jj, int k, int max_kk) +{ + gemm_transB_packed_tile_int8(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); +} + +void pack_A_tile_fp16_to_int8_asimddp(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ + pack_A_tile_fp16_to_int8(A, AT, i, max_ii, k, max_kk, scales); +} + +void transpose_pack_A_tile_fp16_to_int8_asimddp(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ + transpose_pack_A_tile_fp16_to_int8(A, AT, i, max_ii, k, max_kk, scales); +} + +void pack_B_tile_fp16_to_int8_asimddp(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ + pack_B_tile_fp16_to_int8(B, BT, j, max_jj, k, max_kk, scale); +} + +void transpose_pack_B_tile_fp16_to_int8_asimddp(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ + transpose_pack_B_tile_fp16_to_int8(B, BT, j, max_jj, k, max_kk, scale); +} + +void unpack_output_tile_int32_to_fp16_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta) +{ + unpack_output_tile_int32_to_fp16(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales, alpha, beta); +} + +void transpose_unpack_output_tile_int32_to_fp16_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta) +{ + transpose_unpack_output_tile_int32_to_fp16(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales, alpha, beta); +} + +#if NCNN_BF16 +void pack_A_tile_bf16_to_int8_asimddp(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ + pack_A_tile_bf16_to_int8(A, AT, i, max_ii, k, max_kk, scales); +} + +void transpose_pack_A_tile_bf16_to_int8_asimddp(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ + transpose_pack_A_tile_bf16_to_int8(A, AT, i, max_ii, k, max_kk, scales); +} + +void pack_B_tile_bf16_to_int8_asimddp(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ + pack_B_tile_bf16_to_int8(B, BT, j, max_jj, k, max_kk, scale); +} + +void transpose_pack_B_tile_bf16_to_int8_asimddp(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ + transpose_pack_B_tile_bf16_to_int8(B, BT, j, max_jj, k, max_kk, scale); +} + +void unpack_output_tile_int32_to_bf16_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta) +{ + unpack_output_tile_int32_to_bf16(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales, alpha, beta); +} + +void transpose_unpack_output_tile_int32_to_bf16_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta) +{ + transpose_unpack_output_tile_int32_to_bf16(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales, alpha, beta); +} +#endif // NCNN_BF16 + +} // namespace ncnn diff --git a/src/layer/arm/gemm_arm_asimdhp.cpp b/src/layer/arm/gemm_arm_asimdhp.cpp index cb0aa87e4add..dd5d4e6f8460 100644 --- a/src/layer/arm/gemm_arm_asimdhp.cpp +++ b/src/layer/arm/gemm_arm_asimdhp.cpp @@ -27,6 +27,10 @@ namespace ncnn { #include "gemm_bf16s_fp16s.h" #include "gemm_fp16s.h" +#if NCNN_INT8 +#include "gemm_int8_fp16s.h" +#endif + static void gemm_transB_packed_tile_fp16sa(const Mat& AT_tile, const Mat& BT_tile, const Mat& CT_tile, Mat& topT_tile, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, int k, int max_kk, bool k_end) { const int out_elempack = top_blob.elempack; @@ -3026,4 +3030,21 @@ int Gemm_arm::forward_fp16sa(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector(i + ii) + k; + const signed char* p1 = A.row(i + ii + 1) + k; + const signed char* p2 = A.row(i + ii + 2) + k; + const signed char* p3 = A.row(i + ii + 3) + k; + const signed char* p4 = A.row(i + ii + 4) + k; + const signed char* p5 = A.row(i + ii + 5) + k; + const signed char* p6 = A.row(i + ii + 6) + k; + const signed char* p7 = A.row(i + ii + 7) + k; + + int kk = 0; + for (; kk + 15 < max_kk; kk += 16) + { + int8x16_t _p0 = vld1q_s8(p0); + int8x16_t _p1 = vld1q_s8(p1); + int8x16_t _p2 = vld1q_s8(p2); + int8x16_t _p3 = vld1q_s8(p3); + int8x16_t _p4 = vld1q_s8(p4); + int8x16_t _p5 = vld1q_s8(p5); + int8x16_t _p6 = vld1q_s8(p6); + int8x16_t _p7 = vld1q_s8(p7); +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x16_t _r0 = vcombine_s8(vget_low_s8(_p0), vget_low_s8(_p1)); + int8x16_t _r1 = vcombine_s8(vget_low_s8(_p2), vget_low_s8(_p3)); + int8x16_t _r2 = vcombine_s8(vget_low_s8(_p4), vget_low_s8(_p5)); + int8x16_t _r3 = vcombine_s8(vget_low_s8(_p6), vget_low_s8(_p7)); + int8x16_t _r4 = vcombine_s8(vget_high_s8(_p0), vget_high_s8(_p1)); + int8x16_t _r5 = vcombine_s8(vget_high_s8(_p2), vget_high_s8(_p3)); + int8x16_t _r6 = vcombine_s8(vget_high_s8(_p4), vget_high_s8(_p5)); + int8x16_t _r7 = vcombine_s8(vget_high_s8(_p6), vget_high_s8(_p7)); +#else // __ARM_FEATURE_MATMUL_INT8 + int32x4x2_t _p01 = vzipq_s32(vreinterpretq_s32_s8(_p0), vreinterpretq_s32_s8(_p1)); + int32x4x2_t _p23 = vzipq_s32(vreinterpretq_s32_s8(_p2), vreinterpretq_s32_s8(_p3)); + int32x4x2_t _p45 = vzipq_s32(vreinterpretq_s32_s8(_p4), vreinterpretq_s32_s8(_p5)); + int32x4x2_t _p67 = vzipq_s32(vreinterpretq_s32_s8(_p6), vreinterpretq_s32_s8(_p7)); + int8x16_t _r0 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_p01.val[0]), vget_low_s32(_p23.val[0]))); + int8x16_t _r1 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_p45.val[0]), vget_low_s32(_p67.val[0]))); + int8x16_t _r2 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_p01.val[0]), vget_high_s32(_p23.val[0]))); + int8x16_t _r3 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_p45.val[0]), vget_high_s32(_p67.val[0]))); + int8x16_t _r4 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_p01.val[1]), vget_low_s32(_p23.val[1]))); + int8x16_t _r5 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_p45.val[1]), vget_low_s32(_p67.val[1]))); + int8x16_t _r6 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_p01.val[1]), vget_high_s32(_p23.val[1]))); + int8x16_t _r7 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_p45.val[1]), vget_high_s32(_p67.val[1]))); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x8x2_t _p01 = vzipq_s16(vreinterpretq_s16_s8(_p0), vreinterpretq_s16_s8(_p1)); + int16x8x2_t _p23 = vzipq_s16(vreinterpretq_s16_s8(_p2), vreinterpretq_s16_s8(_p3)); + int16x8x2_t _p45 = vzipq_s16(vreinterpretq_s16_s8(_p4), vreinterpretq_s16_s8(_p5)); + int16x8x2_t _p67 = vzipq_s16(vreinterpretq_s16_s8(_p6), vreinterpretq_s16_s8(_p7)); + int32x4x2_t _t0 = vzipq_s32(vreinterpretq_s32_s16(_p01.val[0]), vreinterpretq_s32_s16(_p23.val[0])); + int32x4x2_t _t1 = vzipq_s32(vreinterpretq_s32_s16(_p01.val[1]), vreinterpretq_s32_s16(_p23.val[1])); + int32x4x2_t _t2 = vzipq_s32(vreinterpretq_s32_s16(_p45.val[0]), vreinterpretq_s32_s16(_p67.val[0])); + int32x4x2_t _t3 = vzipq_s32(vreinterpretq_s32_s16(_p45.val[1]), vreinterpretq_s32_s16(_p67.val[1])); + int8x16_t _r0 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t2.val[0]))); + int8x16_t _r1 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t2.val[0]))); + int8x16_t _r2 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_t0.val[1]), vget_low_s32(_t2.val[1]))); + int8x16_t _r3 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_t0.val[1]), vget_high_s32(_t2.val[1]))); + int8x16_t _r4 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_t1.val[0]), vget_low_s32(_t3.val[0]))); + int8x16_t _r5 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_t1.val[0]), vget_high_s32(_t3.val[0]))); + int8x16_t _r6 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t3.val[1]))); + int8x16_t _r7 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t3.val[1]))); +#endif // __ARM_FEATURE_DOTPROD + vst1q_s8(pp, _r0); + vst1q_s8(pp + 16, _r1); + vst1q_s8(pp + 32, _r2); + vst1q_s8(pp + 48, _r3); + vst1q_s8(pp + 64, _r4); + vst1q_s8(pp + 80, _r5); + vst1q_s8(pp + 96, _r6); + vst1q_s8(pp + 112, _r7); + pp += 128; + p0 += 16; + p1 += 16; + p2 += 16; + p3 += 16; + p4 += 16; + p5 += 16; + p6 += 16; + p7 += 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + int8x8_t _p0 = vld1_s8(p0); + int8x8_t _p1 = vld1_s8(p1); + int8x8_t _p2 = vld1_s8(p2); + int8x8_t _p3 = vld1_s8(p3); + int8x8_t _p4 = vld1_s8(p4); + int8x8_t _p5 = vld1_s8(p5); + int8x8_t _p6 = vld1_s8(p6); + int8x8_t _p7 = vld1_s8(p7); +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x16_t _r0 = vcombine_s8(_p0, _p1); + int8x16_t _r1 = vcombine_s8(_p2, _p3); + int8x16_t _r2 = vcombine_s8(_p4, _p5); + int8x16_t _r3 = vcombine_s8(_p6, _p7); +#else // __ARM_FEATURE_MATMUL_INT8 + int32x2x2_t _p01 = vzip_s32(vreinterpret_s32_s8(_p0), vreinterpret_s32_s8(_p1)); + int32x2x2_t _p23 = vzip_s32(vreinterpret_s32_s8(_p2), vreinterpret_s32_s8(_p3)); + int32x2x2_t _p45 = vzip_s32(vreinterpret_s32_s8(_p4), vreinterpret_s32_s8(_p5)); + int32x2x2_t _p67 = vzip_s32(vreinterpret_s32_s8(_p6), vreinterpret_s32_s8(_p7)); + int8x16_t _r0 = vreinterpretq_s8_s32(vcombine_s32(_p01.val[0], _p23.val[0])); + int8x16_t _r1 = vreinterpretq_s8_s32(vcombine_s32(_p45.val[0], _p67.val[0])); + int8x16_t _r2 = vreinterpretq_s8_s32(vcombine_s32(_p01.val[1], _p23.val[1])); + int8x16_t _r3 = vreinterpretq_s8_s32(vcombine_s32(_p45.val[1], _p67.val[1])); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x8_t _p04 = vreinterpretq_s16_s8(vcombine_s8(_p0, _p4)); + int16x8_t _p15 = vreinterpretq_s16_s8(vcombine_s8(_p1, _p5)); + int16x8_t _p26 = vreinterpretq_s16_s8(vcombine_s8(_p2, _p6)); + int16x8_t _p37 = vreinterpretq_s16_s8(vcombine_s8(_p3, _p7)); + int16x8x2_t _t0 = vzipq_s16(_p04, _p15); + int16x8x2_t _t1 = vzipq_s16(_p26, _p37); + int32x4x2_t _t2 = vzipq_s32(vreinterpretq_s32_s16(_t0.val[0]), vreinterpretq_s32_s16(_t1.val[0])); + int32x4x2_t _t3 = vzipq_s32(vreinterpretq_s32_s16(_t0.val[1]), vreinterpretq_s32_s16(_t1.val[1])); + int8x16_t _r0 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0]))); + int8x16_t _r1 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0]))); + int8x16_t _r2 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1]))); + int8x16_t _r3 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1]))); +#endif // __ARM_FEATURE_DOTPROD + vst1q_s8(pp, _r0); + vst1q_s8(pp + 16, _r1); + vst1q_s8(pp + 32, _r2); + vst1q_s8(pp + 48, _r3); + pp += 64; + p0 += 8; + p1 += 8; + p2 += 8; + p3 += 8; + p4 += 8; + p5 += 8; + p6 += 8; + p7 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + pp[4] = p1[0]; + pp[5] = p1[1]; + pp[6] = p1[2]; + pp[7] = p1[3]; + pp[8] = p2[0]; + pp[9] = p2[1]; + pp[10] = p2[2]; + pp[11] = p2[3]; + pp[12] = p3[0]; + pp[13] = p3[1]; + pp[14] = p3[2]; + pp[15] = p3[3]; + pp[16] = p4[0]; + pp[17] = p4[1]; + pp[18] = p4[2]; + pp[19] = p4[3]; + pp[20] = p5[0]; + pp[21] = p5[1]; + pp[22] = p5[2]; + pp[23] = p5[3]; + pp[24] = p6[0]; + pp[25] = p6[1]; + pp[26] = p6[2]; + pp[27] = p6[3]; + pp[28] = p7[0]; + pp[29] = p7[1]; + pp[30] = p7[2]; + pp[31] = p7[3]; +#else // __ARM_FEATURE_DOTPROD + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p1[0]; + pp[3] = p1[1]; + pp[4] = p2[0]; + pp[5] = p2[1]; + pp[6] = p3[0]; + pp[7] = p3[1]; + pp[8] = p4[0]; + pp[9] = p4[1]; + pp[10] = p5[0]; + pp[11] = p5[1]; + pp[12] = p6[0]; + pp[13] = p6[1]; + pp[14] = p7[0]; + pp[15] = p7[1]; + pp[16] = p0[2]; + pp[17] = p0[3]; + pp[18] = p1[2]; + pp[19] = p1[3]; + pp[20] = p2[2]; + pp[21] = p2[3]; + pp[22] = p3[2]; + pp[23] = p3[3]; + pp[24] = p4[2]; + pp[25] = p4[3]; + pp[26] = p5[2]; + pp[27] = p5[3]; + pp[28] = p6[2]; + pp[29] = p6[3]; + pp[30] = p7[2]; + pp[31] = p7[3]; +#endif // __ARM_FEATURE_DOTPROD + pp += 32; + p0 += 4; + p1 += 4; + p2 += 4; + p3 += 4; + p4 += 4; + p5 += 4; + p6 += 4; + p7 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p1[0]; + pp[3] = p1[1]; + pp[4] = p2[0]; + pp[5] = p2[1]; + pp[6] = p3[0]; + pp[7] = p3[1]; + pp[8] = p4[0]; + pp[9] = p4[1]; + pp[10] = p5[0]; + pp[11] = p5[1]; + pp[12] = p6[0]; + pp[13] = p6[1]; + pp[14] = p7[0]; + pp[15] = p7[1]; + pp += 16; + p0 += 2; + p1 += 2; + p2 += 2; + p3 += 2; + p4 += 2; + p5 += 2; + p6 += 2; + p7 += 2; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp[2] = p2[0]; + pp[3] = p3[0]; + pp[4] = p4[0]; + pp[5] = p5[0]; + pp[6] = p6[0]; + pp[7] = p7[0]; + pp += 8; + p0++; + p1++; + p2++; + p3++; + p4++; + p5++; + p6++; + p7++; + } + } + for (; ii + 3 < max_ii; ii += 4) + { + const signed char* p0 = A.row(i + ii) + k; + const signed char* p1 = A.row(i + ii + 1) + k; + const signed char* p2 = A.row(i + ii + 2) + k; + const signed char* p3 = A.row(i + ii + 3) + k; + + int kk = 0; + for (; kk + 15 < max_kk; kk += 16) + { + int8x16_t _p0 = vld1q_s8(p0); + int8x16_t _p1 = vld1q_s8(p1); + int8x16_t _p2 = vld1q_s8(p2); + int8x16_t _p3 = vld1q_s8(p3); +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int64x2x4_t _r0123; + _r0123.val[0] = vreinterpretq_s64_s8(_p0); + _r0123.val[1] = vreinterpretq_s64_s8(_p1); + _r0123.val[2] = vreinterpretq_s64_s8(_p2); + _r0123.val[3] = vreinterpretq_s64_s8(_p3); + vst4q_s64((int64_t*)pp, _r0123); +#else // __ARM_FEATURE_MATMUL_INT8 + int32x4x4_t _r0123; + _r0123.val[0] = vreinterpretq_s32_s8(_p0); + _r0123.val[1] = vreinterpretq_s32_s8(_p1); + _r0123.val[2] = vreinterpretq_s32_s8(_p2); + _r0123.val[3] = vreinterpretq_s32_s8(_p3); + vst4q_s32((int*)pp, _r0123); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x8x4_t _r0123; + _r0123.val[0] = vreinterpretq_s16_s8(_p0); + _r0123.val[1] = vreinterpretq_s16_s8(_p1); + _r0123.val[2] = vreinterpretq_s16_s8(_p2); + _r0123.val[3] = vreinterpretq_s16_s8(_p3); + vst4q_s16((short*)pp, _r0123); +#endif // __ARM_FEATURE_DOTPROD + pp += 64; + p0 += 16; + p1 += 16; + p2 += 16; + p3 += 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + int8x8_t _p0 = vld1_s8(p0); + int8x8_t _p1 = vld1_s8(p1); + int8x8_t _p2 = vld1_s8(p2); + int8x8_t _p3 = vld1_s8(p3); +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + vst1q_s8(pp, vcombine_s8(_p0, _p1)); + vst1q_s8(pp + 16, vcombine_s8(_p2, _p3)); +#else // __ARM_FEATURE_MATMUL_INT8 + int32x2x4_t _r0123; + _r0123.val[0] = vreinterpret_s32_s8(_p0); + _r0123.val[1] = vreinterpret_s32_s8(_p1); + _r0123.val[2] = vreinterpret_s32_s8(_p2); + _r0123.val[3] = vreinterpret_s32_s8(_p3); + vst4_s32((int*)pp, _r0123); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4x4_t _r0123; + _r0123.val[0] = vreinterpret_s16_s8(_p0); + _r0123.val[1] = vreinterpret_s16_s8(_p1); + _r0123.val[2] = vreinterpret_s16_s8(_p2); + _r0123.val[3] = vreinterpret_s16_s8(_p3); + vst4_s16((short*)pp, _r0123); +#endif // __ARM_FEATURE_DOTPROD + pp += 32; + p0 += 8; + p1 += 8; + p2 += 8; + p3 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + pp[4] = p1[0]; + pp[5] = p1[1]; + pp[6] = p1[2]; + pp[7] = p1[3]; + pp[8] = p2[0]; + pp[9] = p2[1]; + pp[10] = p2[2]; + pp[11] = p2[3]; + pp[12] = p3[0]; + pp[13] = p3[1]; + pp[14] = p3[2]; + pp[15] = p3[3]; +#else // __ARM_FEATURE_DOTPROD + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p1[0]; + pp[3] = p1[1]; + pp[4] = p2[0]; + pp[5] = p2[1]; + pp[6] = p3[0]; + pp[7] = p3[1]; + pp[8] = p0[2]; + pp[9] = p0[3]; + pp[10] = p1[2]; + pp[11] = p1[3]; + pp[12] = p2[2]; + pp[13] = p2[3]; + pp[14] = p3[2]; + pp[15] = p3[3]; +#endif // __ARM_FEATURE_DOTPROD + pp += 16; + p0 += 4; + p1 += 4; + p2 += 4; + p3 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p1[0]; + pp[3] = p1[1]; + pp[4] = p2[0]; + pp[5] = p2[1]; + pp[6] = p3[0]; + pp[7] = p3[1]; + pp += 8; + p0 += 2; + p1 += 2; + p2 += 2; + p3 += 2; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp[2] = p2[0]; + pp[3] = p3[0]; + pp += 4; + p0++; + p1++; + p2++; + p3++; + } + } +#endif // __ARM_NEON + for (; ii + 1 < max_ii; ii += 2) + { + const signed char* p0 = A.row(i + ii) + k; + const signed char* p1 = A.row(i + ii + 1) + k; + + int kk = 0; +#if __ARM_NEON + for (; kk + 15 < max_kk; kk += 16) + { + int8x16_t _p0 = vld1q_s8(p0); + int8x16_t _p1 = vld1q_s8(p1); +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int64x2x2_t _r01; + _r01.val[0] = vreinterpretq_s64_s8(_p0); + _r01.val[1] = vreinterpretq_s64_s8(_p1); + vst2q_s64((int64_t*)pp, _r01); +#else // __ARM_FEATURE_MATMUL_INT8 + int32x4x2_t _r01; + _r01.val[0] = vreinterpretq_s32_s8(_p0); + _r01.val[1] = vreinterpretq_s32_s8(_p1); + vst2q_s32((int*)pp, _r01); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x8x2_t _r01; + _r01.val[0] = vreinterpretq_s16_s8(_p0); + _r01.val[1] = vreinterpretq_s16_s8(_p1); + vst2q_s16((short*)pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + pp += 32; + p0 += 16; + p1 += 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + int8x8_t _p0 = vld1_s8(p0); + int8x8_t _p1 = vld1_s8(p1); +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + vst1q_s8(pp, vcombine_s8(_p0, _p1)); +#else // __ARM_FEATURE_MATMUL_INT8 + int32x2x2_t _r01; + _r01.val[0] = vreinterpret_s32_s8(_p0); + _r01.val[1] = vreinterpret_s32_s8(_p1); + vst2_s32((int*)pp, _r01); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4x2_t _r01; + _r01.val[0] = vreinterpret_s16_s8(_p0); + _r01.val[1] = vreinterpret_s16_s8(_p1); + vst2_s16((short*)pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + pp += 16; + p0 += 8; + p1 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + pp[4] = p1[0]; + pp[5] = p1[1]; + pp[6] = p1[2]; + pp[7] = p1[3]; +#else // __ARM_FEATURE_DOTPROD + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p1[0]; + pp[3] = p1[1]; + pp[4] = p0[2]; + pp[5] = p0[3]; + pp[6] = p1[2]; + pp[7] = p1[3]; +#endif // __ARM_FEATURE_DOTPROD + pp += 8; + p0 += 4; + p1 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p1[0]; + pp[3] = p1[1]; + pp += 4; + p0 += 2; + p1 += 2; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp += 2; + p0++; + p1++; + } + } + for (; ii < max_ii; ii += 1) + { + const signed char* p0 = A.row(i + ii) + k; + + int kk = 0; +#if __ARM_NEON + for (; kk + 15 < max_kk; kk += 16) + { + vst1q_s8(pp, vld1q_s8(p0)); + pp += 16; + p0 += 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + vst1_s8(pp, vld1_s8(p0)); + pp += 8; + p0 += 8; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp += 1; + p0++; + } + } +} + +static void transpose_pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + transpose_pack_A_tile_int8_i8mm(A, AT, i, max_ii, k, max_kk); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + transpose_pack_A_tile_int8_asimddp(A, AT, i, max_ii, k, max_kk); + return; + } +#endif + + // NCNN_LOGE("transpose_pack_A_tile_int8"); + // assert A.elempack == 1 + // assert A.dims == 2 + + const int A_hstep = A.w; + + signed char* pp = AT; + + int ii = 0; +#if __ARM_NEON + for (; ii + 7 < max_ii; ii += 8) + { + const signed char* p0 = A.row(k) + (i + ii); + + int kk = 0; +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + for (; kk + 7 < max_kk; kk += 8) + { + int8x8_t _r0 = vld1_s8(p0); + int8x8_t _r1 = vld1_s8(p0 + A_hstep); + int8x8_t _r2 = vld1_s8(p0 + A_hstep * 2); + int8x8_t _r3 = vld1_s8(p0 + A_hstep * 3); + int8x8_t _r4 = vld1_s8(p0 + A_hstep * 4); + int8x8_t _r5 = vld1_s8(p0 + A_hstep * 5); + int8x8_t _r6 = vld1_s8(p0 + A_hstep * 6); + int8x8_t _r7 = vld1_s8(p0 + A_hstep * 7); + // transpose8x8 + int8x8x2_t _r04 = vzip_s8(_r0, _r4); + int8x8x2_t _r15 = vzip_s8(_r1, _r5); + int8x8x2_t _r26 = vzip_s8(_r2, _r6); + int8x8x2_t _r37 = vzip_s8(_r3, _r7); + int8x8x4_t _r0123; + _r0123.val[0] = _r04.val[0]; + _r0123.val[1] = _r15.val[0]; + _r0123.val[2] = _r26.val[0]; + _r0123.val[3] = _r37.val[0]; + int8x8x4_t _r4567; + _r4567.val[0] = _r04.val[1]; + _r4567.val[1] = _r15.val[1]; + _r4567.val[2] = _r26.val[1]; + _r4567.val[3] = _r37.val[1]; + vst4_s8(pp, _r0123); + vst4_s8(pp + 32, _r4567); + pp += 64; + p0 += A_hstep * 8; + } +#endif // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 3 < max_kk; kk += 4) + { + int8x8x4_t _r0123; + _r0123.val[0] = vld1_s8(p0); + _r0123.val[1] = vld1_s8(p0 + A_hstep); + _r0123.val[2] = vld1_s8(p0 + A_hstep * 2); + _r0123.val[3] = vld1_s8(p0 + A_hstep * 3); + vst4_s8(pp, _r0123); + pp += 32; + p0 += A_hstep * 4; + } +#endif // __ARM_FEATURE_DOTPROD + for (; kk + 1 < max_kk; kk += 2) + { + int8x8x2_t _r01; + _r01.val[0] = vld1_s8(p0); + _r01.val[1] = vld1_s8(p0 + A_hstep); + vst2_s8(pp, _r01); + pp += 16; + p0 += A_hstep * 2; + } + for (; kk < max_kk; kk++) + { + vst1_s8(pp, vld1_s8(p0)); + pp += 8; + p0 += A_hstep; + } + } + for (; ii + 3 < max_ii; ii += 4) + { + const signed char* p0 = A.row(k) + (i + ii); + + int kk = 0; +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + for (; kk + 7 < max_kk; kk += 8) + { + pp[0] = p0[0]; + pp[1] = p0[A_hstep]; + pp[2] = p0[A_hstep * 2]; + pp[3] = p0[A_hstep * 3]; + pp[4] = p0[A_hstep * 4]; + pp[5] = p0[A_hstep * 5]; + pp[6] = p0[A_hstep * 6]; + pp[7] = p0[A_hstep * 7]; + pp[8] = p0[1]; + pp[9] = p0[A_hstep + 1]; + pp[10] = p0[A_hstep * 2 + 1]; + pp[11] = p0[A_hstep * 3 + 1]; + pp[12] = p0[A_hstep * 4 + 1]; + pp[13] = p0[A_hstep * 5 + 1]; + pp[14] = p0[A_hstep * 6 + 1]; + pp[15] = p0[A_hstep * 7 + 1]; + pp[16] = p0[2]; + pp[17] = p0[A_hstep + 2]; + pp[18] = p0[A_hstep * 2 + 2]; + pp[19] = p0[A_hstep * 3 + 2]; + pp[20] = p0[A_hstep * 4 + 2]; + pp[21] = p0[A_hstep * 5 + 2]; + pp[22] = p0[A_hstep * 6 + 2]; + pp[23] = p0[A_hstep * 7 + 2]; + pp[24] = p0[3]; + pp[25] = p0[A_hstep + 3]; + pp[26] = p0[A_hstep * 2 + 3]; + pp[27] = p0[A_hstep * 3 + 3]; + pp[28] = p0[A_hstep * 4 + 3]; + pp[29] = p0[A_hstep * 5 + 3]; + pp[30] = p0[A_hstep * 6 + 3]; + pp[31] = p0[A_hstep * 7 + 3]; + pp += 32; + p0 += A_hstep * 8; + } +#endif // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0]; + pp[1] = p0[A_hstep]; + pp[2] = p0[A_hstep * 2]; + pp[3] = p0[A_hstep * 3]; + pp[4] = p0[1]; + pp[5] = p0[A_hstep + 1]; + pp[6] = p0[A_hstep * 2 + 1]; + pp[7] = p0[A_hstep * 3 + 1]; + pp[8] = p0[2]; + pp[9] = p0[A_hstep + 2]; + pp[10] = p0[A_hstep * 2 + 2]; + pp[11] = p0[A_hstep * 3 + 2]; + pp[12] = p0[3]; + pp[13] = p0[A_hstep + 3]; + pp[14] = p0[A_hstep * 2 + 3]; + pp[15] = p0[A_hstep * 3 + 3]; + pp += 16; + p0 += A_hstep * 4; + } +#endif // __ARM_FEATURE_DOTPROD + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[A_hstep]; + pp[2] = p0[1]; + pp[3] = p0[A_hstep + 1]; + pp[4] = p0[2]; + pp[5] = p0[A_hstep + 2]; + pp[6] = p0[3]; + pp[7] = p0[A_hstep + 3]; + pp += 8; + p0 += A_hstep * 2; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + pp += 4; + p0 += A_hstep; + } + } +#endif // __ARM_NEON + for (; ii + 1 < max_ii; ii += 2) + { + const signed char* p0 = A.row(k) + (i + ii); + + int kk = 0; +#if __ARM_NEON +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + for (; kk + 7 < max_kk; kk += 8) + { + pp[0] = p0[0]; + pp[1] = p0[A_hstep]; + pp[2] = p0[A_hstep * 2]; + pp[3] = p0[A_hstep * 3]; + pp[4] = p0[A_hstep * 4]; + pp[5] = p0[A_hstep * 5]; + pp[6] = p0[A_hstep * 6]; + pp[7] = p0[A_hstep * 7]; + pp[8] = p0[1]; + pp[9] = p0[A_hstep + 1]; + pp[10] = p0[A_hstep * 2 + 1]; + pp[11] = p0[A_hstep * 3 + 1]; + pp[12] = p0[A_hstep * 4 + 1]; + pp[13] = p0[A_hstep * 5 + 1]; + pp[14] = p0[A_hstep * 6 + 1]; + pp[15] = p0[A_hstep * 7 + 1]; + pp += 16; + p0 += A_hstep * 8; + } +#endif // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0]; + pp[1] = p0[A_hstep]; + pp[2] = p0[A_hstep * 2]; + pp[3] = p0[A_hstep * 3]; + pp[4] = p0[1]; + pp[5] = p0[A_hstep + 1]; + pp[6] = p0[A_hstep * 2 + 1]; + pp[7] = p0[A_hstep * 3 + 1]; + pp += 8; + p0 += A_hstep * 4; + } +#endif // __ARM_FEATURE_DOTPROD + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[A_hstep]; + pp[2] = p0[1]; + pp[3] = p0[A_hstep + 1]; + pp += 4; + p0 += A_hstep * 2; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp += 2; + p0 += A_hstep; + } + } + for (; ii < max_ii; ii += 1) + { + const signed char* p0 = A.row(k) + (i + ii); + + int kk = 0; + // for (; kk + 1 < max_kk; kk += 2) + // { + // pp[0] = p0[0]; + // pp[1] = p0[A_hstep]; + // pp += 2; + // p0 += A_hstep * 2; + // } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp += 1; + p0 += A_hstep; + } + } +} + +static void pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + pack_B_tile_int8_i8mm(B, BT, j, max_jj, k, max_kk); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + pack_B_tile_int8_asimddp(B, BT, j, max_jj, k, max_kk); + return; + } +#endif + + // NCNN_LOGE("pack_B_tile_int8"); + // assert B.elempack == 1 + // assert B.dims == 2 + + signed char* pp = BT; + + int jj = 0; +#if __ARM_NEON +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + const signed char* p0 = B.row(j + jj) + k; + const signed char* p1 = B.row(j + jj + 1) + k; + const signed char* p2 = B.row(j + jj + 2) + k; + const signed char* p3 = B.row(j + jj + 3) + k; + const signed char* p4 = B.row(j + jj + 4) + k; + const signed char* p5 = B.row(j + jj + 5) + k; + const signed char* p6 = B.row(j + jj + 6) + k; + const signed char* p7 = B.row(j + jj + 7) + k; + + int kk = 0; + for (; kk + 15 < max_kk; kk += 16) + { + int8x16_t _p0 = vld1q_s8(p0); + int8x16_t _p1 = vld1q_s8(p1); + int8x16_t _p2 = vld1q_s8(p2); + int8x16_t _p3 = vld1q_s8(p3); + int8x16_t _p4 = vld1q_s8(p4); + int8x16_t _p5 = vld1q_s8(p5); + int8x16_t _p6 = vld1q_s8(p6); + int8x16_t _p7 = vld1q_s8(p7); +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x16_t _r0 = vcombine_s8(vget_low_s8(_p0), vget_low_s8(_p1)); + int8x16_t _r1 = vcombine_s8(vget_low_s8(_p2), vget_low_s8(_p3)); + int8x16_t _r2 = vcombine_s8(vget_low_s8(_p4), vget_low_s8(_p5)); + int8x16_t _r3 = vcombine_s8(vget_low_s8(_p6), vget_low_s8(_p7)); + int8x16_t _r4 = vcombine_s8(vget_high_s8(_p0), vget_high_s8(_p1)); + int8x16_t _r5 = vcombine_s8(vget_high_s8(_p2), vget_high_s8(_p3)); + int8x16_t _r6 = vcombine_s8(vget_high_s8(_p4), vget_high_s8(_p5)); + int8x16_t _r7 = vcombine_s8(vget_high_s8(_p6), vget_high_s8(_p7)); +#else // __ARM_FEATURE_MATMUL_INT8 + int32x4x2_t _p01 = vzipq_s32(vreinterpretq_s32_s8(_p0), vreinterpretq_s32_s8(_p1)); + int32x4x2_t _p23 = vzipq_s32(vreinterpretq_s32_s8(_p2), vreinterpretq_s32_s8(_p3)); + int32x4x2_t _p45 = vzipq_s32(vreinterpretq_s32_s8(_p4), vreinterpretq_s32_s8(_p5)); + int32x4x2_t _p67 = vzipq_s32(vreinterpretq_s32_s8(_p6), vreinterpretq_s32_s8(_p7)); + int8x16_t _r0 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_p01.val[0]), vget_low_s32(_p23.val[0]))); + int8x16_t _r1 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_p45.val[0]), vget_low_s32(_p67.val[0]))); + int8x16_t _r2 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_p01.val[0]), vget_high_s32(_p23.val[0]))); + int8x16_t _r3 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_p45.val[0]), vget_high_s32(_p67.val[0]))); + int8x16_t _r4 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_p01.val[1]), vget_low_s32(_p23.val[1]))); + int8x16_t _r5 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_p45.val[1]), vget_low_s32(_p67.val[1]))); + int8x16_t _r6 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_p01.val[1]), vget_high_s32(_p23.val[1]))); + int8x16_t _r7 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_p45.val[1]), vget_high_s32(_p67.val[1]))); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x8x2_t _p01 = vzipq_s16(vreinterpretq_s16_s8(_p0), vreinterpretq_s16_s8(_p1)); + int16x8x2_t _p23 = vzipq_s16(vreinterpretq_s16_s8(_p2), vreinterpretq_s16_s8(_p3)); + int16x8x2_t _p45 = vzipq_s16(vreinterpretq_s16_s8(_p4), vreinterpretq_s16_s8(_p5)); + int16x8x2_t _p67 = vzipq_s16(vreinterpretq_s16_s8(_p6), vreinterpretq_s16_s8(_p7)); + int32x4x2_t _t0 = vzipq_s32(vreinterpretq_s32_s16(_p01.val[0]), vreinterpretq_s32_s16(_p23.val[0])); + int32x4x2_t _t1 = vzipq_s32(vreinterpretq_s32_s16(_p01.val[1]), vreinterpretq_s32_s16(_p23.val[1])); + int32x4x2_t _t2 = vzipq_s32(vreinterpretq_s32_s16(_p45.val[0]), vreinterpretq_s32_s16(_p67.val[0])); + int32x4x2_t _t3 = vzipq_s32(vreinterpretq_s32_s16(_p45.val[1]), vreinterpretq_s32_s16(_p67.val[1])); + int8x16_t _r0 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t2.val[0]))); + int8x16_t _r1 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t2.val[0]))); + int8x16_t _r2 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_t0.val[1]), vget_low_s32(_t2.val[1]))); + int8x16_t _r3 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_t0.val[1]), vget_high_s32(_t2.val[1]))); + int8x16_t _r4 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_t1.val[0]), vget_low_s32(_t3.val[0]))); + int8x16_t _r5 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_t1.val[0]), vget_high_s32(_t3.val[0]))); + int8x16_t _r6 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t3.val[1]))); + int8x16_t _r7 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t3.val[1]))); +#endif // __ARM_FEATURE_DOTPROD + vst1q_s8(pp, _r0); + vst1q_s8(pp + 16, _r1); + vst1q_s8(pp + 32, _r2); + vst1q_s8(pp + 48, _r3); + vst1q_s8(pp + 64, _r4); + vst1q_s8(pp + 80, _r5); + vst1q_s8(pp + 96, _r6); + vst1q_s8(pp + 112, _r7); + pp += 128; + p0 += 16; + p1 += 16; + p2 += 16; + p3 += 16; + p4 += 16; + p5 += 16; + p6 += 16; + p7 += 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + int8x8_t _p0 = vld1_s8(p0); + int8x8_t _p1 = vld1_s8(p1); + int8x8_t _p2 = vld1_s8(p2); + int8x8_t _p3 = vld1_s8(p3); + int8x8_t _p4 = vld1_s8(p4); + int8x8_t _p5 = vld1_s8(p5); + int8x8_t _p6 = vld1_s8(p6); + int8x8_t _p7 = vld1_s8(p7); +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x16_t _r0 = vcombine_s8(_p0, _p1); + int8x16_t _r1 = vcombine_s8(_p2, _p3); + int8x16_t _r2 = vcombine_s8(_p4, _p5); + int8x16_t _r3 = vcombine_s8(_p6, _p7); +#else // __ARM_FEATURE_MATMUL_INT8 + int32x2x2_t _p01 = vzip_s32(vreinterpret_s32_s8(_p0), vreinterpret_s32_s8(_p1)); + int32x2x2_t _p23 = vzip_s32(vreinterpret_s32_s8(_p2), vreinterpret_s32_s8(_p3)); + int32x2x2_t _p45 = vzip_s32(vreinterpret_s32_s8(_p4), vreinterpret_s32_s8(_p5)); + int32x2x2_t _p67 = vzip_s32(vreinterpret_s32_s8(_p6), vreinterpret_s32_s8(_p7)); + int8x16_t _r0 = vreinterpretq_s8_s32(vcombine_s32(_p01.val[0], _p23.val[0])); + int8x16_t _r1 = vreinterpretq_s8_s32(vcombine_s32(_p45.val[0], _p67.val[0])); + int8x16_t _r2 = vreinterpretq_s8_s32(vcombine_s32(_p01.val[1], _p23.val[1])); + int8x16_t _r3 = vreinterpretq_s8_s32(vcombine_s32(_p45.val[1], _p67.val[1])); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x8_t _p04 = vreinterpretq_s16_s8(vcombine_s8(_p0, _p4)); + int16x8_t _p15 = vreinterpretq_s16_s8(vcombine_s8(_p1, _p5)); + int16x8_t _p26 = vreinterpretq_s16_s8(vcombine_s8(_p2, _p6)); + int16x8_t _p37 = vreinterpretq_s16_s8(vcombine_s8(_p3, _p7)); + int16x8x2_t _t0 = vzipq_s16(_p04, _p15); + int16x8x2_t _t1 = vzipq_s16(_p26, _p37); + int32x4x2_t _t2 = vzipq_s32(vreinterpretq_s32_s16(_t0.val[0]), vreinterpretq_s32_s16(_t1.val[0])); + int32x4x2_t _t3 = vzipq_s32(vreinterpretq_s32_s16(_t0.val[1]), vreinterpretq_s32_s16(_t1.val[1])); + int8x16_t _r0 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0]))); + int8x16_t _r1 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0]))); + int8x16_t _r2 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1]))); + int8x16_t _r3 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1]))); +#endif // __ARM_FEATURE_DOTPROD + vst1q_s8(pp, _r0); + vst1q_s8(pp + 16, _r1); + vst1q_s8(pp + 32, _r2); + vst1q_s8(pp + 48, _r3); + pp += 64; + p0 += 8; + p1 += 8; + p2 += 8; + p3 += 8; + p4 += 8; + p5 += 8; + p6 += 8; + p7 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + pp[4] = p1[0]; + pp[5] = p1[1]; + pp[6] = p1[2]; + pp[7] = p1[3]; + pp[8] = p2[0]; + pp[9] = p2[1]; + pp[10] = p2[2]; + pp[11] = p2[3]; + pp[12] = p3[0]; + pp[13] = p3[1]; + pp[14] = p3[2]; + pp[15] = p3[3]; + pp[16] = p4[0]; + pp[17] = p4[1]; + pp[18] = p4[2]; + pp[19] = p4[3]; + pp[20] = p5[0]; + pp[21] = p5[1]; + pp[22] = p5[2]; + pp[23] = p5[3]; + pp[24] = p6[0]; + pp[25] = p6[1]; + pp[26] = p6[2]; + pp[27] = p6[3]; + pp[28] = p7[0]; + pp[29] = p7[1]; + pp[30] = p7[2]; + pp[31] = p7[3]; +#else // __ARM_FEATURE_DOTPROD + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p1[0]; + pp[3] = p1[1]; + pp[4] = p2[0]; + pp[5] = p2[1]; + pp[6] = p3[0]; + pp[7] = p3[1]; + pp[8] = p4[0]; + pp[9] = p4[1]; + pp[10] = p5[0]; + pp[11] = p5[1]; + pp[12] = p6[0]; + pp[13] = p6[1]; + pp[14] = p7[0]; + pp[15] = p7[1]; + pp[16] = p0[2]; + pp[17] = p0[3]; + pp[18] = p1[2]; + pp[19] = p1[3]; + pp[20] = p2[2]; + pp[21] = p2[3]; + pp[22] = p3[2]; + pp[23] = p3[3]; + pp[24] = p4[2]; + pp[25] = p4[3]; + pp[26] = p5[2]; + pp[27] = p5[3]; + pp[28] = p6[2]; + pp[29] = p6[3]; + pp[30] = p7[2]; + pp[31] = p7[3]; +#endif // __ARM_FEATURE_DOTPROD + pp += 32; + p0 += 4; + p1 += 4; + p2 += 4; + p3 += 4; + p4 += 4; + p5 += 4; + p6 += 4; + p7 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p1[0]; + pp[3] = p1[1]; + pp[4] = p2[0]; + pp[5] = p2[1]; + pp[6] = p3[0]; + pp[7] = p3[1]; + pp[8] = p4[0]; + pp[9] = p4[1]; + pp[10] = p5[0]; + pp[11] = p5[1]; + pp[12] = p6[0]; + pp[13] = p6[1]; + pp[14] = p7[0]; + pp[15] = p7[1]; + pp += 16; + p0 += 2; + p1 += 2; + p2 += 2; + p3 += 2; + p4 += 2; + p5 += 2; + p6 += 2; + p7 += 2; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp[2] = p2[0]; + pp[3] = p3[0]; + pp[4] = p4[0]; + pp[5] = p5[0]; + pp[6] = p6[0]; + pp[7] = p7[0]; + pp += 8; + p0++; + p1++; + p2++; + p3++; + p4++; + p5++; + p6++; + p7++; + } + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + const signed char* p0 = B.row(j + jj) + k; + const signed char* p1 = B.row(j + jj + 1) + k; + const signed char* p2 = B.row(j + jj + 2) + k; + const signed char* p3 = B.row(j + jj + 3) + k; + + int kk = 0; + for (; kk + 15 < max_kk; kk += 16) + { + int8x16_t _p0 = vld1q_s8(p0); + int8x16_t _p1 = vld1q_s8(p1); + int8x16_t _p2 = vld1q_s8(p2); + int8x16_t _p3 = vld1q_s8(p3); +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int64x2x4_t _r0123; + _r0123.val[0] = vreinterpretq_s64_s8(_p0); + _r0123.val[1] = vreinterpretq_s64_s8(_p1); + _r0123.val[2] = vreinterpretq_s64_s8(_p2); + _r0123.val[3] = vreinterpretq_s64_s8(_p3); + vst4q_s64((int64_t*)pp, _r0123); +#else // __ARM_FEATURE_MATMUL_INT8 + int32x4x4_t _r0123; + _r0123.val[0] = vreinterpretq_s32_s8(_p0); + _r0123.val[1] = vreinterpretq_s32_s8(_p1); + _r0123.val[2] = vreinterpretq_s32_s8(_p2); + _r0123.val[3] = vreinterpretq_s32_s8(_p3); + vst4q_s32((int*)pp, _r0123); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x8x4_t _r0123; + _r0123.val[0] = vreinterpretq_s16_s8(_p0); + _r0123.val[1] = vreinterpretq_s16_s8(_p1); + _r0123.val[2] = vreinterpretq_s16_s8(_p2); + _r0123.val[3] = vreinterpretq_s16_s8(_p3); + vst4q_s16((short*)pp, _r0123); +#endif // __ARM_FEATURE_DOTPROD + pp += 64; + p0 += 16; + p1 += 16; + p2 += 16; + p3 += 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + int8x8_t _p0 = vld1_s8(p0); + int8x8_t _p1 = vld1_s8(p1); + int8x8_t _p2 = vld1_s8(p2); + int8x8_t _p3 = vld1_s8(p3); +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + vst1q_s8(pp, vcombine_s8(_p0, _p1)); + vst1q_s8(pp + 16, vcombine_s8(_p2, _p3)); +#else // __ARM_FEATURE_MATMUL_INT8 + int32x2x4_t _r0123; + _r0123.val[0] = vreinterpret_s32_s8(_p0); + _r0123.val[1] = vreinterpret_s32_s8(_p1); + _r0123.val[2] = vreinterpret_s32_s8(_p2); + _r0123.val[3] = vreinterpret_s32_s8(_p3); + vst4_s32((int*)pp, _r0123); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4x4_t _r0123; + _r0123.val[0] = vreinterpret_s16_s8(_p0); + _r0123.val[1] = vreinterpret_s16_s8(_p1); + _r0123.val[2] = vreinterpret_s16_s8(_p2); + _r0123.val[3] = vreinterpret_s16_s8(_p3); + vst4_s16((short*)pp, _r0123); +#endif // __ARM_FEATURE_DOTPROD + pp += 32; + p0 += 8; + p1 += 8; + p2 += 8; + p3 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + pp[4] = p1[0]; + pp[5] = p1[1]; + pp[6] = p1[2]; + pp[7] = p1[3]; + pp[8] = p2[0]; + pp[9] = p2[1]; + pp[10] = p2[2]; + pp[11] = p2[3]; + pp[12] = p3[0]; + pp[13] = p3[1]; + pp[14] = p3[2]; + pp[15] = p3[3]; +#else // __ARM_FEATURE_DOTPROD + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p1[0]; + pp[3] = p1[1]; + pp[4] = p2[0]; + pp[5] = p2[1]; + pp[6] = p3[0]; + pp[7] = p3[1]; + pp[8] = p0[2]; + pp[9] = p0[3]; + pp[10] = p1[2]; + pp[11] = p1[3]; + pp[12] = p2[2]; + pp[13] = p2[3]; + pp[14] = p3[2]; + pp[15] = p3[3]; +#endif // __ARM_FEATURE_DOTPROD + pp += 16; + p0 += 4; + p1 += 4; + p2 += 4; + p3 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p1[0]; + pp[3] = p1[1]; + pp[4] = p2[0]; + pp[5] = p2[1]; + pp[6] = p3[0]; + pp[7] = p3[1]; + pp += 8; + p0 += 2; + p1 += 2; + p2 += 2; + p3 += 2; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp[2] = p2[0]; + pp[3] = p3[0]; + pp += 4; + p0++; + p1++; + p2++; + p3++; + } + } +#endif // __ARM_NEON + for (; jj + 1 < max_jj; jj += 2) + { + const signed char* p0 = B.row(j + jj) + k; + const signed char* p1 = B.row(j + jj + 1) + k; + + int kk = 0; +#if __ARM_NEON + for (; kk + 15 < max_kk; kk += 16) + { + int8x16_t _p0 = vld1q_s8(p0); + int8x16_t _p1 = vld1q_s8(p1); +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int64x2x2_t _r01; + _r01.val[0] = vreinterpretq_s64_s8(_p0); + _r01.val[1] = vreinterpretq_s64_s8(_p1); + vst2q_s64((int64_t*)pp, _r01); +#else // __ARM_FEATURE_MATMUL_INT8 + int32x4x2_t _r01; + _r01.val[0] = vreinterpretq_s32_s8(_p0); + _r01.val[1] = vreinterpretq_s32_s8(_p1); + vst2q_s32((int*)pp, _r01); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x8x2_t _r01; + _r01.val[0] = vreinterpretq_s16_s8(_p0); + _r01.val[1] = vreinterpretq_s16_s8(_p1); + vst2q_s16((short*)pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + pp += 32; + p0 += 16; + p1 += 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + int8x8_t _p0 = vld1_s8(p0); + int8x8_t _p1 = vld1_s8(p1); +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + vst1q_s8(pp, vcombine_s8(_p0, _p1)); +#else // __ARM_FEATURE_MATMUL_INT8 + int32x2x2_t _r01; + _r01.val[0] = vreinterpret_s32_s8(_p0); + _r01.val[1] = vreinterpret_s32_s8(_p1); + vst2_s32((int*)pp, _r01); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4x2_t _r01; + _r01.val[0] = vreinterpret_s16_s8(_p0); + _r01.val[1] = vreinterpret_s16_s8(_p1); + vst2_s16((short*)pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + pp += 16; + p0 += 8; + p1 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + pp[4] = p1[0]; + pp[5] = p1[1]; + pp[6] = p1[2]; + pp[7] = p1[3]; +#else // __ARM_FEATURE_DOTPROD + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p1[0]; + pp[3] = p1[1]; + pp[4] = p0[2]; + pp[5] = p0[3]; + pp[6] = p1[2]; + pp[7] = p1[3]; +#endif // __ARM_FEATURE_DOTPROD + pp += 8; + p0 += 4; + p1 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p1[0]; + pp[3] = p1[1]; + pp += 4; + p0 += 2; + p1 += 2; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp += 2; + p0++; + p1++; + } + } + for (; jj < max_jj; jj += 1) + { + const signed char* p0 = B.row(j + jj) + k; + + int kk = 0; +#if __ARM_NEON + for (; kk + 15 < max_kk; kk += 16) + { + vst1q_s8(pp, vld1q_s8(p0)); + pp += 16; + p0 += 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + vst1_s8(pp, vld1_s8(p0)); + pp += 8; + p0 += 8; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp += 1; + p0++; + } + } +} + +static void transpose_pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + transpose_pack_B_tile_int8_i8mm(B, BT, j, max_jj, k, max_kk); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + transpose_pack_B_tile_int8_asimddp(B, BT, j, max_jj, k, max_kk); + return; + } +#endif + + // NCNN_LOGE("transpose_pack_B_tile_int8"); + // assert B.elempack == 1 + // assert B.dims == 2 + + const int B_hstep = B.w; + + signed char* pp = BT; + + int jj = 0; +#if __ARM_NEON +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + const signed char* p0 = B.row(k) + (j + jj); + + int kk = 0; +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + for (; kk + 7 < max_kk; kk += 8) + { + int8x8_t _r0 = vld1_s8(p0); + int8x8_t _r1 = vld1_s8(p0 + B_hstep); + int8x8_t _r2 = vld1_s8(p0 + B_hstep * 2); + int8x8_t _r3 = vld1_s8(p0 + B_hstep * 3); + int8x8_t _r4 = vld1_s8(p0 + B_hstep * 4); + int8x8_t _r5 = vld1_s8(p0 + B_hstep * 5); + int8x8_t _r6 = vld1_s8(p0 + B_hstep * 6); + int8x8_t _r7 = vld1_s8(p0 + B_hstep * 7); + // transpose8x8 + int8x8x2_t _r04 = vzip_s8(_r0, _r4); + int8x8x2_t _r15 = vzip_s8(_r1, _r5); + int8x8x2_t _r26 = vzip_s8(_r2, _r6); + int8x8x2_t _r37 = vzip_s8(_r3, _r7); + int8x8x4_t _r0123; + _r0123.val[0] = _r04.val[0]; + _r0123.val[1] = _r15.val[0]; + _r0123.val[2] = _r26.val[0]; + _r0123.val[3] = _r37.val[0]; + int8x8x4_t _r4567; + _r4567.val[0] = _r04.val[1]; + _r4567.val[1] = _r15.val[1]; + _r4567.val[2] = _r26.val[1]; + _r4567.val[3] = _r37.val[1]; + vst4_s8(pp, _r0123); + vst4_s8(pp + 32, _r4567); + pp += 64; + p0 += B_hstep * 8; + } +#endif // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 3 < max_kk; kk += 4) + { + int8x8x4_t _r0123; + _r0123.val[0] = vld1_s8(p0); + _r0123.val[1] = vld1_s8(p0 + B_hstep); + _r0123.val[2] = vld1_s8(p0 + B_hstep * 2); + _r0123.val[3] = vld1_s8(p0 + B_hstep * 3); + vst4_s8(pp, _r0123); + pp += 32; + p0 += B_hstep * 4; + } +#endif // __ARM_FEATURE_DOTPROD + for (; kk + 1 < max_kk; kk += 2) + { + int8x8x2_t _r01; + _r01.val[0] = vld1_s8(p0); + _r01.val[1] = vld1_s8(p0 + B_hstep); + vst2_s8(pp, _r01); + pp += 16; + p0 += B_hstep * 2; + } + for (; kk < max_kk; kk++) + { + vst1_s8(pp, vld1_s8(p0)); + pp += 8; + p0 += B_hstep; + } + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + const signed char* p0 = B.row(k) + (j + jj); + + int kk = 0; +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + for (; kk + 7 < max_kk; kk += 8) + { + pp[0] = p0[0]; + pp[1] = p0[B_hstep]; + pp[2] = p0[B_hstep * 2]; + pp[3] = p0[B_hstep * 3]; + pp[4] = p0[B_hstep * 4]; + pp[5] = p0[B_hstep * 5]; + pp[6] = p0[B_hstep * 6]; + pp[7] = p0[B_hstep * 7]; + pp[8] = p0[1]; + pp[9] = p0[B_hstep + 1]; + pp[10] = p0[B_hstep * 2 + 1]; + pp[11] = p0[B_hstep * 3 + 1]; + pp[12] = p0[B_hstep * 4 + 1]; + pp[13] = p0[B_hstep * 5 + 1]; + pp[14] = p0[B_hstep * 6 + 1]; + pp[15] = p0[B_hstep * 7 + 1]; + pp[16] = p0[2]; + pp[17] = p0[B_hstep + 2]; + pp[18] = p0[B_hstep * 2 + 2]; + pp[19] = p0[B_hstep * 3 + 2]; + pp[20] = p0[B_hstep * 4 + 2]; + pp[21] = p0[B_hstep * 5 + 2]; + pp[22] = p0[B_hstep * 6 + 2]; + pp[23] = p0[B_hstep * 7 + 2]; + pp[24] = p0[3]; + pp[25] = p0[B_hstep + 3]; + pp[26] = p0[B_hstep * 2 + 3]; + pp[27] = p0[B_hstep * 3 + 3]; + pp[28] = p0[B_hstep * 4 + 3]; + pp[29] = p0[B_hstep * 5 + 3]; + pp[30] = p0[B_hstep * 6 + 3]; + pp[31] = p0[B_hstep * 7 + 3]; + pp += 32; + p0 += B_hstep * 8; + } +#endif // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0]; + pp[1] = p0[B_hstep]; + pp[2] = p0[B_hstep * 2]; + pp[3] = p0[B_hstep * 3]; + pp[4] = p0[1]; + pp[5] = p0[B_hstep + 1]; + pp[6] = p0[B_hstep * 2 + 1]; + pp[7] = p0[B_hstep * 3 + 1]; + pp[8] = p0[2]; + pp[9] = p0[B_hstep + 2]; + pp[10] = p0[B_hstep * 2 + 2]; + pp[11] = p0[B_hstep * 3 + 2]; + pp[12] = p0[3]; + pp[13] = p0[B_hstep + 3]; + pp[14] = p0[B_hstep * 2 + 3]; + pp[15] = p0[B_hstep * 3 + 3]; + pp += 16; + p0 += B_hstep * 4; + } +#endif // __ARM_FEATURE_DOTPROD + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[B_hstep]; + pp[2] = p0[1]; + pp[3] = p0[B_hstep + 1]; + pp[4] = p0[2]; + pp[5] = p0[B_hstep + 2]; + pp[6] = p0[3]; + pp[7] = p0[B_hstep + 3]; + pp += 8; + p0 += B_hstep * 2; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + pp += 4; + p0 += B_hstep; + } + } +#endif // __ARM_NEON + for (; jj + 1 < max_jj; jj += 2) + { + const signed char* p0 = B.row(k) + (j + jj); + + int kk = 0; +#if __ARM_NEON +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + for (; kk + 7 < max_kk; kk += 8) + { + pp[0] = p0[0]; + pp[1] = p0[B_hstep]; + pp[2] = p0[B_hstep * 2]; + pp[3] = p0[B_hstep * 3]; + pp[4] = p0[B_hstep * 4]; + pp[5] = p0[B_hstep * 5]; + pp[6] = p0[B_hstep * 6]; + pp[7] = p0[B_hstep * 7]; + pp[8] = p0[1]; + pp[9] = p0[B_hstep + 1]; + pp[10] = p0[B_hstep * 2 + 1]; + pp[11] = p0[B_hstep * 3 + 1]; + pp[12] = p0[B_hstep * 4 + 1]; + pp[13] = p0[B_hstep * 5 + 1]; + pp[14] = p0[B_hstep * 6 + 1]; + pp[15] = p0[B_hstep * 7 + 1]; + pp += 16; + p0 += B_hstep * 8; + } +#endif // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0]; + pp[1] = p0[B_hstep]; + pp[2] = p0[B_hstep * 2]; + pp[3] = p0[B_hstep * 3]; + pp[4] = p0[1]; + pp[5] = p0[B_hstep + 1]; + pp[6] = p0[B_hstep * 2 + 1]; + pp[7] = p0[B_hstep * 3 + 1]; + pp += 8; + p0 += B_hstep * 4; + } +#endif // __ARM_FEATURE_DOTPROD + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[B_hstep]; + pp[2] = p0[1]; + pp[3] = p0[B_hstep + 1]; + pp += 4; + p0 += B_hstep * 2; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp += 2; + p0 += B_hstep; + } + } + for (; jj < max_jj; jj += 1) + { + const signed char* p0 = B.row(k) + (j + jj); + + int kk = 0; + // for (; kk + 1 < max_kk; kk += 2) + // { + // pp[0] = p0[0]; + // pp[1] = p0[B_hstep]; + // pp += 2; + // p0 += B_hstep * 2; + // } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp += 1; + p0 += B_hstep; + } + } +} + +static void compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales, float B_scale, Mat& out_descales, int i, int max_ii) +{ + const int elempack = A.elempack; + const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + const int K = A.w; + + // NCNN_LOGE("compute_A_tile_int8_scales %d %d", max_ii, elempack); + + const float v127_B_scale = 127.f * B_scale; + + float* ps = scales; + float* pods = out_descales; + +#if __ARM_NEON + if (elempack == 4) + { +#if __aarch64__ + float32x4_t _v127 = vdupq_n_f32(127.f); + float32x4_t _v127_B_scale = vdupq_n_f32(v127_B_scale); +#endif + + for (int ii = 0; ii + 3 < max_ii; ii += 4) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep; + + float32x4_t _absmax0 = vdupq_n_f32(0.f); + float32x4_t _absmax1 = vdupq_n_f32(0.f); + float32x4_t _absmax2 = vdupq_n_f32(0.f); + float32x4_t _absmax3 = vdupq_n_f32(0.f); + int kk = 0; + for (; kk + 3 < K; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + _absmax2 = vmaxq_f32(_absmax2, vabsq_f32(_p2)); + _absmax3 = vmaxq_f32(_absmax3, vabsq_f32(_p3)); + p0 += 16; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax2); + _absmax1 = vmaxq_f32(_absmax1, _absmax3); + for (; kk + 1 < K; kk += 2) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + p0 += 8; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax1); + for (; kk < K; kk++) + { + float32x4_t _p = vld1q_f32(p0); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p)); + p0 += 4; + } + +#if __aarch64__ + float32x4_t _scale = vdivq_f32(_v127, _absmax0); + float32x4_t _out_descale = vdivq_f32(_absmax0, _v127_B_scale); + + vst1q_f32(ps, _scale); + vst1q_f32(pods, _out_descale); +#else + // float32x4_t _recp_absmax = vrecpeq_f32(_absmax0); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax0, _recp_absmax), _recp_absmax); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax0, _recp_absmax), _recp_absmax); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax0, _recp_absmax), _recp_absmax); + // float32x4_t _scale = vmulq_f32(_v127, _recp_absmax); + // float32x4_t _out_descale = vmulq_f32(_absmax0, _recp_v127_B_scale); + + float tmp[4]; + vst1q_f32(tmp, _absmax0); + + ps[0] = 127.f / tmp[0]; + ps[1] = 127.f / tmp[1]; + ps[2] = 127.f / tmp[2]; + ps[3] = 127.f / tmp[3]; + + pods[0] = tmp[0] / v127_B_scale; + pods[1] = tmp[1] / v127_B_scale; + pods[2] = tmp[2] / v127_B_scale; + pods[3] = tmp[3] / v127_B_scale; + +#endif + ps += 4; + pods += 4; + } + } +#endif // __ARM_NEON + if (elempack == 1) + { + for (int ii = 0; ii < max_ii; ii++) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep; + + float absmax = 0.f; + int kk = 0; +#if __ARM_NEON + float32x4_t _absmax0 = vdupq_n_f32(0.f); + float32x4_t _absmax1 = vdupq_n_f32(0.f); + float32x4_t _absmax2 = vdupq_n_f32(0.f); + float32x4_t _absmax3 = vdupq_n_f32(0.f); + for (; kk + 15 < K; kk += 16) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + _absmax2 = vmaxq_f32(_absmax2, vabsq_f32(_p2)); + _absmax3 = vmaxq_f32(_absmax3, vabsq_f32(_p3)); + p0 += 16; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax2); + _absmax1 = vmaxq_f32(_absmax1, _absmax3); + for (; kk + 7 < K; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + p0 += 8; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax1); + for (; kk + 3 < K; kk += 4) + { + float32x4_t _p = vld1q_f32(p0); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p)); + p0 += 4; + } + float32x2_t _aa = vmax_f32(vget_low_f32(_absmax0), vget_high_f32(_absmax0)); + absmax = std::max(absmax, std::max(vget_lane_f32(_aa, 0), vget_lane_f32(_aa, 1))); +#endif // __ARM_NEON + for (; kk < K; kk++) + { + absmax = std::max(absmax, (float)fabsf(p0[0])); + p0++; + } + + ps[0] = 127.f / absmax; + pods[0] = absmax / v127_B_scale; + ps++; + pods++; + } + } +} + +static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + pack_A_tile_fp32_to_int8_i8mm(A, AT, i, max_ii, k, max_kk, scales); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + pack_A_tile_fp32_to_int8_asimddp(A, AT, i, max_ii, k, max_kk, scales); + return; + } +#endif + + const int elempack = A.elempack; + const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + + // NCNN_LOGE("pack_A_tile_fp32_to_int8 %d %d", max_ii, elempack); + + signed char* pp = AT; + + int ii = 0; +#if __ARM_NEON + for (; ii + 7 < max_ii; ii += 8) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep + k * elempack; + + float32x4_t _scale0 = vld1q_f32((const float*)scales + ii); + float32x4_t _scale1 = vld1q_f32((const float*)scales + ii + 4); + + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { +#if __ARM_FEATURE_DOTPROD + float32x4x4_t _p = vld4q_f32(p0); + float32x4x4_t _q = vld4q_f32(p0 + 16); + float32x4x4_t _r = vld4q_f32(p0 + A_hstep * 4); + float32x4x4_t _s = vld4q_f32(p0 + A_hstep * 4 + 16); + + float32x4_t _p0 = vmulq_laneq_f32(_p.val[0], _scale0, 0); + float32x4_t _p1 = vmulq_laneq_f32(_p.val[1], _scale0, 1); + float32x4_t _p2 = vmulq_laneq_f32(_p.val[2], _scale0, 2); + float32x4_t _p3 = vmulq_laneq_f32(_p.val[3], _scale0, 3); + float32x4_t _p4 = vmulq_laneq_f32(_q.val[0], _scale0, 0); + float32x4_t _p5 = vmulq_laneq_f32(_q.val[1], _scale0, 1); + float32x4_t _p6 = vmulq_laneq_f32(_q.val[2], _scale0, 2); + float32x4_t _p7 = vmulq_laneq_f32(_q.val[3], _scale0, 3); + float32x4_t _p8 = vmulq_laneq_f32(_r.val[0], _scale1, 0); + float32x4_t _p9 = vmulq_laneq_f32(_r.val[1], _scale1, 1); + float32x4_t _pa = vmulq_laneq_f32(_r.val[2], _scale1, 2); + float32x4_t _pb = vmulq_laneq_f32(_r.val[3], _scale1, 3); + float32x4_t _pc = vmulq_laneq_f32(_s.val[0], _scale1, 0); + float32x4_t _pd = vmulq_laneq_f32(_s.val[1], _scale1, 1); + float32x4_t _pe = vmulq_laneq_f32(_s.val[2], _scale1, 2); + float32x4_t _pf = vmulq_laneq_f32(_s.val[3], _scale1, 3); + +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p4); + int8x8_t _r1 = float2int8(_p1, _p5); + int8x8_t _r2 = float2int8(_p2, _p6); + int8x8_t _r3 = float2int8(_p3, _p7); + int8x8_t _r4 = float2int8(_p8, _pc); + int8x8_t _r5 = float2int8(_p9, _pd); + int8x8_t _r6 = float2int8(_pa, _pe); + int8x8_t _r7 = float2int8(_pb, _pf); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p8, _p9); + int8x8_t _r3 = float2int8(_pa, _pb); + int8x8_t _r4 = float2int8(_p4, _p5); + int8x8_t _r5 = float2int8(_p6, _p7); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); +#endif // __ARM_FEATURE_MATMUL_INT8 + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + float32x4_t _p4 = vld1q_f32(p0 + 16); + float32x4_t _p5 = vld1q_f32(p0 + 20); + float32x4_t _p6 = vld1q_f32(p0 + 24); + float32x4_t _p7 = vld1q_f32(p0 + 28); + float32x4_t _p8 = vld1q_f32(p0 + A_hstep * 4); + float32x4_t _p9 = vld1q_f32(p0 + A_hstep * 4 + 4); + float32x4_t _pa = vld1q_f32(p0 + A_hstep * 4 + 8); + float32x4_t _pb = vld1q_f32(p0 + A_hstep * 4 + 12); + float32x4_t _pc = vld1q_f32(p0 + A_hstep * 4 + 16); + float32x4_t _pd = vld1q_f32(p0 + A_hstep * 4 + 20); + float32x4_t _pe = vld1q_f32(p0 + A_hstep * 4 + 24); + float32x4_t _pf = vld1q_f32(p0 + A_hstep * 4 + 28); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale0); + _p2 = vmulq_f32(_p2, _scale0); + _p3 = vmulq_f32(_p3, _scale0); + _p4 = vmulq_f32(_p4, _scale0); + _p5 = vmulq_f32(_p5, _scale0); + _p6 = vmulq_f32(_p6, _scale0); + _p7 = vmulq_f32(_p7, _scale0); + _p8 = vmulq_f32(_p8, _scale1); + _p9 = vmulq_f32(_p9, _scale1); + _pa = vmulq_f32(_pa, _scale1); + _pb = vmulq_f32(_pb, _scale1); + _pc = vmulq_f32(_pc, _scale1); + _pd = vmulq_f32(_pd, _scale1); + _pe = vmulq_f32(_pe, _scale1); + _pf = vmulq_f32(_pf, _scale1); + + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p8), float2int8(_p2, _pa)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p9), float2int8(_p3, _pb)); + int8x16x2_t _r23; + _r23.val[0] = vcombine_s8(float2int8(_p4, _pc), float2int8(_p6, _pe)); + _r23.val[1] = vcombine_s8(float2int8(_p5, _pd), float2int8(_p7, _pf)); + + vst2q_s8(pp, _r01); + vst2q_s8(pp + 32, _r23); +#endif // __ARM_FEATURE_DOTPROD + + pp += 64; + p0 += 32; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + float32x4x4_t _p = vld4q_f32(p0); + float32x4x4_t _q = vld4q_f32(p0 + A_hstep * 4); + + float32x4_t _p0 = vmulq_laneq_f32(_p.val[0], _scale0, 0); + float32x4_t _p1 = vmulq_laneq_f32(_p.val[1], _scale0, 1); + float32x4_t _p2 = vmulq_laneq_f32(_p.val[2], _scale0, 2); + float32x4_t _p3 = vmulq_laneq_f32(_p.val[3], _scale0, 3); + float32x4_t _p4 = vmulq_laneq_f32(_q.val[0], _scale1, 0); + float32x4_t _p5 = vmulq_laneq_f32(_q.val[1], _scale1, 1); + float32x4_t _p6 = vmulq_laneq_f32(_q.val[2], _scale1, 2); + float32x4_t _p7 = vmulq_laneq_f32(_q.val[3], _scale1, 3); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + float32x4_t _p4 = vld1q_f32(p0 + A_hstep * 4); + float32x4_t _p5 = vld1q_f32(p0 + A_hstep * 4 + 4); + float32x4_t _p6 = vld1q_f32(p0 + A_hstep * 4 + 8); + float32x4_t _p7 = vld1q_f32(p0 + A_hstep * 4 + 12); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale0); + _p2 = vmulq_f32(_p2, _scale0); + _p3 = vmulq_f32(_p3, _scale0); + _p4 = vmulq_f32(_p4, _scale1); + _p5 = vmulq_f32(_p5, _scale1); + _p6 = vmulq_f32(_p6, _scale1); + _p7 = vmulq_f32(_p7, _scale1); + + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p4), float2int8(_p2, _p6)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p5), float2int8(_p3, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += 16; + } + for (; kk + 1 < max_kk; kk += 2) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p0n = vld1q_f32(p0 + 4); + float32x4_t _p1 = vld1q_f32(p0 + A_hstep * 4); + float32x4_t _p1n = vld1q_f32(p0 + A_hstep * 4 + 4); + + _p0 = vmulq_f32(_p0, _scale0); + _p0n = vmulq_f32(_p0n, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + _p1n = vmulq_f32(_p1n, _scale1); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p1); + _r01.val[1] = float2int8(_p0n, _p1n); + + vst2_s8(pp, _r01); + + pp += 16; + p0 += 8; + } + for (; kk < max_kk; kk++) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + A_hstep * 4); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + A_hstep); + float32x4_t _p3 = vld1q_f32(p0 + A_hstep + 4); + float32x4_t _p4 = vld1q_f32(p0 + A_hstep * 2); + float32x4_t _p5 = vld1q_f32(p0 + A_hstep * 2 + 4); + float32x4_t _p6 = vld1q_f32(p0 + A_hstep * 3); + float32x4_t _p7 = vld1q_f32(p0 + A_hstep * 3 + 4); + float32x4_t _p8 = vld1q_f32(p0 + A_hstep * 4); + float32x4_t _p9 = vld1q_f32(p0 + A_hstep * 4 + 4); + float32x4_t _pa = vld1q_f32(p0 + A_hstep * 5); + float32x4_t _pb = vld1q_f32(p0 + A_hstep * 5 + 4); + float32x4_t _pc = vld1q_f32(p0 + A_hstep * 6); + float32x4_t _pd = vld1q_f32(p0 + A_hstep * 6 + 4); + float32x4_t _pe = vld1q_f32(p0 + A_hstep * 7); + float32x4_t _pf = vld1q_f32(p0 + A_hstep * 7 + 4); + +#if __aarch64__ + _p0 = vmulq_laneq_f32(_p0, _scale0, 0); + _p1 = vmulq_laneq_f32(_p1, _scale0, 0); + _p2 = vmulq_laneq_f32(_p2, _scale0, 1); + _p3 = vmulq_laneq_f32(_p3, _scale0, 1); + _p4 = vmulq_laneq_f32(_p4, _scale0, 2); + _p5 = vmulq_laneq_f32(_p5, _scale0, 2); + _p6 = vmulq_laneq_f32(_p6, _scale0, 3); + _p7 = vmulq_laneq_f32(_p7, _scale0, 3); + _p8 = vmulq_laneq_f32(_p8, _scale1, 0); + _p9 = vmulq_laneq_f32(_p9, _scale1, 0); + _pa = vmulq_laneq_f32(_pa, _scale1, 1); + _pb = vmulq_laneq_f32(_pb, _scale1, 1); + _pc = vmulq_laneq_f32(_pc, _scale1, 2); + _pd = vmulq_laneq_f32(_pd, _scale1, 2); + _pe = vmulq_laneq_f32(_pe, _scale1, 3); + _pf = vmulq_laneq_f32(_pf, _scale1, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale0), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale0), 0); + _p2 = vmulq_lane_f32(_p2, vget_low_f32(_scale0), 1); + _p3 = vmulq_lane_f32(_p3, vget_low_f32(_scale0), 1); + _p4 = vmulq_lane_f32(_p4, vget_high_f32(_scale0), 0); + _p5 = vmulq_lane_f32(_p5, vget_high_f32(_scale0), 0); + _p6 = vmulq_lane_f32(_p6, vget_high_f32(_scale0), 1); + _p7 = vmulq_lane_f32(_p7, vget_high_f32(_scale0), 1); + _p8 = vmulq_lane_f32(_p8, vget_low_f32(_scale1), 0); + _p9 = vmulq_lane_f32(_p9, vget_low_f32(_scale1), 0); + _pa = vmulq_lane_f32(_pa, vget_low_f32(_scale1), 1); + _pb = vmulq_lane_f32(_pb, vget_low_f32(_scale1), 1); + _pc = vmulq_lane_f32(_pc, vget_high_f32(_scale1), 0); + _pd = vmulq_lane_f32(_pd, vget_high_f32(_scale1), 0); + _pe = vmulq_lane_f32(_pe, vget_high_f32(_scale1), 1); + _pf = vmulq_lane_f32(_pf, vget_high_f32(_scale1), 1); +#endif + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p4, _p6); + int8x8_t _r2 = float2int8(_p8, _pa); + int8x8_t _r3 = float2int8(_pc, _pe); + int8x8_t _r4 = float2int8(_p1, _p3); + int8x8_t _r5 = float2int8(_p5, _p7); + int8x8_t _r6 = float2int8(_p9, _pb); + int8x8_t _r7 = float2int8(_pd, _pf); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p4, _p6)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p8, _pa)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_pc, _pe)); + int16x4_t _t4 = vreinterpret_s16_s8(float2int8(_p1, _p3)); + int16x4_t _t5 = vreinterpret_s16_s8(float2int8(_p5, _p7)); + int16x4_t _t6 = vreinterpret_s16_s8(float2int8(_p9, _pb)); + int16x4_t _t7 = vreinterpret_s16_s8(float2int8(_pd, _pf)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int16x4x2_t _t45 = vuzp_s16(_t4, _t5); + int16x4x2_t _t67 = vuzp_s16(_t6, _t7); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r2 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); + int8x8_t _r4 = vreinterpret_s8_s16(_t45.val[0]); + int8x8_t _r5 = vreinterpret_s8_s16(_t67.val[0]); + int8x8_t _r6 = vreinterpret_s8_s16(_t45.val[1]); + int8x8_t _r7 = vreinterpret_s8_s16(_t67.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); + + pp += 64; + p0 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + A_hstep); + float32x4_t _p2 = vld1q_f32(p0 + A_hstep * 2); + float32x4_t _p3 = vld1q_f32(p0 + A_hstep * 3); + float32x4_t _p4 = vld1q_f32(p0 + A_hstep * 4); + float32x4_t _p5 = vld1q_f32(p0 + A_hstep * 5); + float32x4_t _p6 = vld1q_f32(p0 + A_hstep * 6); + float32x4_t _p7 = vld1q_f32(p0 + A_hstep * 7); + +#if __aarch64__ + _p0 = vmulq_laneq_f32(_p0, _scale0, 0); + _p1 = vmulq_laneq_f32(_p1, _scale0, 1); + _p2 = vmulq_laneq_f32(_p2, _scale0, 2); + _p3 = vmulq_laneq_f32(_p3, _scale0, 3); + _p4 = vmulq_laneq_f32(_p4, _scale1, 0); + _p5 = vmulq_laneq_f32(_p5, _scale1, 1); + _p6 = vmulq_laneq_f32(_p6, _scale1, 2); + _p7 = vmulq_laneq_f32(_p7, _scale1, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale0), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale0), 1); + _p2 = vmulq_lane_f32(_p2, vget_high_f32(_scale0), 0); + _p3 = vmulq_lane_f32(_p3, vget_high_f32(_scale0), 1); + _p4 = vmulq_lane_f32(_p4, vget_low_f32(_scale1), 0); + _p5 = vmulq_lane_f32(_p5, vget_low_f32(_scale1), 1); + _p6 = vmulq_lane_f32(_p6, vget_high_f32(_scale1), 0); + _p7 = vmulq_lane_f32(_p7, vget_high_f32(_scale1), 1); +#endif + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p4, _p5)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p6, _p7)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r2 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + + pp += 32; + p0 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + float32x2_t _p0 = vld1_f32(p0); + float32x2_t _p1 = vld1_f32(p0 + A_hstep); + float32x2_t _p2 = vld1_f32(p0 + A_hstep * 2); + float32x2_t _p3 = vld1_f32(p0 + A_hstep * 3); + float32x2_t _p4 = vld1_f32(p0 + A_hstep * 4); + float32x2_t _p5 = vld1_f32(p0 + A_hstep * 5); + float32x2_t _p6 = vld1_f32(p0 + A_hstep * 6); + float32x2_t _p7 = vld1_f32(p0 + A_hstep * 7); + + float32x4_t _p01 = vcombine_f32(_p0, _p1); + float32x4_t _p23 = vcombine_f32(_p2, _p3); + float32x4_t _p45 = vcombine_f32(_p4, _p5); + float32x4_t _p67 = vcombine_f32(_p6, _p7); + + float32x4x2_t _scale01 = vzipq_f32(_scale0, _scale0); + float32x4x2_t _scale23 = vzipq_f32(_scale1, _scale1); + + _p01 = vmulq_f32(_p01, _scale01.val[0]); + _p23 = vmulq_f32(_p23, _scale01.val[1]); + _p45 = vmulq_f32(_p45, _scale23.val[0]); + _p67 = vmulq_f32(_p67, _scale23.val[1]); + + int8x8_t _r0 = float2int8(_p01, _p23); + int8x8_t _r1 = float2int8(_p45, _p67); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += 2; + } + for (; kk < max_kk; kk++) + { + float32x4_t _p0 = float32x4_t(); + float32x4_t _p1 = float32x4_t(); + _p0 = vsetq_lane_f32(p0[0], _p0, 0); + _p0 = vsetq_lane_f32(p0[A_hstep], _p0, 1); + _p0 = vsetq_lane_f32(p0[A_hstep * 2], _p0, 2); + _p0 = vsetq_lane_f32(p0[A_hstep * 3], _p0, 3); + _p1 = vsetq_lane_f32(p0[A_hstep * 4], _p1, 0); + _p1 = vsetq_lane_f32(p0[A_hstep * 5], _p1, 1); + _p1 = vsetq_lane_f32(p0[A_hstep * 6], _p1, 2); + _p1 = vsetq_lane_f32(p0[A_hstep * 7], _p1, 3); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + + pp += 8; + p0++; + } + } + } + for (; ii + 3 < max_ii; ii += 4) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep + k * elempack; + + float32x4_t _scale = vld1q_f32((const float*)scales + ii); + + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { +#if __ARM_FEATURE_DOTPROD + float32x4x4_t _p = vld4q_f32(p0); + float32x4x4_t _q = vld4q_f32(p0 + 16); + + float32x4_t _p0 = vmulq_laneq_f32(_p.val[0], _scale, 0); + float32x4_t _p1 = vmulq_laneq_f32(_p.val[1], _scale, 1); + float32x4_t _p2 = vmulq_laneq_f32(_p.val[2], _scale, 2); + float32x4_t _p3 = vmulq_laneq_f32(_p.val[3], _scale, 3); + float32x4_t _p4 = vmulq_laneq_f32(_q.val[0], _scale, 0); + float32x4_t _p5 = vmulq_laneq_f32(_q.val[1], _scale, 1); + float32x4_t _p6 = vmulq_laneq_f32(_q.val[2], _scale, 2); + float32x4_t _p7 = vmulq_laneq_f32(_q.val[3], _scale, 3); + +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p4); + int8x8_t _r1 = float2int8(_p1, _p5); + int8x8_t _r2 = float2int8(_p2, _p6); + int8x8_t _r3 = float2int8(_p3, _p7); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#endif // __ARM_FEATURE_MATMUL_INT8 + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + float32x4_t _p4 = vld1q_f32(p0 + 16); + float32x4_t _p5 = vld1q_f32(p0 + 20); + float32x4_t _p6 = vld1q_f32(p0 + 24); + float32x4_t _p7 = vld1q_f32(p0 + 28); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p2), float2int8(_p4, _p6)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p3), float2int8(_p5, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += 32; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + float32x4x4_t _p = vld4q_f32(p0); + + float32x4_t _p0 = vmulq_laneq_f32(_p.val[0], _scale, 0); + float32x4_t _p1 = vmulq_laneq_f32(_p.val[1], _scale, 1); + float32x4_t _p2 = vmulq_laneq_f32(_p.val[2], _scale, 2); + float32x4_t _p3 = vmulq_laneq_f32(_p.val[3], _scale, 3); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p2); + _r01.val[1] = float2int8(_p1, _p3); + + vst2_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 16; + p0 += 16; + } + for (; kk + 1 < max_kk; kk += 2) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + float32x4x2_t _p01 = vzipq_f32(_p0, _p1); + + int8x8_t _r01 = float2int8(_p01.val[0], _p01.val[1]); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += 8; + } + for (; kk < max_kk; kk++) + { + float32x4_t _p0 = vld1q_f32(p0); + _p0 = vmulq_f32(_p0, _scale); + int8x8_t _r0 = float2int8(_p0, _p0); + + pp[0] = vget_lane_s8(_r0, 0); + pp[1] = vget_lane_s8(_r0, 1); + pp[2] = vget_lane_s8(_r0, 2); + pp[3] = vget_lane_s8(_r0, 3); + + pp += 4; + p0 += 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + A_hstep); + float32x4_t _p3 = vld1q_f32(p0 + A_hstep + 4); + float32x4_t _p4 = vld1q_f32(p0 + A_hstep * 2); + float32x4_t _p5 = vld1q_f32(p0 + A_hstep * 2 + 4); + float32x4_t _p6 = vld1q_f32(p0 + A_hstep * 3); + float32x4_t _p7 = vld1q_f32(p0 + A_hstep * 3 + 4); + +#if __aarch64__ + _p0 = vmulq_laneq_f32(_p0, _scale, 0); + _p1 = vmulq_laneq_f32(_p1, _scale, 0); + _p2 = vmulq_laneq_f32(_p2, _scale, 1); + _p3 = vmulq_laneq_f32(_p3, _scale, 1); + _p4 = vmulq_laneq_f32(_p4, _scale, 2); + _p5 = vmulq_laneq_f32(_p5, _scale, 2); + _p6 = vmulq_laneq_f32(_p6, _scale, 3); + _p7 = vmulq_laneq_f32(_p7, _scale, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale), 0); + _p2 = vmulq_lane_f32(_p2, vget_low_f32(_scale), 1); + _p3 = vmulq_lane_f32(_p3, vget_low_f32(_scale), 1); + _p4 = vmulq_lane_f32(_p4, vget_high_f32(_scale), 0); + _p5 = vmulq_lane_f32(_p5, vget_high_f32(_scale), 0); + _p6 = vmulq_lane_f32(_p6, vget_high_f32(_scale), 1); + _p7 = vmulq_lane_f32(_p7, vget_high_f32(_scale), 1); +#endif + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p4, _p6); + int8x8_t _r2 = float2int8(_p1, _p3); + int8x8_t _r3 = float2int8(_p5, _p7); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p4, _p6)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p1, _p3)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p5, _p7)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r2 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + + pp += 32; + p0 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + A_hstep); + float32x4_t _p2 = vld1q_f32(p0 + A_hstep * 2); + float32x4_t _p3 = vld1q_f32(p0 + A_hstep * 3); + +#if __aarch64__ + _p0 = vmulq_laneq_f32(_p0, _scale, 0); + _p1 = vmulq_laneq_f32(_p1, _scale, 1); + _p2 = vmulq_laneq_f32(_p2, _scale, 2); + _p3 = vmulq_laneq_f32(_p3, _scale, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale), 1); + _p2 = vmulq_lane_f32(_p2, vget_high_f32(_scale), 0); + _p3 = vmulq_lane_f32(_p3, vget_high_f32(_scale), 1); +#endif + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + float32x2_t _p0 = vld1_f32(p0); + float32x2_t _p1 = vld1_f32(p0 + A_hstep); + float32x2_t _p2 = vld1_f32(p0 + A_hstep * 2); + float32x2_t _p3 = vld1_f32(p0 + A_hstep * 3); + + float32x4_t _p01 = vcombine_f32(_p0, _p1); + float32x4_t _p23 = vcombine_f32(_p2, _p3); + + float32x4x2_t _scale01 = vzipq_f32(_scale, _scale); + + _p01 = vmulq_f32(_p01, _scale01.val[0]); + _p23 = vmulq_f32(_p23, _scale01.val[1]); + + int8x8_t _r0 = float2int8(_p01, _p23); + + vst1_s8(pp, _r0); + + pp += 8; + p0 += 2; + } + for (; kk < max_kk; kk++) + { + float32x4_t _p0 = float32x4_t(); + _p0 = vsetq_lane_f32(p0[0], _p0, 0); + _p0 = vsetq_lane_f32(p0[A_hstep], _p0, 1); + _p0 = vsetq_lane_f32(p0[A_hstep * 2], _p0, 2); + _p0 = vsetq_lane_f32(p0[A_hstep * 3], _p0, 3); + + _p0 = vmulq_f32(_p0, _scale); + int8x8_t _r0 = float2int8(_p0, _p0); + + pp[0] = vget_lane_s8(_r0, 0); + pp[1] = vget_lane_s8(_r0, 1); + pp[2] = vget_lane_s8(_r0, 2); + pp[3] = vget_lane_s8(_r0, 3); + + pp += 4; + p0++; + } + } + } +#endif // __ARM_NEON + for (; ii + 1 < max_ii; ii += 2) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep + k; + + const float scale0 = scales[ii]; + const float scale1 = scales[ii + 1]; + + // if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + float32x4_t _scale0 = vdupq_n_f32(scale0); + float32x4_t _scale1 = vdupq_n_f32(scale1); + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + A_hstep); + float32x4_t _p3 = vld1q_f32(p0 + A_hstep + 4); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale0); + _p2 = vmulq_f32(_p2, _scale1); + _p3 = vmulq_f32(_p3, _scale1); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p1, _p3); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p2)); + float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p2)); + float32x4_t _t2 = vcombine_f32(vget_low_f32(_p1), vget_low_f32(_p3)); + float32x4_t _t3 = vcombine_f32(vget_high_f32(_p1), vget_high_f32(_p3)); + int8x8_t _r0 = float2int8(_t0, _t1); + int8x8_t _r1 = float2int8(_t2, _t3); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r0); + vst1_s8(pp + 8, _r1); + + pp += 16; + p0 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + A_hstep); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p1)); + float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p1)); + int8x8_t _r0 = float2int8(_t0, _t1); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r0); + + pp += 8; + p0 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[1] * scale0); + pp[2] = float2int8(p0[A_hstep] * scale1); + pp[3] = float2int8(p0[A_hstep + 1] * scale1); + pp += 4; + p0 += 2; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[A_hstep] * scale1); + pp += 2; + p0++; + } + } + } + for (; ii < max_ii; ii += 1) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep + k; + + const float scale = scales[ii]; + + // if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + float32x4_t _scale = vdupq_n_f32(scale); + for (; kk + 15 < max_kk; kk += 16) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + + vst1_s8(pp, _r0); + + pp += 8; + p0 += 8; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale); + pp += 1; + p0++; + } + } + } +} + +static void transpose_compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales, float B_scale, Mat& out_descales, int i, int max_ii) +{ + const int elempack = A.elempack; + const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + const int K = A.dims == 3 ? A.c : A.h; + + // NCNN_LOGE("transpose_compute_A_tile_int8_scales %d %d", max_ii, elempack); + + const float v127_B_scale = 127.f * B_scale; + +#if __ARM_NEON +#if __aarch64__ + float32x4_t _v127 = vdupq_n_f32(127.f); + float32x4_t _v127_B_scale = vdupq_n_f32(v127_B_scale); +#endif +#endif + + float* ps = scales; + float* pods = out_descales; + +#if __ARM_NEON + if (elempack == 4) + { + int ii = 0; + for (; ii + 3 < max_ii; ii += 4) + { + const float* p0 = (const float*)A + (i + ii) * 4; + + float32x4_t _absmax0 = vdupq_n_f32(0.f); + float32x4_t _absmax1 = vdupq_n_f32(0.f); + float32x4_t _absmax2 = vdupq_n_f32(0.f); + float32x4_t _absmax3 = vdupq_n_f32(0.f); + for (int kk = 0; kk < K; kk++) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + _absmax2 = vmaxq_f32(_absmax2, vabsq_f32(_p2)); + _absmax3 = vmaxq_f32(_absmax3, vabsq_f32(_p3)); + p0 += A_hstep * 4; + } + float32x2_t _aa0 = vmax_f32(vget_low_f32(_absmax0), vget_high_f32(_absmax0)); + float32x2_t _aa1 = vmax_f32(vget_low_f32(_absmax1), vget_high_f32(_absmax1)); + float32x2_t _aa2 = vmax_f32(vget_low_f32(_absmax2), vget_high_f32(_absmax2)); + float32x2_t _aa3 = vmax_f32(vget_low_f32(_absmax3), vget_high_f32(_absmax3)); + float32x2_t _aa01 = vpmax_f32(_aa0, _aa1); + float32x2_t _aa23 = vpmax_f32(_aa2, _aa3); + float32x4_t _absmax = vcombine_f32(_aa01, _aa23); + +#if __aarch64__ + float32x4_t _scale = vdivq_f32(_v127, _absmax); + float32x4_t _out_descale = vdivq_f32(_absmax, _v127_B_scale); + + vst1q_f32(ps, _scale); + vst1q_f32(pods, _out_descale); +#else + float tmp[4]; + vst1q_f32(tmp, _absmax); + + ps[0] = 127.f / tmp[0]; + ps[1] = 127.f / tmp[1]; + ps[2] = 127.f / tmp[2]; + ps[3] = 127.f / tmp[3]; + + pods[0] = tmp[0] / v127_B_scale; + pods[1] = tmp[1] / v127_B_scale; + pods[2] = tmp[2] / v127_B_scale; + pods[3] = tmp[3] / v127_B_scale; + + // float32x4_t _recp_absmax = vrecpeq_f32(_absmax); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax, _recp_absmax), _recp_absmax); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax, _recp_absmax), _recp_absmax); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax, _recp_absmax), _recp_absmax); + // float32x4_t _scale = vmulq_f32(_v127, _recp_absmax); + // float32x4_t _out_descale = vmulq_f32(_absmax, _recp_v127_B_scale); +#endif + + ps += 4; + pods += 4; + } + for (; ii < max_ii; ii++) + { + const float* p0 = (const float*)A + (i + ii) * 4; + + float32x4_t _absmax0 = vdupq_n_f32(0.f); + float32x4_t _absmax1 = vdupq_n_f32(0.f); + float32x4_t _absmax2 = vdupq_n_f32(0.f); + float32x4_t _absmax3 = vdupq_n_f32(0.f); + int kk = 0; + for (; kk + 3 < K; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + A_hstep * 4); + float32x4_t _p2 = vld1q_f32(p0 + A_hstep * 8); + float32x4_t _p3 = vld1q_f32(p0 + A_hstep * 12); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + _absmax2 = vmaxq_f32(_absmax2, vabsq_f32(_p2)); + _absmax3 = vmaxq_f32(_absmax3, vabsq_f32(_p3)); + p0 += A_hstep * 16; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax2); + _absmax1 = vmaxq_f32(_absmax1, _absmax3); + for (; kk + 1 < K; kk += 2) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + A_hstep * 4); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + p0 += A_hstep * 8; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax1); + for (; kk < K; kk++) + { + float32x4_t _p = vld1q_f32(p0); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p)); + p0 += A_hstep * 4; + } + float32x2_t _aa = vmax_f32(vget_low_f32(_absmax0), vget_high_f32(_absmax0)); + float absmax = std::max(vget_lane_f32(_aa, 0), vget_lane_f32(_aa, 1)); + + ps[0] = 127.f / absmax; + pods[0] = absmax / v127_B_scale; + ps++; + pods++; + } + } +#endif // __ARM_NEON + if (elempack == 1) + { + int ii = 0; +#if __ARM_NEON + for (; ii + 3 < max_ii; ii += 4) + { + const float* p0 = (const float*)A + (i + ii); + + float32x4_t _absmax0 = vdupq_n_f32(0.f); + float32x4_t _absmax1 = vdupq_n_f32(0.f); + float32x4_t _absmax2 = vdupq_n_f32(0.f); + float32x4_t _absmax3 = vdupq_n_f32(0.f); + int kk = 0; + for (; kk + 3 < K; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + A_hstep); + float32x4_t _p2 = vld1q_f32(p0 + A_hstep * 2); + float32x4_t _p3 = vld1q_f32(p0 + A_hstep * 3); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + _absmax2 = vmaxq_f32(_absmax2, vabsq_f32(_p2)); + _absmax3 = vmaxq_f32(_absmax3, vabsq_f32(_p3)); + p0 += A_hstep * 4; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax2); + _absmax1 = vmaxq_f32(_absmax1, _absmax3); + for (; kk + 1 < K; kk += 2) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + A_hstep); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + p0 += A_hstep * 2; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax1); + for (; kk < K; kk++) + { + float32x4_t _p = vld1q_f32(p0); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p)); + p0 += A_hstep; + } + +#if __aarch64__ + float32x4_t _scale = vdivq_f32(_v127, _absmax0); + float32x4_t _out_descale = vdivq_f32(_absmax0, _v127_B_scale); + + vst1q_f32(ps, _scale); + vst1q_f32(pods, _out_descale); +#else + float tmp[4]; + vst1q_f32(tmp, _absmax0); + + ps[0] = 127.f / tmp[0]; + ps[1] = 127.f / tmp[1]; + ps[2] = 127.f / tmp[2]; + ps[3] = 127.f / tmp[3]; + + pods[0] = tmp[0] / v127_B_scale; + pods[1] = tmp[1] / v127_B_scale; + pods[2] = tmp[2] / v127_B_scale; + pods[3] = tmp[3] / v127_B_scale; + + // float32x4_t _recp_absmax = vrecpeq_f32(_absmax0); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax0, _recp_absmax), _recp_absmax); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax0, _recp_absmax), _recp_absmax); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax0, _recp_absmax), _recp_absmax); + // float32x4_t _scale = vmulq_f32(_v127, _recp_absmax); + // float32x4_t _out_descale = vmulq_f32(_absmax0, _recp_v127_B_scale); +#endif + + ps += 4; + pods += 4; + } + for (; ii + 1 < max_ii; ii += 2) + { + const float* p0 = (const float*)A + (i + ii); + + float32x2_t _absmax0 = vdup_n_f32(0.f); + float32x2_t _absmax1 = vdup_n_f32(0.f); + float32x2_t _absmax2 = vdup_n_f32(0.f); + float32x2_t _absmax3 = vdup_n_f32(0.f); + int kk = 0; + for (; kk + 3 < K; kk += 4) + { + float32x2_t _p0 = vld1_f32(p0); + float32x2_t _p1 = vld1_f32(p0 + A_hstep); + float32x2_t _p2 = vld1_f32(p0 + A_hstep * 2); + float32x2_t _p3 = vld1_f32(p0 + A_hstep * 3); + _absmax0 = vmax_f32(_absmax0, vabs_f32(_p0)); + _absmax1 = vmax_f32(_absmax1, vabs_f32(_p1)); + _absmax2 = vmax_f32(_absmax2, vabs_f32(_p2)); + _absmax3 = vmax_f32(_absmax3, vabs_f32(_p3)); + p0 += A_hstep * 4; + } + _absmax0 = vmax_f32(_absmax0, _absmax2); + _absmax1 = vmax_f32(_absmax1, _absmax3); + for (; kk + 1 < K; kk += 2) + { + float32x2_t _p0 = vld1_f32(p0); + float32x2_t _p1 = vld1_f32(p0 + A_hstep); + _absmax0 = vmax_f32(_absmax0, vabs_f32(_p0)); + _absmax1 = vmax_f32(_absmax1, vabs_f32(_p1)); + p0 += A_hstep * 2; + } + _absmax0 = vmax_f32(_absmax0, _absmax1); + for (; kk < K; kk++) + { + float32x2_t _p = vld1_f32(p0); + _absmax0 = vmax_f32(_absmax0, vabs_f32(_p)); + p0 += A_hstep; + } + +#if __aarch64__ + float32x2_t _scale = vdiv_f32(vget_low_f32(_v127), _absmax0); + float32x2_t _out_descale = vdiv_f32(_absmax0, vget_low_f32(_v127_B_scale)); + + vst1_f32(ps, _scale); + vst1_f32(pods, _out_descale); +#else + float tmp[2]; + vst1_f32(tmp, _absmax0); + + ps[0] = 127.f / tmp[0]; + ps[1] = 127.f / tmp[1]; + + pods[0] = tmp[0] / v127_B_scale; + pods[1] = tmp[1] / v127_B_scale; + + // float32x2_t _recp_absmax = vrecpe_f32(_absmax0); + // _recp_absmax = vmul_f32(vrecps_f32(_absmax0, _recp_absmax), _recp_absmax); + // _recp_absmax = vmul_f32(vrecps_f32(_absmax0, _recp_absmax), _recp_absmax); + // _recp_absmax = vmul_f32(vrecps_f32(_absmax0, _recp_absmax), _recp_absmax); + // float32x2_t _scale = vmul_f32(vget_low_f32(_v127), _recp_absmax); + // float32x2_t _out_descale = vmul_f32(_absmax0, vget_low_f32(_recp_v127_B_scale)); +#endif + + ps += 2; + pods += 2; + } +#endif // __ARM_NEON + for (; ii < max_ii; ii++) + { + const float* p0 = (const float*)A + (i + ii); + + float absmax = 0.f; + for (int kk = 0; kk < K; kk++) + { + absmax = std::max(absmax, (float)fabsf(p0[0])); + p0 += A_hstep; + } + + ps[0] = 127.f / absmax; + pods[0] = absmax / v127_B_scale; + ps++; + pods++; + } + } +} + +static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + transpose_pack_A_tile_fp32_to_int8_i8mm(A, AT, i, max_ii, k, max_kk, scales); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + transpose_pack_A_tile_fp32_to_int8_asimddp(A, AT, i, max_ii, k, max_kk, scales); + return; + } +#endif + + const int elempack = A.elempack; + const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + + // NCNN_LOGE("transpose_pack_A_tile_fp32_to_int8 %d %d", max_ii, elempack); + + signed char* pp = AT; + + int ii = 0; +#if __ARM_NEON + for (; ii + 7 < max_ii; ii += 8) + { + const float* p0 = (const float*)A + k * A_hstep + (i + ii) * elempack; + + float32x4_t _scale0 = vld1q_f32((const float*)scales + ii); + float32x4_t _scale1 = vld1q_f32((const float*)scales + ii + 4); + + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + float32x4_t _p4 = vld1q_f32(p0 + 16); + float32x4_t _p5 = vld1q_f32(p0 + 20); + float32x4_t _p6 = vld1q_f32(p0 + 24); + float32x4_t _p7 = vld1q_f32(p0 + 28); + float32x4_t _p8 = vld1q_f32(p0 + A_hstep * 4); + float32x4_t _p9 = vld1q_f32(p0 + A_hstep * 4 + 4); + float32x4_t _pa = vld1q_f32(p0 + A_hstep * 4 + 8); + float32x4_t _pb = vld1q_f32(p0 + A_hstep * 4 + 12); + float32x4_t _pc = vld1q_f32(p0 + A_hstep * 4 + 16); + float32x4_t _pd = vld1q_f32(p0 + A_hstep * 4 + 20); + float32x4_t _pe = vld1q_f32(p0 + A_hstep * 4 + 24); + float32x4_t _pf = vld1q_f32(p0 + A_hstep * 4 + 28); + +#if __aarch64__ + _p0 = vmulq_laneq_f32(_p0, _scale0, 0); + _p1 = vmulq_laneq_f32(_p1, _scale0, 1); + _p2 = vmulq_laneq_f32(_p2, _scale0, 2); + _p3 = vmulq_laneq_f32(_p3, _scale0, 3); + _p4 = vmulq_laneq_f32(_p4, _scale1, 0); + _p5 = vmulq_laneq_f32(_p5, _scale1, 1); + _p6 = vmulq_laneq_f32(_p6, _scale1, 2); + _p7 = vmulq_laneq_f32(_p7, _scale1, 3); + _p8 = vmulq_laneq_f32(_p8, _scale0, 0); + _p9 = vmulq_laneq_f32(_p9, _scale0, 1); + _pa = vmulq_laneq_f32(_pa, _scale0, 2); + _pb = vmulq_laneq_f32(_pb, _scale0, 3); + _pc = vmulq_laneq_f32(_pc, _scale1, 0); + _pd = vmulq_laneq_f32(_pd, _scale1, 1); + _pe = vmulq_laneq_f32(_pe, _scale1, 2); + _pf = vmulq_laneq_f32(_pf, _scale1, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale0), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale0), 1); + _p2 = vmulq_lane_f32(_p2, vget_high_f32(_scale0), 0); + _p3 = vmulq_lane_f32(_p3, vget_high_f32(_scale0), 1); + _p4 = vmulq_lane_f32(_p4, vget_low_f32(_scale1), 0); + _p5 = vmulq_lane_f32(_p5, vget_low_f32(_scale1), 1); + _p6 = vmulq_lane_f32(_p6, vget_high_f32(_scale1), 0); + _p7 = vmulq_lane_f32(_p7, vget_high_f32(_scale1), 1); + _p8 = vmulq_lane_f32(_p8, vget_low_f32(_scale0), 0); + _p9 = vmulq_lane_f32(_p9, vget_low_f32(_scale0), 1); + _pa = vmulq_lane_f32(_pa, vget_high_f32(_scale0), 0); + _pb = vmulq_lane_f32(_pb, vget_high_f32(_scale0), 1); + _pc = vmulq_lane_f32(_pc, vget_low_f32(_scale1), 0); + _pd = vmulq_lane_f32(_pd, vget_low_f32(_scale1), 1); + _pe = vmulq_lane_f32(_pe, vget_high_f32(_scale1), 0); + _pf = vmulq_lane_f32(_pf, vget_high_f32(_scale1), 1); +#endif + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p8); + int8x8_t _r1 = float2int8(_p1, _p9); + int8x8_t _r2 = float2int8(_p2, _pa); + int8x8_t _r3 = float2int8(_p3, _pb); + int8x8_t _r4 = float2int8(_p4, _pc); + int8x8_t _r5 = float2int8(_p5, _pd); + int8x8_t _r6 = float2int8(_p6, _pe); + int8x8_t _r7 = float2int8(_p7, _pf); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); + + int16x8_t _r01 = vreinterpretq_s16_s8(vcombine_s8(_r0, _r1)); + int16x8_t _r23 = vreinterpretq_s16_s8(vcombine_s8(_r2, _r3)); + int16x8_t _r45 = vreinterpretq_s16_s8(vcombine_s8(_r4, _r5)); + int16x8_t _r67 = vreinterpretq_s16_s8(vcombine_s8(_r6, _r7)); + int16x8x2_t _rr0 = vuzpq_s16(_r01, _r23); + int16x8x2_t _rr1 = vuzpq_s16(_r45, _r67); + + vst1q_s8(pp, vreinterpretq_s8_s16(_rr0.val[0])); + vst1q_s8(pp + 16, vreinterpretq_s8_s16(_rr0.val[1])); + vst1q_s8(pp + 32, vreinterpretq_s8_s16(_rr1.val[0])); + vst1q_s8(pp + 48, vreinterpretq_s8_s16(_rr1.val[1])); +#endif // __ARM_FEATURE_DOTPROD + + pp += 64; + p0 += A_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + float32x4_t _p4 = vld1q_f32(p0 + 16); + float32x4_t _p5 = vld1q_f32(p0 + 20); + float32x4_t _p6 = vld1q_f32(p0 + 24); + float32x4_t _p7 = vld1q_f32(p0 + 28); + +#if __aarch64__ + _p0 = vmulq_laneq_f32(_p0, _scale0, 0); + _p1 = vmulq_laneq_f32(_p1, _scale0, 1); + _p2 = vmulq_laneq_f32(_p2, _scale0, 2); + _p3 = vmulq_laneq_f32(_p3, _scale0, 3); + _p4 = vmulq_laneq_f32(_p4, _scale1, 0); + _p5 = vmulq_laneq_f32(_p5, _scale1, 1); + _p6 = vmulq_laneq_f32(_p6, _scale1, 2); + _p7 = vmulq_laneq_f32(_p7, _scale1, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale0), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale0), 1); + _p2 = vmulq_lane_f32(_p2, vget_high_f32(_scale0), 0); + _p3 = vmulq_lane_f32(_p3, vget_high_f32(_scale0), 1); + _p4 = vmulq_lane_f32(_p4, vget_low_f32(_scale1), 0); + _p5 = vmulq_lane_f32(_p5, vget_low_f32(_scale1), 1); + _p6 = vmulq_lane_f32(_p6, vget_high_f32(_scale1), 0); + _p7 = vmulq_lane_f32(_p7, vget_high_f32(_scale1), 1); +#endif + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + +#if __ARM_FEATURE_DOTPROD + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); +#else // __ARM_FEATURE_DOTPROD + int16x8_t _r01 = vreinterpretq_s16_s8(vcombine_s8(_r0, _r1)); + int16x8_t _r23 = vreinterpretq_s16_s8(vcombine_s8(_r2, _r3)); + int16x8x2_t _rr = vuzpq_s16(_r01, _r23); + + vst1q_s8(pp, vreinterpretq_s8_s16(_rr.val[0])); + vst1q_s8(pp + 16, vreinterpretq_s8_s16(_rr.val[1])); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += A_hstep * 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + A_hstep); + float32x4_t _p3 = vld1q_f32(p0 + A_hstep + 4); + float32x4_t _p4 = vld1q_f32(p0 + A_hstep * 2); + float32x4_t _p5 = vld1q_f32(p0 + A_hstep * 2 + 4); + float32x4_t _p6 = vld1q_f32(p0 + A_hstep * 3); + float32x4_t _p7 = vld1q_f32(p0 + A_hstep * 3 + 4); + float32x4_t _p8 = vld1q_f32(p0 + A_hstep * 4); + float32x4_t _p9 = vld1q_f32(p0 + A_hstep * 4 + 4); + float32x4_t _pa = vld1q_f32(p0 + A_hstep * 5); + float32x4_t _pb = vld1q_f32(p0 + A_hstep * 5 + 4); + float32x4_t _pc = vld1q_f32(p0 + A_hstep * 6); + float32x4_t _pd = vld1q_f32(p0 + A_hstep * 6 + 4); + float32x4_t _pe = vld1q_f32(p0 + A_hstep * 7); + float32x4_t _pf = vld1q_f32(p0 + A_hstep * 7 + 4); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + _p2 = vmulq_f32(_p2, _scale0); + _p3 = vmulq_f32(_p3, _scale1); + _p4 = vmulq_f32(_p4, _scale0); + _p5 = vmulq_f32(_p5, _scale1); + _p6 = vmulq_f32(_p6, _scale0); + _p7 = vmulq_f32(_p7, _scale1); + _p8 = vmulq_f32(_p8, _scale0); + _p9 = vmulq_f32(_p9, _scale1); + _pa = vmulq_f32(_pa, _scale0); + _pb = vmulq_f32(_pb, _scale1); + _pc = vmulq_f32(_pc, _scale0); + _pd = vmulq_f32(_pd, _scale1); + _pe = vmulq_f32(_pe, _scale0); + _pf = vmulq_f32(_pf, _scale1); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8x2_t _r04 = vzip_s8(_r0, _r4); + int8x8x2_t _r15 = vzip_s8(_r1, _r5); + int8x8x2_t _r26 = vzip_s8(_r2, _r6); + int8x8x2_t _r37 = vzip_s8(_r3, _r7); + int8x16x4_t _r0123; + _r0123.val[0] = vcombine_s8(_r04.val[0], _r04.val[1]); + _r0123.val[1] = vcombine_s8(_r15.val[0], _r15.val[1]); + _r0123.val[2] = vcombine_s8(_r26.val[0], _r26.val[1]); + _r0123.val[3] = vcombine_s8(_r37.val[0], _r37.val[1]); + + vst4q_s8(pp, _r0123); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8x4_t _r0123; + _r0123.val[0] = _r0; + _r0123.val[1] = _r1; + _r0123.val[2] = _r2; + _r0123.val[3] = _r3; + int8x8x4_t _r4567; + _r4567.val[0] = _r4; + _r4567.val[1] = _r5; + _r4567.val[2] = _r6; + _r4567.val[3] = _r7; + + vst4_s8(pp, _r0123); + vst4_s8(pp + 32, _r4567); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(_r0, _r2); + _r01.val[1] = vcombine_s8(_r1, _r3); + int8x16x2_t _r23; + _r23.val[0] = vcombine_s8(_r4, _r6); + _r23.val[1] = vcombine_s8(_r5, _r7); + + vst2q_s8(pp, _r01); + vst2q_s8(pp + 32, _r23); +#endif // __ARM_FEATURE_DOTPROD + + pp += 64; + p0 += A_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + A_hstep); + float32x4_t _p3 = vld1q_f32(p0 + A_hstep + 4); + float32x4_t _p4 = vld1q_f32(p0 + A_hstep * 2); + float32x4_t _p5 = vld1q_f32(p0 + A_hstep * 2 + 4); + float32x4_t _p6 = vld1q_f32(p0 + A_hstep * 3); + float32x4_t _p7 = vld1q_f32(p0 + A_hstep * 3 + 4); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + _p2 = vmulq_f32(_p2, _scale0); + _p3 = vmulq_f32(_p3, _scale1); + _p4 = vmulq_f32(_p4, _scale0); + _p5 = vmulq_f32(_p5, _scale1); + _p6 = vmulq_f32(_p6, _scale0); + _p7 = vmulq_f32(_p7, _scale1); + +#if __ARM_FEATURE_DOTPROD + int8x8x4_t _r0123; + _r0123.val[0] = float2int8(_p0, _p1); + _r0123.val[1] = float2int8(_p2, _p3); + _r0123.val[2] = float2int8(_p4, _p5); + _r0123.val[3] = float2int8(_p6, _p7); + + vst4_s8(pp, _r0123); +#else // __ARM_FEATURE_DOTPROD + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p1), float2int8(_p4, _p5)); + _r01.val[1] = vcombine_s8(float2int8(_p2, _p3), float2int8(_p6, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += A_hstep * 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + A_hstep); + float32x4_t _p3 = vld1q_f32(p0 + A_hstep + 4); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + _p2 = vmulq_f32(_p2, _scale0); + _p3 = vmulq_f32(_p3, _scale1); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p1); + _r01.val[1] = float2int8(_p2, _p3); + + vst2_s8(pp, _r01); + + pp += 16; + p0 += A_hstep * 2; + } + for (; kk < max_kk; kk++) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += A_hstep; + } + } + } + for (; ii + 3 < max_ii; ii += 4) + { + const float* p0 = (const float*)A + k * A_hstep + (i + ii) * elempack; + + float32x4_t _scale = vld1q_f32((const float*)scales + ii); + + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + float32x4_t _p4 = vld1q_f32(p0 + A_hstep * 4); + float32x4_t _p5 = vld1q_f32(p0 + A_hstep * 4 + 4); + float32x4_t _p6 = vld1q_f32(p0 + A_hstep * 4 + 8); + float32x4_t _p7 = vld1q_f32(p0 + A_hstep * 4 + 12); + +#if __aarch64__ + _p0 = vmulq_laneq_f32(_p0, _scale, 0); + _p1 = vmulq_laneq_f32(_p1, _scale, 1); + _p2 = vmulq_laneq_f32(_p2, _scale, 2); + _p3 = vmulq_laneq_f32(_p3, _scale, 3); + _p4 = vmulq_laneq_f32(_p4, _scale, 0); + _p5 = vmulq_laneq_f32(_p5, _scale, 1); + _p6 = vmulq_laneq_f32(_p6, _scale, 2); + _p7 = vmulq_laneq_f32(_p7, _scale, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale), 1); + _p2 = vmulq_lane_f32(_p2, vget_high_f32(_scale), 0); + _p3 = vmulq_lane_f32(_p3, vget_high_f32(_scale), 1); + _p4 = vmulq_lane_f32(_p4, vget_low_f32(_scale), 0); + _p5 = vmulq_lane_f32(_p5, vget_low_f32(_scale), 1); + _p6 = vmulq_lane_f32(_p6, vget_high_f32(_scale), 0); + _p7 = vmulq_lane_f32(_p7, vget_high_f32(_scale), 1); +#endif + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p4); + int8x8_t _r1 = float2int8(_p1, _p5); + int8x8_t _r2 = float2int8(_p2, _p6); + int8x8_t _r3 = float2int8(_p3, _p7); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p4, _p5)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p6, _p7)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r2 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + + pp += 32; + p0 += A_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + +#if __aarch64__ + _p0 = vmulq_laneq_f32(_p0, _scale, 0); + _p1 = vmulq_laneq_f32(_p1, _scale, 1); + _p2 = vmulq_laneq_f32(_p2, _scale, 2); + _p3 = vmulq_laneq_f32(_p3, _scale, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale), 1); + _p2 = vmulq_lane_f32(_p2, vget_high_f32(_scale), 0); + _p3 = vmulq_lane_f32(_p3, vget_high_f32(_scale), 1); +#endif + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += A_hstep * 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + A_hstep); + float32x4_t _p2 = vld1q_f32(p0 + A_hstep * 2); + float32x4_t _p3 = vld1q_f32(p0 + A_hstep * 3); + float32x4_t _p4 = vld1q_f32(p0 + A_hstep * 4); + float32x4_t _p5 = vld1q_f32(p0 + A_hstep * 5); + float32x4_t _p6 = vld1q_f32(p0 + A_hstep * 6); + float32x4_t _p7 = vld1q_f32(p0 + A_hstep * 7); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + float32x4x2_t _p04 = vzipq_f32(_p0, _p4); + float32x4x2_t _p15 = vzipq_f32(_p1, _p5); + float32x4x2_t _p26 = vzipq_f32(_p2, _p6); + float32x4x2_t _p37 = vzipq_f32(_p3, _p7); + int8x8x4_t _r0123; + _r0123.val[0] = float2int8(_p04.val[0], _p04.val[1]); + _r0123.val[1] = float2int8(_p15.val[0], _p15.val[1]); + _r0123.val[2] = float2int8(_p26.val[0], _p26.val[1]); + _r0123.val[3] = float2int8(_p37.val[0], _p37.val[1]); + + vst4_s8(pp, _r0123); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8x4_t _r0123; + _r0123.val[0] = float2int8(_p0, _p4); + _r0123.val[1] = float2int8(_p1, _p5); + _r0123.val[2] = float2int8(_p2, _p6); + _r0123.val[3] = float2int8(_p3, _p7); + + vst4_s8(pp, _r0123); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p2), float2int8(_p4, _p6)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p3), float2int8(_p5, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += A_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + A_hstep); + float32x4_t _p2 = vld1q_f32(p0 + A_hstep * 2); + float32x4_t _p3 = vld1q_f32(p0 + A_hstep * 3); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + +#if __ARM_FEATURE_DOTPROD + transpose4x4_ps(_p0, _p1, _p2, _p3); + + int8x8_t _r01 = float2int8(_p0, _p1); + int8x8_t _r23 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r01, _r23)); +#else // __ARM_FEATURE_DOTPROD + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p2); + _r01.val[1] = float2int8(_p1, _p3); + + vst2_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 16; + p0 += A_hstep * 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + A_hstep); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + float32x4x2_t _p01 = vzipq_f32(_p0, _p1); + + int8x8_t _r01 = float2int8(_p01.val[0], _p01.val[1]); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += A_hstep * 2; + } + for (; kk < max_kk; kk++) + { + float32x4_t _p0 = vld1q_f32(p0); + _p0 = vmulq_f32(_p0, _scale); + int8x8_t _r0 = float2int8(_p0, _p0); + + pp[0] = vget_lane_s8(_r0, 0); + pp[1] = vget_lane_s8(_r0, 1); + pp[2] = vget_lane_s8(_r0, 2); + pp[3] = vget_lane_s8(_r0, 3); + pp += 4; + p0 += A_hstep; + } + } + } +#endif // __ARM_NEON + for (; ii + 1 < max_ii; ii += 2) + { + const float* p0 = (const float*)A + k * A_hstep + (i + ii) * elempack; + + const float scale0 = scales[ii]; + const float scale1 = scales[ii + 1]; + +#if __ARM_NEON + float32x4_t _scale0 = vdupq_n_f32(scale0); + float32x4_t _scale1 = vdupq_n_f32(scale1); + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + A_hstep * 4); + float32x4_t _p3 = vld1q_f32(p0 + A_hstep * 4 + 4); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + _p2 = vmulq_f32(_p2, _scale0); + _p3 = vmulq_f32(_p3, _scale1); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p1, _p3); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p1, _p3)); + int16x4x2_t _t01 = vzip_s16(_t0, _t1); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += A_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r01 = float2int8(_p0, _p1); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p1)); + float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p1)); + int8x8_t _r01 = float2int8(_t0, _t1); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r01); + + pp += 8; + p0 += A_hstep * 4; + } + } +#endif // __ARM_NEON + if (elempack == 1) + { + const float* p0 = (const float*)A + k * A_hstep + (i + ii); + + int kk = 0; +#if __ARM_NEON + float32x4_t _scale = vzipq_f32(_scale0, _scale1).val[0]; + for (; kk + 7 < max_kk; kk += 8) + { + float32x2_t _p0 = vld1_f32(p0); + float32x2_t _p1 = vld1_f32(p0 + A_hstep); + float32x2_t _p2 = vld1_f32(p0 + A_hstep * 2); + float32x2_t _p3 = vld1_f32(p0 + A_hstep * 3); + float32x2_t _p4 = vld1_f32(p0 + A_hstep * 4); + float32x2_t _p5 = vld1_f32(p0 + A_hstep * 5); + float32x2_t _p6 = vld1_f32(p0 + A_hstep * 6); + float32x2_t _p7 = vld1_f32(p0 + A_hstep * 7); + +#if __ARM_FEATURE_DOTPROD + float32x4_t _p01 = vcombine_f32(_p0, _p1); + float32x4_t _p23 = vcombine_f32(_p2, _p3); + float32x4_t _p45 = vcombine_f32(_p4, _p5); + float32x4_t _p67 = vcombine_f32(_p6, _p7); + + _p01 = vmulq_f32(_p01, _scale); + _p23 = vmulq_f32(_p23, _scale); + _p45 = vmulq_f32(_p45, _scale); + _p67 = vmulq_f32(_p67, _scale); + + int8x8_t _r0 = float2int8(_p01, _p23); + int8x8_t _r1 = float2int8(_p45, _p67); + +#if __ARM_FEATURE_MATMUL_INT8 + int8x8x2_t _r01 = vuzp_s8(_r0, _r1); + + vst1q_s8(pp, vcombine_s8(_r01.val[0], _r01.val[1])); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8x2_t _r01 = vtrn_s8(_r0, _r1); + int8x8x2_t _rr01 = vuzp_s8(_r01.val[0], _r01.val[1]); + + vst1q_s8(pp, vcombine_s8(_rr01.val[0], _rr01.val[1])); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + float32x4_t _p02 = vcombine_f32(_p0, _p2); + float32x4_t _p46 = vcombine_f32(_p4, _p6); + float32x4_t _p13 = vcombine_f32(_p1, _p3); + float32x4_t _p57 = vcombine_f32(_p5, _p7); + + _p02 = vmulq_f32(_p02, _scale); + _p46 = vmulq_f32(_p46, _scale); + _p13 = vmulq_f32(_p13, _scale); + _p57 = vmulq_f32(_p57, _scale); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p02, _p46); + _r01.val[1] = float2int8(_p13, _p57); + + vst2_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 16; + p0 += A_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x2_t _p0 = vld1_f32(p0); + float32x2_t _p1 = vld1_f32(p0 + A_hstep); + float32x2_t _p2 = vld1_f32(p0 + A_hstep * 2); + float32x2_t _p3 = vld1_f32(p0 + A_hstep * 3); + +#if __ARM_FEATURE_DOTPROD + float32x4_t _p01 = vcombine_f32(_p0, _p1); + float32x4_t _p23 = vcombine_f32(_p2, _p3); + + _p01 = vmulq_f32(_p01, _scale); + _p23 = vmulq_f32(_p23, _scale); + + float32x4x2_t _pp = vuzpq_f32(_p01, _p23); + int8x8_t _r01 = float2int8(_pp.val[0], _pp.val[1]); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _p02 = vcombine_f32(_p0, _p2); + float32x4_t _p13 = vcombine_f32(_p1, _p3); + + _p02 = vmulq_f32(_p02, _scale); + _p13 = vmulq_f32(_p13, _scale); + + float32x4x2_t _pp = vzipq_f32(_p02, _p13); + int8x8_t _r01 = float2int8(_pp.val[0], _pp.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r01); + + pp += 8; + p0 += A_hstep * 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[A_hstep + 0] * scale0); + pp[2] = float2int8(p0[1] * scale1); + pp[3] = float2int8(p0[A_hstep + 1] * scale1); + pp += 4; + p0 += A_hstep * 2; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[1] * scale1); + pp += 2; + p0 += A_hstep; + } + } + } + for (; ii < max_ii; ii += 1) + { + const float* p0 = (const float*)A + k * A_hstep + (i + ii) * elempack; + + const float scale = scales[ii]; + +#if __ARM_NEON + float32x4_t _scale = vdupq_n_f32(scale); + if (elempack == 4) + { + int kk = 0; + for (; kk + 15 < max_kk; kk += 16) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + A_hstep * 4); + float32x4_t _p2 = vld1q_f32(p0 + A_hstep * 8); + float32x4_t _p3 = vld1q_f32(p0 + A_hstep * 12); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + int8x8_t _r23 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r01, _r23)); + + pp += 16; + p0 += A_hstep * 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + A_hstep * 4); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += A_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale); + pp[1] = float2int8(p0[1] * scale); + pp[2] = float2int8(p0[2] * scale); + pp[3] = float2int8(p0[3] * scale); + pp += 4; + p0 += A_hstep * 4; + } + } +#endif // __ARM_NEON + if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + for (; kk + 15 < max_kk; kk += 16) + { + float32x4_t _p0 = float32x4_t(); + float32x4_t _p1 = float32x4_t(); + float32x4_t _p2 = float32x4_t(); + float32x4_t _p3 = float32x4_t(); + _p0 = vsetq_lane_f32(p0[0], _p0, 0); + _p0 = vsetq_lane_f32(p0[A_hstep], _p0, 1); + _p0 = vsetq_lane_f32(p0[A_hstep * 2], _p0, 2); + _p0 = vsetq_lane_f32(p0[A_hstep * 3], _p0, 3); + _p1 = vsetq_lane_f32(p0[A_hstep * 4], _p1, 0); + _p1 = vsetq_lane_f32(p0[A_hstep * 5], _p1, 1); + _p1 = vsetq_lane_f32(p0[A_hstep * 6], _p1, 2); + _p1 = vsetq_lane_f32(p0[A_hstep * 7], _p1, 3); + _p2 = vsetq_lane_f32(p0[A_hstep * 8], _p2, 0); + _p2 = vsetq_lane_f32(p0[A_hstep * 9], _p2, 1); + _p2 = vsetq_lane_f32(p0[A_hstep * 10], _p2, 2); + _p2 = vsetq_lane_f32(p0[A_hstep * 11], _p2, 3); + _p3 = vsetq_lane_f32(p0[A_hstep * 12], _p3, 0); + _p3 = vsetq_lane_f32(p0[A_hstep * 13], _p3, 1); + _p3 = vsetq_lane_f32(p0[A_hstep * 14], _p3, 2); + _p3 = vsetq_lane_f32(p0[A_hstep * 15], _p3, 3); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + int8x8_t _r23 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r01, _r23)); + + pp += 16; + p0 += A_hstep * 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = float32x4_t(); + float32x4_t _p1 = float32x4_t(); + _p0 = vsetq_lane_f32(p0[0], _p0, 0); + _p0 = vsetq_lane_f32(p0[A_hstep], _p0, 1); + _p0 = vsetq_lane_f32(p0[A_hstep * 2], _p0, 2); + _p0 = vsetq_lane_f32(p0[A_hstep * 3], _p0, 3); + _p1 = vsetq_lane_f32(p0[A_hstep * 4], _p1, 0); + _p1 = vsetq_lane_f32(p0[A_hstep * 5], _p1, 1); + _p1 = vsetq_lane_f32(p0[A_hstep * 6], _p1, 2); + _p1 = vsetq_lane_f32(p0[A_hstep * 7], _p1, 3); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += A_hstep * 8; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale); + pp += 1; + p0 += A_hstep; + } + } + } +} + +static void compute_B_fp32_int8_scale(const Mat& B, float& scale) +{ + float absmax = 0.f; +#if __ARM_NEON + float32x4_t _absmax = vdupq_n_f32(0.f); +#endif + for (int i = 0; i < (B.dims == 3 ? B.c : B.h); i++) + { + const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; + const float* ptr = (const float*)B + i * B_hstep * B.elempack; + + const int size = B.w * B.elempack; + + int j = 0; +#if __ARM_NEON + for (; j + 3 < size; j += 4) + { + float32x4_t _p = vld1q_f32(ptr); + _absmax = vmaxq_f32(_absmax, vabsq_f32(_p)); + ptr += 4; + } +#endif + for (; j < size; j++) + { + absmax = std::max(absmax, (float)fabsf(ptr[0])); + ptr++; + } + } +#if __ARM_NEON + float32x2_t _aa = vmax_f32(vget_low_f32(_absmax), vget_high_f32(_absmax)); + absmax = std::max(absmax, std::max(vget_lane_f32(_aa, 0), vget_lane_f32(_aa, 1))); +#endif + + scale = absmax == 0.f ? 1.f : 127.f / absmax; +} + +static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + pack_B_tile_fp32_to_int8_i8mm(B, BT, j, max_jj, k, max_kk, scale); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + pack_B_tile_fp32_to_int8_asimddp(B, BT, j, max_jj, k, max_kk, scale); + return; + } +#endif + + const int elempack = B.elempack; + const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; + + // NCNN_LOGE("pack_B_tile_fp32_to_int8 %d %d %d", max_jj, max_kk, elempack); + + signed char* pp = BT; + +#if __ARM_NEON + float32x4_t _scale = vdupq_n_f32(scale); +#endif + + int jj = 0; +#if __ARM_NEON +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k * elempack; + + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { +#if __ARM_FEATURE_DOTPROD + float32x4x4_t _p = vld4q_f32(p0); + float32x4x4_t _q = vld4q_f32(p0 + 16); + float32x4x4_t _r = vld4q_f32(p0 + B_hstep * 4); + float32x4x4_t _s = vld4q_f32(p0 + B_hstep * 4 + 16); + + float32x4_t _p0 = vmulq_f32(_p.val[0], _scale); + float32x4_t _p1 = vmulq_f32(_p.val[1], _scale); + float32x4_t _p2 = vmulq_f32(_p.val[2], _scale); + float32x4_t _p3 = vmulq_f32(_p.val[3], _scale); + float32x4_t _p4 = vmulq_f32(_q.val[0], _scale); + float32x4_t _p5 = vmulq_f32(_q.val[1], _scale); + float32x4_t _p6 = vmulq_f32(_q.val[2], _scale); + float32x4_t _p7 = vmulq_f32(_q.val[3], _scale); + float32x4_t _p8 = vmulq_f32(_r.val[0], _scale); + float32x4_t _p9 = vmulq_f32(_r.val[1], _scale); + float32x4_t _pa = vmulq_f32(_r.val[2], _scale); + float32x4_t _pb = vmulq_f32(_r.val[3], _scale); + float32x4_t _pc = vmulq_f32(_s.val[0], _scale); + float32x4_t _pd = vmulq_f32(_s.val[1], _scale); + float32x4_t _pe = vmulq_f32(_s.val[2], _scale); + float32x4_t _pf = vmulq_f32(_s.val[3], _scale); + +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p4); + int8x8_t _r1 = float2int8(_p1, _p5); + int8x8_t _r2 = float2int8(_p2, _p6); + int8x8_t _r3 = float2int8(_p3, _p7); + int8x8_t _r4 = float2int8(_p8, _pc); + int8x8_t _r5 = float2int8(_p9, _pd); + int8x8_t _r6 = float2int8(_pa, _pe); + int8x8_t _r7 = float2int8(_pb, _pf); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p8, _p9); + int8x8_t _r3 = float2int8(_pa, _pb); + int8x8_t _r4 = float2int8(_p4, _p5); + int8x8_t _r5 = float2int8(_p6, _p7); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); +#endif // __ARM_FEATURE_MATMUL_INT8 + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + float32x4_t _p4 = vld1q_f32(p0 + 16); + float32x4_t _p5 = vld1q_f32(p0 + 20); + float32x4_t _p6 = vld1q_f32(p0 + 24); + float32x4_t _p7 = vld1q_f32(p0 + 28); + float32x4_t _p8 = vld1q_f32(p0 + B_hstep * 4); + float32x4_t _p9 = vld1q_f32(p0 + B_hstep * 4 + 4); + float32x4_t _pa = vld1q_f32(p0 + B_hstep * 4 + 8); + float32x4_t _pb = vld1q_f32(p0 + B_hstep * 4 + 12); + float32x4_t _pc = vld1q_f32(p0 + B_hstep * 4 + 16); + float32x4_t _pd = vld1q_f32(p0 + B_hstep * 4 + 20); + float32x4_t _pe = vld1q_f32(p0 + B_hstep * 4 + 24); + float32x4_t _pf = vld1q_f32(p0 + B_hstep * 4 + 28); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + _p8 = vmulq_f32(_p8, _scale); + _p9 = vmulq_f32(_p9, _scale); + _pa = vmulq_f32(_pa, _scale); + _pb = vmulq_f32(_pb, _scale); + _pc = vmulq_f32(_pc, _scale); + _pd = vmulq_f32(_pd, _scale); + _pe = vmulq_f32(_pe, _scale); + _pf = vmulq_f32(_pf, _scale); + + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p8), float2int8(_p2, _pa)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p9), float2int8(_p3, _pb)); + int8x16x2_t _r23; + _r23.val[0] = vcombine_s8(float2int8(_p4, _pc), float2int8(_p6, _pe)); + _r23.val[1] = vcombine_s8(float2int8(_p5, _pd), float2int8(_p7, _pf)); + + vst2q_s8(pp, _r01); + vst2q_s8(pp + 32, _r23); +#endif // __ARM_FEATURE_DOTPROD + + pp += 64; + p0 += 32; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + float32x4x4_t _p = vld4q_f32(p0); + float32x4x4_t _q = vld4q_f32(p0 + B_hstep * 4); + + float32x4_t _p0 = vmulq_f32(_p.val[0], _scale); + float32x4_t _p1 = vmulq_f32(_p.val[1], _scale); + float32x4_t _p2 = vmulq_f32(_p.val[2], _scale); + float32x4_t _p3 = vmulq_f32(_p.val[3], _scale); + float32x4_t _p4 = vmulq_f32(_q.val[0], _scale); + float32x4_t _p5 = vmulq_f32(_q.val[1], _scale); + float32x4_t _p6 = vmulq_f32(_q.val[2], _scale); + float32x4_t _p7 = vmulq_f32(_q.val[3], _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + float32x4_t _p4 = vld1q_f32(p0 + B_hstep * 4); + float32x4_t _p5 = vld1q_f32(p0 + B_hstep * 4 + 4); + float32x4_t _p6 = vld1q_f32(p0 + B_hstep * 4 + 8); + float32x4_t _p7 = vld1q_f32(p0 + B_hstep * 4 + 12); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p4), float2int8(_p2, _p6)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p5), float2int8(_p3, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += 16; + } + for (; kk + 1 < max_kk; kk += 2) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + B_hstep * 4); + float32x4_t _p3 = vld1q_f32(p0 + B_hstep * 4 + 4); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p2); + _r01.val[1] = float2int8(_p1, _p3); + + vst2_s8(pp, _r01); + + pp += 16; + p0 += 8; + } + for (; kk < max_kk; kk++) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + B_hstep * 4); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + + vst1_s8(pp, _r0); + + pp += 8; + p0 += 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + B_hstep); + float32x4_t _p3 = vld1q_f32(p0 + B_hstep + 4); + float32x4_t _p4 = vld1q_f32(p0 + B_hstep * 2); + float32x4_t _p5 = vld1q_f32(p0 + B_hstep * 2 + 4); + float32x4_t _p6 = vld1q_f32(p0 + B_hstep * 3); + float32x4_t _p7 = vld1q_f32(p0 + B_hstep * 3 + 4); + float32x4_t _p8 = vld1q_f32(p0 + B_hstep * 4); + float32x4_t _p9 = vld1q_f32(p0 + B_hstep * 4 + 4); + float32x4_t _pa = vld1q_f32(p0 + B_hstep * 5); + float32x4_t _pb = vld1q_f32(p0 + B_hstep * 5 + 4); + float32x4_t _pc = vld1q_f32(p0 + B_hstep * 6); + float32x4_t _pd = vld1q_f32(p0 + B_hstep * 6 + 4); + float32x4_t _pe = vld1q_f32(p0 + B_hstep * 7); + float32x4_t _pf = vld1q_f32(p0 + B_hstep * 7 + 4); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + _p8 = vmulq_f32(_p8, _scale); + _p9 = vmulq_f32(_p9, _scale); + _pa = vmulq_f32(_pa, _scale); + _pb = vmulq_f32(_pb, _scale); + _pc = vmulq_f32(_pc, _scale); + _pd = vmulq_f32(_pd, _scale); + _pe = vmulq_f32(_pe, _scale); + _pf = vmulq_f32(_pf, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p4, _p6); + int8x8_t _r2 = float2int8(_p8, _pa); + int8x8_t _r3 = float2int8(_pc, _pe); + int8x8_t _r4 = float2int8(_p1, _p3); + int8x8_t _r5 = float2int8(_p5, _p7); + int8x8_t _r6 = float2int8(_p9, _pb); + int8x8_t _r7 = float2int8(_pd, _pf); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p4, _p6)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p8, _pa)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_pc, _pe)); + int16x4_t _t4 = vreinterpret_s16_s8(float2int8(_p1, _p3)); + int16x4_t _t5 = vreinterpret_s16_s8(float2int8(_p5, _p7)); + int16x4_t _t6 = vreinterpret_s16_s8(float2int8(_p9, _pb)); + int16x4_t _t7 = vreinterpret_s16_s8(float2int8(_pd, _pf)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int16x4x2_t _t45 = vuzp_s16(_t4, _t5); + int16x4x2_t _t67 = vuzp_s16(_t6, _t7); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r2 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); + int8x8_t _r4 = vreinterpret_s8_s16(_t45.val[0]); + int8x8_t _r5 = vreinterpret_s8_s16(_t67.val[0]); + int8x8_t _r6 = vreinterpret_s8_s16(_t45.val[1]); + int8x8_t _r7 = vreinterpret_s8_s16(_t67.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); + + pp += 64; + p0 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + B_hstep); + float32x4_t _p2 = vld1q_f32(p0 + B_hstep * 2); + float32x4_t _p3 = vld1q_f32(p0 + B_hstep * 3); + float32x4_t _p4 = vld1q_f32(p0 + B_hstep * 4); + float32x4_t _p5 = vld1q_f32(p0 + B_hstep * 5); + float32x4_t _p6 = vld1q_f32(p0 + B_hstep * 6); + float32x4_t _p7 = vld1q_f32(p0 + B_hstep * 7); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p4, _p5)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p6, _p7)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r2 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + + pp += 32; + p0 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + float32x2_t _p0 = vld1_f32(p0); + float32x2_t _p1 = vld1_f32(p0 + B_hstep); + float32x2_t _p2 = vld1_f32(p0 + B_hstep * 2); + float32x2_t _p3 = vld1_f32(p0 + B_hstep * 3); + float32x2_t _p4 = vld1_f32(p0 + B_hstep * 4); + float32x2_t _p5 = vld1_f32(p0 + B_hstep * 5); + float32x2_t _p6 = vld1_f32(p0 + B_hstep * 6); + float32x2_t _p7 = vld1_f32(p0 + B_hstep * 7); + + float32x4_t _p01 = vcombine_f32(_p0, _p1); + float32x4_t _p23 = vcombine_f32(_p2, _p3); + float32x4_t _p45 = vcombine_f32(_p4, _p5); + float32x4_t _p67 = vcombine_f32(_p6, _p7); + + _p01 = vmulq_f32(_p01, _scale); + _p23 = vmulq_f32(_p23, _scale); + _p45 = vmulq_f32(_p45, _scale); + _p67 = vmulq_f32(_p67, _scale); + + int8x8_t _r0 = float2int8(_p01, _p23); + int8x8_t _r1 = float2int8(_p45, _p67); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += 2; + } + for (; kk < max_kk; kk++) + { + float32x4_t _p0 = float32x4_t(); + float32x4_t _p1 = float32x4_t(); + _p0 = vsetq_lane_f32(p0[0], _p0, 0); + _p0 = vsetq_lane_f32(p0[B_hstep], _p0, 1); + _p0 = vsetq_lane_f32(p0[B_hstep * 2], _p0, 2); + _p0 = vsetq_lane_f32(p0[B_hstep * 3], _p0, 3); + _p1 = vsetq_lane_f32(p0[B_hstep * 4], _p1, 0); + _p1 = vsetq_lane_f32(p0[B_hstep * 5], _p1, 1); + _p1 = vsetq_lane_f32(p0[B_hstep * 6], _p1, 2); + _p1 = vsetq_lane_f32(p0[B_hstep * 7], _p1, 3); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + + vst1_s8(pp, _r0); + + pp += 8; + p0++; + } + } + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k * elempack; + + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { +#if __ARM_FEATURE_DOTPROD + float32x4x4_t _p = vld4q_f32(p0); + float32x4x4_t _q = vld4q_f32(p0 + 16); + + float32x4_t _p0 = vmulq_f32(_p.val[0], _scale); + float32x4_t _p1 = vmulq_f32(_p.val[1], _scale); + float32x4_t _p2 = vmulq_f32(_p.val[2], _scale); + float32x4_t _p3 = vmulq_f32(_p.val[3], _scale); + float32x4_t _p4 = vmulq_f32(_q.val[0], _scale); + float32x4_t _p5 = vmulq_f32(_q.val[1], _scale); + float32x4_t _p6 = vmulq_f32(_q.val[2], _scale); + float32x4_t _p7 = vmulq_f32(_q.val[3], _scale); + +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p4); + int8x8_t _r1 = float2int8(_p1, _p5); + int8x8_t _r2 = float2int8(_p2, _p6); + int8x8_t _r3 = float2int8(_p3, _p7); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#endif // __ARM_FEATURE_MATMUL_INT8 + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + float32x4_t _p4 = vld1q_f32(p0 + 16); + float32x4_t _p5 = vld1q_f32(p0 + 20); + float32x4_t _p6 = vld1q_f32(p0 + 24); + float32x4_t _p7 = vld1q_f32(p0 + 28); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p2), float2int8(_p4, _p6)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p3), float2int8(_p5, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += 32; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + float32x4x4_t _p = vld4q_f32(p0); + + float32x4_t _p0 = vmulq_f32(_p.val[0], _scale); + float32x4_t _p1 = vmulq_f32(_p.val[1], _scale); + float32x4_t _p2 = vmulq_f32(_p.val[2], _scale); + float32x4_t _p3 = vmulq_f32(_p.val[3], _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p2); + _r01.val[1] = float2int8(_p1, _p3); + + vst2_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 16; + p0 += 16; + } + for (; kk + 1 < max_kk; kk += 2) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + float32x4x2_t _p01 = vzipq_f32(_p0, _p1); + + int8x8_t _r01 = float2int8(_p01.val[0], _p01.val[1]); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += 8; + } + for (; kk < max_kk; kk++) + { + float32x4_t _p0 = vld1q_f32(p0); + _p0 = vmulq_f32(_p0, _scale); + int8x8_t _r0 = float2int8(_p0, _p0); + + pp[0] = vget_lane_s8(_r0, 0); + pp[1] = vget_lane_s8(_r0, 1); + pp[2] = vget_lane_s8(_r0, 2); + pp[3] = vget_lane_s8(_r0, 3); + + pp += 4; + p0 += 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + B_hstep); + float32x4_t _p3 = vld1q_f32(p0 + B_hstep + 4); + float32x4_t _p4 = vld1q_f32(p0 + B_hstep * 2); + float32x4_t _p5 = vld1q_f32(p0 + B_hstep * 2 + 4); + float32x4_t _p6 = vld1q_f32(p0 + B_hstep * 3); + float32x4_t _p7 = vld1q_f32(p0 + B_hstep * 3 + 4); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p4, _p6); + int8x8_t _r2 = float2int8(_p1, _p3); + int8x8_t _r3 = float2int8(_p5, _p7); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p4, _p6)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p1, _p3)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p5, _p7)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r2 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + + pp += 32; + p0 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + B_hstep); + float32x4_t _p2 = vld1q_f32(p0 + B_hstep * 2); + float32x4_t _p3 = vld1q_f32(p0 + B_hstep * 3); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + float32x2_t _p0 = vld1_f32(p0); + float32x2_t _p1 = vld1_f32(p0 + B_hstep); + float32x2_t _p2 = vld1_f32(p0 + B_hstep * 2); + float32x2_t _p3 = vld1_f32(p0 + B_hstep * 3); + + float32x4_t _p01 = vcombine_f32(_p0, _p1); + float32x4_t _p23 = vcombine_f32(_p2, _p3); + + _p01 = vmulq_f32(_p01, _scale); + _p23 = vmulq_f32(_p23, _scale); + + int8x8_t _r0 = float2int8(_p01, _p23); + + vst1_s8(pp, _r0); + + pp += 8; + p0 += 2; + } + for (; kk < max_kk; kk++) + { + float32x4_t _p0 = float32x4_t(); + _p0 = vsetq_lane_f32(p0[0], _p0, 0); + _p0 = vsetq_lane_f32(p0[B_hstep], _p0, 1); + _p0 = vsetq_lane_f32(p0[B_hstep * 2], _p0, 2); + _p0 = vsetq_lane_f32(p0[B_hstep * 3], _p0, 3); + + _p0 = vmulq_f32(_p0, _scale); + int8x8_t _r0 = float2int8(_p0, _p0); + + pp[0] = vget_lane_s8(_r0, 0); + pp[1] = vget_lane_s8(_r0, 1); + pp[2] = vget_lane_s8(_r0, 2); + pp[3] = vget_lane_s8(_r0, 3); + + pp += 4; + p0++; + } + } + } +#endif // __ARM_NEON + for (; jj + 1 < max_jj; jj += 2) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k; + + // if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + B_hstep); + float32x4_t _p3 = vld1q_f32(p0 + B_hstep + 4); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p1, _p3); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p2)); + float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p2)); + float32x4_t _t2 = vcombine_f32(vget_low_f32(_p1), vget_low_f32(_p3)); + float32x4_t _t3 = vcombine_f32(vget_high_f32(_p1), vget_high_f32(_p3)); + int8x8_t _r0 = float2int8(_t0, _t1); + int8x8_t _r1 = float2int8(_t2, _t3); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r0); + vst1_s8(pp + 8, _r1); + + pp += 16; + p0 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + B_hstep); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p1)); + float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p1)); + int8x8_t _r0 = float2int8(_t0, _t1); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r0); + + pp += 8; + p0 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = float2int8(p0[0] * scale); + pp[1] = float2int8(p0[1] * scale); + pp[2] = float2int8(p0[B_hstep] * scale); + pp[3] = float2int8(p0[B_hstep + 1] * scale); + pp += 4; + p0 += 2; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale); + pp[1] = float2int8(p0[B_hstep] * scale); + pp += 2; + p0++; + } + } + } + for (; jj < max_jj; jj += 1) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k; + + // if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + for (; kk + 15 < max_kk; kk += 16) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + + vst1_s8(pp, _r0); + + pp += 8; + p0 += 8; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale); + pp += 1; + p0++; + } + } + } +} + +static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + transpose_pack_B_tile_fp32_to_int8_i8mm(B, BT, j, max_jj, k, max_kk, scale); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + transpose_pack_B_tile_fp32_to_int8_asimddp(B, BT, j, max_jj, k, max_kk, scale); + return; + } +#endif + + const int elempack = B.elempack; + const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; + + // NCNN_LOGE("transpose_pack_B_tile_fp32_to_int8 %d %d", max_jj, elempack); + + signed char* pp = BT; + +#if __ARM_NEON + float32x4_t _scale = vdupq_n_f32(scale); +#endif + + int jj = 0; +#if __ARM_NEON +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * elempack; + + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + float32x4_t _p4 = vld1q_f32(p0 + 16); + float32x4_t _p5 = vld1q_f32(p0 + 20); + float32x4_t _p6 = vld1q_f32(p0 + 24); + float32x4_t _p7 = vld1q_f32(p0 + 28); + float32x4_t _p8 = vld1q_f32(p0 + B_hstep * 4); + float32x4_t _p9 = vld1q_f32(p0 + B_hstep * 4 + 4); + float32x4_t _pa = vld1q_f32(p0 + B_hstep * 4 + 8); + float32x4_t _pb = vld1q_f32(p0 + B_hstep * 4 + 12); + float32x4_t _pc = vld1q_f32(p0 + B_hstep * 4 + 16); + float32x4_t _pd = vld1q_f32(p0 + B_hstep * 4 + 20); + float32x4_t _pe = vld1q_f32(p0 + B_hstep * 4 + 24); + float32x4_t _pf = vld1q_f32(p0 + B_hstep * 4 + 28); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + _p8 = vmulq_f32(_p8, _scale); + _p9 = vmulq_f32(_p9, _scale); + _pa = vmulq_f32(_pa, _scale); + _pb = vmulq_f32(_pb, _scale); + _pc = vmulq_f32(_pc, _scale); + _pd = vmulq_f32(_pd, _scale); + _pe = vmulq_f32(_pe, _scale); + _pf = vmulq_f32(_pf, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p8); + int8x8_t _r1 = float2int8(_p1, _p9); + int8x8_t _r2 = float2int8(_p2, _pa); + int8x8_t _r3 = float2int8(_p3, _pb); + int8x8_t _r4 = float2int8(_p4, _pc); + int8x8_t _r5 = float2int8(_p5, _pd); + int8x8_t _r6 = float2int8(_p6, _pe); + int8x8_t _r7 = float2int8(_p7, _pf); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); + + int16x8_t _r01 = vreinterpretq_s16_s8(vcombine_s8(_r0, _r1)); + int16x8_t _r23 = vreinterpretq_s16_s8(vcombine_s8(_r2, _r3)); + int16x8_t _r45 = vreinterpretq_s16_s8(vcombine_s8(_r4, _r5)); + int16x8_t _r67 = vreinterpretq_s16_s8(vcombine_s8(_r6, _r7)); + int16x8x2_t _rr0 = vuzpq_s16(_r01, _r23); + int16x8x2_t _rr1 = vuzpq_s16(_r45, _r67); + + vst1q_s8(pp, vreinterpretq_s8_s16(_rr0.val[0])); + vst1q_s8(pp + 16, vreinterpretq_s8_s16(_rr0.val[1])); + vst1q_s8(pp + 32, vreinterpretq_s8_s16(_rr1.val[0])); + vst1q_s8(pp + 48, vreinterpretq_s8_s16(_rr1.val[1])); +#endif // __ARM_FEATURE_DOTPROD + + pp += 64; + p0 += B_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + float32x4_t _p4 = vld1q_f32(p0 + 16); + float32x4_t _p5 = vld1q_f32(p0 + 20); + float32x4_t _p6 = vld1q_f32(p0 + 24); + float32x4_t _p7 = vld1q_f32(p0 + 28); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + +#if __ARM_FEATURE_DOTPROD + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); +#else // __ARM_FEATURE_DOTPROD + int16x8_t _r01 = vreinterpretq_s16_s8(vcombine_s8(_r0, _r1)); + int16x8_t _r23 = vreinterpretq_s16_s8(vcombine_s8(_r2, _r3)); + int16x8x2_t _rr = vuzpq_s16(_r01, _r23); + + vst1q_s8(pp, vreinterpretq_s8_s16(_rr.val[0])); + vst1q_s8(pp + 16, vreinterpretq_s8_s16(_rr.val[1])); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += B_hstep * 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + B_hstep); + float32x4_t _p3 = vld1q_f32(p0 + B_hstep + 4); + float32x4_t _p4 = vld1q_f32(p0 + B_hstep * 2); + float32x4_t _p5 = vld1q_f32(p0 + B_hstep * 2 + 4); + float32x4_t _p6 = vld1q_f32(p0 + B_hstep * 3); + float32x4_t _p7 = vld1q_f32(p0 + B_hstep * 3 + 4); + float32x4_t _p8 = vld1q_f32(p0 + B_hstep * 4); + float32x4_t _p9 = vld1q_f32(p0 + B_hstep * 4 + 4); + float32x4_t _pa = vld1q_f32(p0 + B_hstep * 5); + float32x4_t _pb = vld1q_f32(p0 + B_hstep * 5 + 4); + float32x4_t _pc = vld1q_f32(p0 + B_hstep * 6); + float32x4_t _pd = vld1q_f32(p0 + B_hstep * 6 + 4); + float32x4_t _pe = vld1q_f32(p0 + B_hstep * 7); + float32x4_t _pf = vld1q_f32(p0 + B_hstep * 7 + 4); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + _p8 = vmulq_f32(_p8, _scale); + _p9 = vmulq_f32(_p9, _scale); + _pa = vmulq_f32(_pa, _scale); + _pb = vmulq_f32(_pb, _scale); + _pc = vmulq_f32(_pc, _scale); + _pd = vmulq_f32(_pd, _scale); + _pe = vmulq_f32(_pe, _scale); + _pf = vmulq_f32(_pf, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8x2_t _r04 = vzip_s8(_r0, _r4); + int8x8x2_t _r15 = vzip_s8(_r1, _r5); + int8x8x2_t _r26 = vzip_s8(_r2, _r6); + int8x8x2_t _r37 = vzip_s8(_r3, _r7); + int8x16x4_t _r0123; + _r0123.val[0] = vcombine_s8(_r04.val[0], _r04.val[1]); + _r0123.val[1] = vcombine_s8(_r15.val[0], _r15.val[1]); + _r0123.val[2] = vcombine_s8(_r26.val[0], _r26.val[1]); + _r0123.val[3] = vcombine_s8(_r37.val[0], _r37.val[1]); + + vst4q_s8(pp, _r0123); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8x4_t _r0123; + _r0123.val[0] = _r0; + _r0123.val[1] = _r1; + _r0123.val[2] = _r2; + _r0123.val[3] = _r3; + int8x8x4_t _r4567; + _r4567.val[0] = _r4; + _r4567.val[1] = _r5; + _r4567.val[2] = _r6; + _r4567.val[3] = _r7; + + vst4_s8(pp, _r0123); + vst4_s8(pp + 32, _r4567); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(_r0, _r2); + _r01.val[1] = vcombine_s8(_r1, _r3); + int8x16x2_t _r23; + _r23.val[0] = vcombine_s8(_r4, _r6); + _r23.val[1] = vcombine_s8(_r5, _r7); + + vst2q_s8(pp, _r01); + vst2q_s8(pp + 32, _r23); +#endif // __ARM_FEATURE_DOTPROD + + pp += 64; + p0 += B_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + B_hstep); + float32x4_t _p3 = vld1q_f32(p0 + B_hstep + 4); + float32x4_t _p4 = vld1q_f32(p0 + B_hstep * 2); + float32x4_t _p5 = vld1q_f32(p0 + B_hstep * 2 + 4); + float32x4_t _p6 = vld1q_f32(p0 + B_hstep * 3); + float32x4_t _p7 = vld1q_f32(p0 + B_hstep * 3 + 4); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + +#if __ARM_FEATURE_DOTPROD + int8x8x4_t _r0123; + _r0123.val[0] = float2int8(_p0, _p1); + _r0123.val[1] = float2int8(_p2, _p3); + _r0123.val[2] = float2int8(_p4, _p5); + _r0123.val[3] = float2int8(_p6, _p7); + + vst4_s8(pp, _r0123); +#else // __ARM_FEATURE_DOTPROD + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p1), float2int8(_p4, _p5)); + _r01.val[1] = vcombine_s8(float2int8(_p2, _p3), float2int8(_p6, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += B_hstep * 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + B_hstep); + float32x4_t _p3 = vld1q_f32(p0 + B_hstep + 4); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p1); + _r01.val[1] = float2int8(_p2, _p3); + + vst2_s8(pp, _r01); + + pp += 16; + p0 += B_hstep * 2; + } + for (; kk < max_kk; kk++) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + + vst1_s8(pp, _r0); + + pp += 8; + p0 += B_hstep; + } + } + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * elempack; + + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + float32x4_t _p4 = vld1q_f32(p0 + B_hstep * 4); + float32x4_t _p5 = vld1q_f32(p0 + B_hstep * 4 + 4); + float32x4_t _p6 = vld1q_f32(p0 + B_hstep * 4 + 8); + float32x4_t _p7 = vld1q_f32(p0 + B_hstep * 4 + 12); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p4); + int8x8_t _r1 = float2int8(_p1, _p5); + int8x8_t _r2 = float2int8(_p2, _p6); + int8x8_t _r3 = float2int8(_p3, _p7); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p4, _p5)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p6, _p7)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r2 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + + pp += 32; + p0 += B_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += B_hstep * 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + B_hstep); + float32x4_t _p2 = vld1q_f32(p0 + B_hstep * 2); + float32x4_t _p3 = vld1q_f32(p0 + B_hstep * 3); + float32x4_t _p4 = vld1q_f32(p0 + B_hstep * 4); + float32x4_t _p5 = vld1q_f32(p0 + B_hstep * 5); + float32x4_t _p6 = vld1q_f32(p0 + B_hstep * 6); + float32x4_t _p7 = vld1q_f32(p0 + B_hstep * 7); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + float32x4x2_t _p04 = vzipq_f32(_p0, _p4); + float32x4x2_t _p15 = vzipq_f32(_p1, _p5); + float32x4x2_t _p26 = vzipq_f32(_p2, _p6); + float32x4x2_t _p37 = vzipq_f32(_p3, _p7); + int8x8x4_t _r0123; + _r0123.val[0] = float2int8(_p04.val[0], _p04.val[1]); + _r0123.val[1] = float2int8(_p15.val[0], _p15.val[1]); + _r0123.val[2] = float2int8(_p26.val[0], _p26.val[1]); + _r0123.val[3] = float2int8(_p37.val[0], _p37.val[1]); + + vst4_s8(pp, _r0123); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8x4_t _r0123; + _r0123.val[0] = float2int8(_p0, _p4); + _r0123.val[1] = float2int8(_p1, _p5); + _r0123.val[2] = float2int8(_p2, _p6); + _r0123.val[3] = float2int8(_p3, _p7); + + vst4_s8(pp, _r0123); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p2), float2int8(_p4, _p6)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p3), float2int8(_p5, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += B_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + B_hstep); + float32x4_t _p2 = vld1q_f32(p0 + B_hstep * 2); + float32x4_t _p3 = vld1q_f32(p0 + B_hstep * 3); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + +#if __ARM_FEATURE_DOTPROD + transpose4x4_ps(_p0, _p1, _p2, _p3); + int8x8_t _r01 = float2int8(_p0, _p1); + int8x8_t _r23 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r01, _r23)); +#else // __ARM_FEATURE_DOTPROD + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p2); + _r01.val[1] = float2int8(_p1, _p3); + + vst2_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 16; + p0 += B_hstep * 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + B_hstep); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + float32x4x2_t _p01 = vzipq_f32(_p0, _p1); + int8x8_t _r01 = float2int8(_p01.val[0], _p01.val[1]); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += B_hstep * 2; + } + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale); + pp[1] = float2int8(p0[1] * scale); + pp[2] = float2int8(p0[2] * scale); + pp[3] = float2int8(p0[3] * scale); + pp += 4; + p0 += B_hstep; + } + } + } +#endif // __ARM_NEON + for (; jj + 1 < max_jj; jj += 2) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * elempack; + +#if __ARM_NEON + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + B_hstep * 4); + float32x4_t _p3 = vld1q_f32(p0 + B_hstep * 4 + 4); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p1, _p3); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p1, _p3)); + int16x4x2_t _t01 = vzip_s16(_t0, _t1); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += B_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r01 = float2int8(_p0, _p1); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p1)); + float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p1)); + int8x8_t _r01 = float2int8(_t0, _t1); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r01); + + pp += 8; + p0 += B_hstep * 4; + } + } +#endif // __ARM_NEON + if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + for (; kk + 7 < max_kk; kk += 8) + { + float32x2_t _p0 = vld1_f32(p0); + float32x2_t _p1 = vld1_f32(p0 + B_hstep); + float32x2_t _p2 = vld1_f32(p0 + B_hstep * 2); + float32x2_t _p3 = vld1_f32(p0 + B_hstep * 3); + float32x2_t _p4 = vld1_f32(p0 + B_hstep * 4); + float32x2_t _p5 = vld1_f32(p0 + B_hstep * 5); + float32x2_t _p6 = vld1_f32(p0 + B_hstep * 6); + float32x2_t _p7 = vld1_f32(p0 + B_hstep * 7); + +#if __ARM_FEATURE_DOTPROD + float32x4_t _p01 = vcombine_f32(_p0, _p1); + float32x4_t _p23 = vcombine_f32(_p2, _p3); + float32x4_t _p45 = vcombine_f32(_p4, _p5); + float32x4_t _p67 = vcombine_f32(_p6, _p7); + + _p01 = vmulq_f32(_p01, _scale); + _p23 = vmulq_f32(_p23, _scale); + _p45 = vmulq_f32(_p45, _scale); + _p67 = vmulq_f32(_p67, _scale); + + int8x8_t _r0 = float2int8(_p01, _p23); + int8x8_t _r1 = float2int8(_p45, _p67); + +#if __ARM_FEATURE_MATMUL_INT8 + int8x8x2_t _r01 = vuzp_s8(_r0, _r1); + + vst1q_s8(pp, vcombine_s8(_r01.val[0], _r01.val[1])); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8x2_t _r01 = vtrn_s8(_r0, _r1); + int8x8x2_t _rr01 = vuzp_s8(_r01.val[0], _r01.val[1]); + + vst1q_s8(pp, vcombine_s8(_rr01.val[0], _rr01.val[1])); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + float32x4_t _p02 = vcombine_f32(_p0, _p2); + float32x4_t _p46 = vcombine_f32(_p4, _p6); + float32x4_t _p13 = vcombine_f32(_p1, _p3); + float32x4_t _p57 = vcombine_f32(_p5, _p7); + + _p02 = vmulq_f32(_p02, _scale); + _p46 = vmulq_f32(_p46, _scale); + _p13 = vmulq_f32(_p13, _scale); + _p57 = vmulq_f32(_p57, _scale); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p02, _p46); + _r01.val[1] = float2int8(_p13, _p57); + + vst2_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 16; + p0 += B_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x2_t _p0 = vld1_f32(p0); + float32x2_t _p1 = vld1_f32(p0 + B_hstep); + float32x2_t _p2 = vld1_f32(p0 + B_hstep * 2); + float32x2_t _p3 = vld1_f32(p0 + B_hstep * 3); + +#if __ARM_FEATURE_DOTPROD + float32x4_t _p01 = vcombine_f32(_p0, _p1); + float32x4_t _p23 = vcombine_f32(_p2, _p3); + + _p01 = vmulq_f32(_p01, _scale); + _p23 = vmulq_f32(_p23, _scale); + + float32x4x2_t _pp = vuzpq_f32(_p01, _p23); + int8x8_t _r01 = float2int8(_pp.val[0], _pp.val[1]); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _p02 = vcombine_f32(_p0, _p2); + float32x4_t _p13 = vcombine_f32(_p1, _p3); + + _p02 = vmulq_f32(_p02, _scale); + _p13 = vmulq_f32(_p13, _scale); + + float32x4x2_t _pp = vzipq_f32(_p02, _p13); + int8x8_t _r01 = float2int8(_pp.val[0], _pp.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r01); + + pp += 8; + p0 += B_hstep * 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = float2int8(p0[0] * scale); + pp[1] = float2int8(p0[B_hstep + 0] * scale); + pp[2] = float2int8(p0[1] * scale); + pp[3] = float2int8(p0[B_hstep + 1] * scale); + pp += 4; + p0 += B_hstep * 2; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale); + pp[1] = float2int8(p0[1] * scale); + pp += 2; + p0 += B_hstep; + } + } + } + for (; jj < max_jj; jj += 1) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * elempack; + +#if __ARM_NEON + if (elempack == 4) + { + int kk = 0; + for (; kk + 15 < max_kk; kk += 16) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + B_hstep * 4); + float32x4_t _p2 = vld1q_f32(p0 + B_hstep * 8); + float32x4_t _p3 = vld1q_f32(p0 + B_hstep * 12); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + int8x8_t _r23 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r01, _r23)); + + pp += 16; + p0 += B_hstep * 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + B_hstep * 4); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += B_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale); + pp[1] = float2int8(p0[1] * scale); + pp[2] = float2int8(p0[2] * scale); + pp[3] = float2int8(p0[3] * scale); + pp += 4; + p0 += B_hstep * 4; + } + } +#endif // __ARM_NEON + if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + for (; kk + 15 < max_kk; kk += 16) + { + float32x4_t _p0 = float32x4_t(); + float32x4_t _p1 = float32x4_t(); + float32x4_t _p2 = float32x4_t(); + float32x4_t _p3 = float32x4_t(); + _p0 = vsetq_lane_f32(p0[0], _p0, 0); + _p0 = vsetq_lane_f32(p0[B_hstep], _p0, 1); + _p0 = vsetq_lane_f32(p0[B_hstep * 2], _p0, 2); + _p0 = vsetq_lane_f32(p0[B_hstep * 3], _p0, 3); + _p1 = vsetq_lane_f32(p0[B_hstep * 4], _p1, 0); + _p1 = vsetq_lane_f32(p0[B_hstep * 5], _p1, 1); + _p1 = vsetq_lane_f32(p0[B_hstep * 6], _p1, 2); + _p1 = vsetq_lane_f32(p0[B_hstep * 7], _p1, 3); + _p2 = vsetq_lane_f32(p0[B_hstep * 8], _p2, 0); + _p2 = vsetq_lane_f32(p0[B_hstep * 9], _p2, 1); + _p2 = vsetq_lane_f32(p0[B_hstep * 10], _p2, 2); + _p2 = vsetq_lane_f32(p0[B_hstep * 11], _p2, 3); + _p3 = vsetq_lane_f32(p0[B_hstep * 12], _p3, 0); + _p3 = vsetq_lane_f32(p0[B_hstep * 13], _p3, 1); + _p3 = vsetq_lane_f32(p0[B_hstep * 14], _p3, 2); + _p3 = vsetq_lane_f32(p0[B_hstep * 15], _p3, 3); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + int8x8_t _r23 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r01, _r23)); + + pp += 16; + p0 += B_hstep * 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = float32x4_t(); + float32x4_t _p1 = float32x4_t(); + _p0 = vsetq_lane_f32(p0[0], _p0, 0); + _p0 = vsetq_lane_f32(p0[B_hstep], _p0, 1); + _p0 = vsetq_lane_f32(p0[B_hstep * 2], _p0, 2); + _p0 = vsetq_lane_f32(p0[B_hstep * 3], _p0, 3); + _p1 = vsetq_lane_f32(p0[B_hstep * 4], _p1, 0); + _p1 = vsetq_lane_f32(p0[B_hstep * 5], _p1, 1); + _p1 = vsetq_lane_f32(p0[B_hstep * 6], _p1, 2); + _p1 = vsetq_lane_f32(p0[B_hstep * 7], _p1, 3); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += B_hstep * 8; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale); + pp += 1; + p0 += B_hstep; + } + } + } +} + +static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + unpack_output_tile_int32_to_fp32_asimddp(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales, alpha, beta); + return; + } +#endif + + const int out_elempack = top_blob.elempack; + const int out_hstep = top_blob.dims == 3 ? (int)top_blob.cstep : top_blob.w; + + const int c_hstep = C.dims == 3 ? (int)C.cstep : C.w; + const int c_elempack = C.elempack; + const float* pC = C; + + // NCNN_LOGE("unpack_output_tile_int32_to_fp32 %d %d %d %d %d %d %d", i, max_ii, j, max_jj, out_elempack, broadcast_type_C, c_elempack); + + const int* pp = topT; + + int ii = 0; +#if __ARM_NEON + for (; ii + 7 < max_ii; ii += 8) + { + float* p0 = (float*)top_blob + (i + ii) * out_hstep + j * out_elempack; + + float32x4_t _descale0 = vld1q_f32((const float*)descales + ii); + float32x4_t _descale1 = vld1q_f32((const float*)descales + ii + 4); + + float32x4_t _c0; + float32x4_t _c1; + if (pC) + { + if (broadcast_type_C == 0) + { + _c0 = vdupq_n_f32(pC[0] * beta); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)C + i + ii; + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + _c0 = vmulq_n_f32(_c0, beta); + _c1 = vmulq_n_f32(_c1, beta); + } + if (broadcast_type_C == 3) + { + pC = (const float*)C + (i + ii) * c_hstep + j * c_elempack; + } + if (broadcast_type_C == 4) + { + pC = (const float*)C + j; + } + } + + int jj = 0; +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + int32x4_t _sum8 = vld1q_s32(pp + 32); + int32x4_t _sum9 = vld1q_s32(pp + 36); + int32x4_t _suma = vld1q_s32(pp + 40); + int32x4_t _sumb = vld1q_s32(pp + 44); + int32x4_t _sumc = vld1q_s32(pp + 48); + int32x4_t _sumd = vld1q_s32(pp + 52); + int32x4_t _sume = vld1q_s32(pp + 56); + int32x4_t _sumf = vld1q_s32(pp + 60); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 +#else + // from + // a0 b1 c2 d3 + // e4 f5 g6 h7 + // e0 f1 g2 h3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // g4 h5 e6 f7 + // g0 h1 e2 f3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // e7 f6 g5 h4 + // e3 f2 g1 h0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // g7 h6 e5 f4 + // g3 h2 e1 f0 + // c7 d6 a5 b4 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 + { + _sum8 = vrev64q_s32(_sum8); + _sum9 = vrev64q_s32(_sum9); + _suma = vrev64q_s32(_suma); + _sumb = vrev64q_s32(_sumb); + _sumc = vrev64q_s32(_sumc); + _sumd = vrev64q_s32(_sumd); + _sume = vrev64q_s32(_sume); + _sumf = vrev64q_s32(_sumf); + _sum8 = vextq_s32(_sum8, _sum8, 2); + _sum9 = vextq_s32(_sum9, _sum9, 2); + _suma = vextq_s32(_suma, _suma, 2); + _sumb = vextq_s32(_sumb, _sumb, 2); + _sumc = vextq_s32(_sumc, _sumc, 2); + _sumd = vextq_s32(_sumd, _sumd, 2); + _sume = vextq_s32(_sume, _sume, 2); + _sumf = vextq_s32(_sumf, _sumf, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sumc); + int32x4x2_t _t1 = vzipq_s32(_sum4, _sum8); + int32x4x2_t _t2 = vzipq_s32(_sum2, _sume); + int32x4x2_t _t3 = vzipq_s32(_sum6, _suma); + int32x4x2_t _t4 = vzipq_s32(_sum3, _sumf); + int32x4x2_t _t5 = vzipq_s32(_sum7, _sumb); + int32x4x2_t _t6 = vzipq_s32(_sum1, _sumd); + int32x4x2_t _t7 = vzipq_s32(_sum5, _sum9); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum8 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); + _sum9 = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); + _suma = vcombine_s32(vget_low_s32(_t5.val[1]), vget_low_s32(_t4.val[1])); + _sumb = vcombine_s32(vget_high_s32(_t5.val[1]), vget_high_s32(_t4.val[1])); + _sumc = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); + _sumd = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); + _sume = vcombine_s32(vget_low_s32(_t7.val[1]), vget_low_s32(_t6.val[1])); + _sumf = vcombine_s32(vget_high_s32(_t7.val[1]), vget_high_s32(_t6.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + _sum9 = vrev64q_s32(_sum9); + _sumb = vrev64q_s32(_sumb); + _sumd = vrev64q_s32(_sumd); + _sumf = vrev64q_s32(_sumf); + } +#endif + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum8), _descale0); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum9), _descale0); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_suma), _descale0); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sumb), _descale0); + float32x4_t _f8 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); + float32x4_t _f9 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); + float32x4_t _fa = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); + float32x4_t _fb = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); + float32x4_t _fc = vmulq_f32(vcvtq_f32_s32(_sumc), _descale1); + float32x4_t _fd = vmulq_f32(vcvtq_f32_s32(_sumd), _descale1); + float32x4_t _fe = vmulq_f32(vcvtq_f32_s32(_sume), _descale1); + float32x4_t _ff = vmulq_f32(vcvtq_f32_s32(_sumf), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c0); + _fa = vaddq_f32(_fa, _c0); + _fb = vaddq_f32(_fb, _c0); + _fc = vaddq_f32(_fc, _c0); + _fd = vaddq_f32(_fd, _c0); + _fe = vaddq_f32(_fe, _c0); + _ff = vaddq_f32(_ff, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + _f8 = vaddq_f32(_f8, _c1); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c1); + _fb = vaddq_f32(_fb, _c1); + _fc = vaddq_f32(_fc, _c1); + _fd = vaddq_f32(_fd, _c1); + _fe = vaddq_f32(_fe, _c1); + _ff = vaddq_f32(_ff, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + 4 * 2); + float32x4_t _c3 = vld1q_f32(pC + 4 * 3); + float32x4_t _c4 = vld1q_f32(pC + 4 * 4); + float32x4_t _c5 = vld1q_f32(pC + 4 * 5); + float32x4_t _c6 = vld1q_f32(pC + 4 * 6); + float32x4_t _c7 = vld1q_f32(pC + 4 * 7); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 4 + 4); + _c2 = vld1q_f32(pC + c_hstep * 4 + 4 * 2); + _c3 = vld1q_f32(pC + c_hstep * 4 + 4 * 3); + _c4 = vld1q_f32(pC + c_hstep * 4 + 4 * 4); + _c5 = vld1q_f32(pC + c_hstep * 4 + 4 * 5); + _c6 = vld1q_f32(pC + c_hstep * 4 + 4 * 6); + _c7 = vld1q_f32(pC + c_hstep * 4 + 4 * 7); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); + } + pC += 32; + } + if (c_elempack == 1) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + c_hstep); + float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); + float32x4_t _c4 = vld1q_f32(pC + c_hstep * 2); + float32x4_t _c5 = vld1q_f32(pC + c_hstep * 2 + 4); + float32x4_t _c6 = vld1q_f32(pC + c_hstep * 3); + float32x4_t _c7 = vld1q_f32(pC + c_hstep * 3 + 4); + transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 4 + 4); + _c2 = vld1q_f32(pC + c_hstep * 5); + _c3 = vld1q_f32(pC + c_hstep * 5 + 4); + _c4 = vld1q_f32(pC + c_hstep * 6); + _c5 = vld1q_f32(pC + c_hstep * 6 + 4); + _c6 = vld1q_f32(pC + c_hstep * 7); + _c7 = vld1q_f32(pC + c_hstep * 7 + 4); + transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); + } + pC += 8; + } + } + if (broadcast_type_C == 4) + { + float32x4_t _cc0 = vld1q_f32(pC); + float32x4_t _cc1 = vld1q_f32(pC + 4); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _cc0 = vmulq_f32(_cc0, _beta); + _cc1 = vmulq_f32(_cc1, _beta); + } + _c0 = vdupq_laneq_f32(_cc0, 0); + _c1 = vdupq_laneq_f32(_cc0, 1); + float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); + float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); + float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); + float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); + float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); + float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + pC += 8; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + _f8 = vmulq_f32(_f8, _alpha); + _f9 = vmulq_f32(_f9, _alpha); + _fa = vmulq_f32(_fa, _alpha); + _fb = vmulq_f32(_fb, _alpha); + _fc = vmulq_f32(_fc, _alpha); + _fd = vmulq_f32(_fd, _alpha); + _fe = vmulq_f32(_fe, _alpha); + _ff = vmulq_f32(_ff, _alpha); + } + + if (out_elempack == 4) + { + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + vst1q_f32(p0 + 8, _f2); + vst1q_f32(p0 + 12, _f3); + vst1q_f32(p0 + 16, _f4); + vst1q_f32(p0 + 20, _f5); + vst1q_f32(p0 + 24, _f6); + vst1q_f32(p0 + 28, _f7); + vst1q_f32(p0 + out_hstep * 4, _f8); + vst1q_f32(p0 + out_hstep * 4 + 4, _f9); + vst1q_f32(p0 + out_hstep * 4 + 8, _fa); + vst1q_f32(p0 + out_hstep * 4 + 12, _fb); + vst1q_f32(p0 + out_hstep * 4 + 16, _fc); + vst1q_f32(p0 + out_hstep * 4 + 20, _fd); + vst1q_f32(p0 + out_hstep * 4 + 24, _fe); + vst1q_f32(p0 + out_hstep * 4 + 28, _ff); + p0 += 32; + } + if (out_elempack == 1) + { + transpose4x4_ps(_f0, _f1, _f2, _f3); + transpose4x4_ps(_f4, _f5, _f6, _f7); + transpose4x4_ps(_f8, _f9, _fa, _fb); + transpose4x4_ps(_fc, _fd, _fe, _ff); + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f4); + vst1q_f32(p0 + out_hstep, _f1); + vst1q_f32(p0 + out_hstep + 4, _f5); + vst1q_f32(p0 + out_hstep * 2, _f2); + vst1q_f32(p0 + out_hstep * 2 + 4, _f6); + vst1q_f32(p0 + out_hstep * 3, _f3); + vst1q_f32(p0 + out_hstep * 3 + 4, _f7); + vst1q_f32(p0 + out_hstep * 4, _f8); + vst1q_f32(p0 + out_hstep * 4 + 4, _fc); + vst1q_f32(p0 + out_hstep * 5, _f9); + vst1q_f32(p0 + out_hstep * 5 + 4, _fd); + vst1q_f32(p0 + out_hstep * 6, _fa); + vst1q_f32(p0 + out_hstep * 6 + 4, _fe); + vst1q_f32(p0 + out_hstep * 7, _fb); + vst1q_f32(p0 + out_hstep * 7 + 4, _ff); + p0 += 8; + } + + pp += 64; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 +#else + // from + // a0 b1 c2 d3 + // e0 f1 g2 h3 + // c0 d1 a2 b3 + // g0 h1 e2 f3 + // a3 b2 c1 d0 + // e3 f2 g1 h0 + // c3 d2 a1 b0 + // g3 h2 e1 f0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + { + _sum4 = vrev64q_s32(_sum4); + _sum5 = vrev64q_s32(_sum5); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + _sum4 = vextq_s32(_sum4, _sum4, 2); + _sum5 = vextq_s32(_sum5, _sum5, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c1); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c1); + _f7 = vaddq_f32(_f7, _c1); + } + if (broadcast_type_C == 3) + { + float32x4_t _c2; + float32x4_t _c3; + if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + _c2 = vld1q_f32(pC + 8); + _c3 = vld1q_f32(pC + 12); + } + if (c_elempack == 1) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + c_hstep); + _c2 = vld1q_f32(pC + c_hstep * 2); + _c3 = vld1q_f32(pC + c_hstep * 3); + transpose4x4_ps(_c0, _c1, _c2, _c3); + } + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + if (c_elempack == 4) + { + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 4 + 4); + _c2 = vld1q_f32(pC + c_hstep * 4 + 8); + _c3 = vld1q_f32(pC + c_hstep * 4 + 12); + pC += 16; + } + if (c_elempack == 1) + { + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 5); + _c2 = vld1q_f32(pC + c_hstep * 6); + _c3 = vld1q_f32(pC + c_hstep * 7); + transpose4x4_ps(_c0, _c1, _c2, _c3); + pC += 4; + } + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); + } + } + if (broadcast_type_C == 4) + { + float32x4_t _c = vld1q_f32(pC); + _c = vmulq_n_f32(_c, beta); +#if __aarch64__ + _c0 = vdupq_laneq_f32(_c, 0); + _c1 = vdupq_laneq_f32(_c, 1); + float32x4_t _c2 = vdupq_laneq_f32(_c, 2); + float32x4_t _c3 = vdupq_laneq_f32(_c, 3); +#else + _c0 = vdupq_lane_f32(vget_low_f32(_c), 0); + _c1 = vdupq_lane_f32(vget_low_f32(_c), 1); + float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_c), 0); + float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_c), 1); +#endif + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + pC += 4; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + + if (out_elempack == 4) + { + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + vst1q_f32(p0 + 8, _f2); + vst1q_f32(p0 + 12, _f3); + vst1q_f32(p0 + out_hstep * 4, _f4); + vst1q_f32(p0 + out_hstep * 4 + 4, _f5); + vst1q_f32(p0 + out_hstep * 4 + 8, _f6); + vst1q_f32(p0 + out_hstep * 4 + 12, _f7); + p0 += 16; + } + if (out_elempack == 1) + { + transpose4x4_ps(_f0, _f1, _f2, _f3); + transpose4x4_ps(_f4, _f5, _f6, _f7); + vst1q_f32(p0, _f0); + vst1q_f32(p0 + out_hstep, _f1); + vst1q_f32(p0 + out_hstep * 2, _f2); + vst1q_f32(p0 + out_hstep * 3, _f3); + vst1q_f32(p0 + out_hstep * 4, _f4); + vst1q_f32(p0 + out_hstep * 5, _f5); + vst1q_f32(p0 + out_hstep * 6, _f6); + vst1q_f32(p0 + out_hstep * 7, _f7); + p0 += 4; + } + + pp += 32; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // e0 f0 g0 h0 + // e1 f1 g1 h1 +#else + // from + // a0 b1 c0 d1 + // e0 f1 g0 h1 + // a1 b0 c1 d0 + // e1 f0 g1 h0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + { + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum2); + int32x4x2_t _t1 = vzipq_s32(_sum1, _sum3); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[0]), vget_low_s32(_t1.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[0]), vget_high_s32(_t1.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale1); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c1); + _f3 = vaddq_f32(_f3, _c1); + } + if (broadcast_type_C == 3) + { + float32x4_t _c2; + float32x4_t _c3; + if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + _c2 = vld1q_f32(pC + c_hstep * 4); + _c3 = vld1q_f32(pC + c_hstep * 4 + 4); + pC += 8; + } + if (c_elempack == 1) + { + float32x2_t _cc0 = vld1_f32(pC); + float32x2_t _cc1 = vld1_f32(pC + c_hstep); + float32x2_t _cc2 = vld1_f32(pC + c_hstep * 2); + float32x2_t _cc3 = vld1_f32(pC + c_hstep * 3); + float32x4_t _c01 = vcombine_f32(_cc0, _cc1); + float32x4_t _c23 = vcombine_f32(_cc2, _cc3); + float32x4x2_t _ccc0 = vuzpq_f32(_c01, _c23); + _c0 = _ccc0.val[0]; + _c1 = _ccc0.val[1]; + float32x2_t _cc4 = vld1_f32(pC + c_hstep * 4); + float32x2_t _cc5 = vld1_f32(pC + c_hstep * 5); + float32x2_t _cc6 = vld1_f32(pC + c_hstep * 6); + float32x2_t _cc7 = vld1_f32(pC + c_hstep * 7); + float32x4_t _c45 = vcombine_f32(_cc4, _cc5); + float32x4_t _c67 = vcombine_f32(_cc6, _cc7); + float32x4x2_t _ccc1 = vuzpq_f32(_c45, _c67); + _c2 = _ccc1.val[0]; + _c3 = _ccc1.val[1]; + pC += 2; + } + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + } + if (broadcast_type_C == 4) + { + float32x2_t _c = vld1_f32(pC); + _c = vmul_n_f32(_c, beta); + _c0 = vdupq_lane_f32(_c, 0); + _c1 = vdupq_lane_f32(_c, 1); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + pC += 2; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + + if (out_elempack == 4) + { + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + vst1q_f32(p0 + out_hstep * 4, _f2); + vst1q_f32(p0 + out_hstep * 4 + 4, _f3); + p0 += 8; + } + if (out_elempack == 1) + { + float32x4x2_t _f01 = vzipq_f32(_f0, _f1); + float32x4x2_t _f23 = vzipq_f32(_f2, _f3); + vst1_f32(p0, vget_low_f32(_f01.val[0])); + vst1_f32(p0 + out_hstep, vget_high_f32(_f01.val[0])); + vst1_f32(p0 + out_hstep * 2, vget_low_f32(_f01.val[1])); + vst1_f32(p0 + out_hstep * 3, vget_high_f32(_f01.val[1])); + vst1_f32(p0 + out_hstep * 4, vget_low_f32(_f23.val[0])); + vst1_f32(p0 + out_hstep * 5, vget_high_f32(_f23.val[0])); + vst1_f32(p0 + out_hstep * 6, vget_low_f32(_f23.val[1])); + vst1_f32(p0 + out_hstep * 7, vget_high_f32(_f23.val[1])); + p0 += 2; + } + + pp += 16; + } + for (; jj < max_jj; jj++) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + c_hstep * 4); + pC += 4; + } + if (c_elempack == 1) + { + _c0 = vsetq_lane_f32(pC[0], _c0, 0); + _c0 = vsetq_lane_f32(pC[c_hstep], _c0, 1); + _c0 = vsetq_lane_f32(pC[c_hstep * 2], _c0, 2); + _c0 = vsetq_lane_f32(pC[c_hstep * 3], _c0, 3); + _c1 = vsetq_lane_f32(pC[c_hstep * 4], _c1, 0); + _c1 = vsetq_lane_f32(pC[c_hstep * 5], _c1, 1); + _c1 = vsetq_lane_f32(pC[c_hstep * 6], _c1, 2); + _c1 = vsetq_lane_f32(pC[c_hstep * 7], _c1, 3); + pC += 1; + } + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(pC[0] * beta); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + pC += 1; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + + if (out_elempack == 4) + { + vst1q_f32(p0, _f0); + vst1q_f32(p0 + out_hstep * 4, _f1); + p0 += 4; + } + if (out_elempack == 1) + { + p0[0] = vgetq_lane_f32(_f0, 0); + p0[out_hstep] = vgetq_lane_f32(_f0, 1); + p0[out_hstep * 2] = vgetq_lane_f32(_f0, 2); + p0[out_hstep * 3] = vgetq_lane_f32(_f0, 3); + p0[out_hstep * 4] = vgetq_lane_f32(_f1, 0); + p0[out_hstep * 5] = vgetq_lane_f32(_f1, 1); + p0[out_hstep * 6] = vgetq_lane_f32(_f1, 2); + p0[out_hstep * 7] = vgetq_lane_f32(_f1, 3); + p0++; + } + + pp += 8; + } + } + for (; ii + 3 < max_ii; ii += 4) + { + float* p0 = (float*)top_blob + (i + ii) * out_hstep + j * out_elempack; + + float32x4_t _descale = vld1q_f32((const float*)descales + ii); + + float32x4_t _c0; + if (pC) + { + if (broadcast_type_C == 0) + { + _c0 = vdupq_n_f32(pC[0] * beta); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)C + i + ii; + _c0 = vld1q_f32(pC); + _c0 = vmulq_n_f32(_c0, beta); + } + if (broadcast_type_C == 3) + { + pC = (const float*)C + (i + ii) * c_hstep + j * c_elempack; + } + if (broadcast_type_C == 4) + { + pC = (const float*)C + j; + } + } + + int jj = 0; +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 +#else + // from + // a0 b1 c2 d3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // c7 d6 a5 b4 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + { + _sum4 = vrev64q_s32(_sum4); + _sum5 = vrev64q_s32(_sum5); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + _sum4 = vextq_s32(_sum4, _sum4, 2); + _sum5 = vextq_s32(_sum5, _sum5, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 3) + { + float32x4_t _c1; + float32x4_t _c2; + float32x4_t _c3; + float32x4_t _c4; + float32x4_t _c5; + float32x4_t _c6; + float32x4_t _c7; + if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + _c2 = vld1q_f32(pC + 8); + _c3 = vld1q_f32(pC + 12); + _c4 = vld1q_f32(pC + 16); + _c5 = vld1q_f32(pC + 20); + _c6 = vld1q_f32(pC + 24); + _c7 = vld1q_f32(pC + 28); + pC += 32; + } + if (c_elempack == 1) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + _c2 = vld1q_f32(pC + c_hstep); + _c3 = vld1q_f32(pC + c_hstep + 4); + _c4 = vld1q_f32(pC + c_hstep * 2); + _c5 = vld1q_f32(pC + c_hstep * 2 + 4); + _c6 = vld1q_f32(pC + c_hstep * 3); + _c7 = vld1q_f32(pC + c_hstep * 3 + 4); + transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); + pC += 8; + } + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + } + if (broadcast_type_C == 4) + { + float32x4_t _cc0 = vld1q_f32(pC); + float32x4_t _cc1 = vld1q_f32(pC + 4); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _cc0 = vmulq_f32(_cc0, _beta); + _cc1 = vmulq_f32(_cc1, _beta); + } + _c0 = vdupq_laneq_f32(_cc0, 0); + float32x4_t _c1 = vdupq_laneq_f32(_cc0, 1); + float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); + float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); + float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); + float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); + float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); + float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + pC += 8; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + + if (out_elempack == 4) + { + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + vst1q_f32(p0 + 8, _f2); + vst1q_f32(p0 + 12, _f3); + vst1q_f32(p0 + 16, _f4); + vst1q_f32(p0 + 20, _f5); + vst1q_f32(p0 + 24, _f6); + vst1q_f32(p0 + 28, _f7); + p0 += 32; + } + if (out_elempack == 1) + { + transpose4x4_ps(_f0, _f1, _f2, _f3); + transpose4x4_ps(_f4, _f5, _f6, _f7); + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f4); + vst1q_f32(p0 + out_hstep, _f1); + vst1q_f32(p0 + out_hstep + 4, _f5); + vst1q_f32(p0 + out_hstep * 2, _f2); + vst1q_f32(p0 + out_hstep * 2 + 4, _f6); + vst1q_f32(p0 + out_hstep * 3, _f3); + vst1q_f32(p0 + out_hstep * 3 + 4, _f7); + p0 += 8; + } + + pp += 32; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 +#else + // from + // a0 b1 c2 d3 + // c0 d1 a2 b3 + // a3 b2 c1 d0 + // c3 d2 a1 b0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + { + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + _sum2 = vextq_s32(_sum2, _sum2, 2); + _sum3 = vextq_s32(_sum3, _sum3, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum3); + int32x4x2_t _t1 = vzipq_s32(_sum1, _sum2); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 3) + { + float32x4_t _c1; + float32x4_t _c2; + float32x4_t _c3; + if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + _c2 = vld1q_f32(pC + 8); + _c3 = vld1q_f32(pC + 12); + pC += 16; + } + if (c_elempack == 1) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + c_hstep * 1); + _c2 = vld1q_f32(pC + c_hstep * 2); + _c3 = vld1q_f32(pC + c_hstep * 3); + transpose4x4_ps(_c0, _c1, _c2, _c3); + pC += 4; + } + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + } + if (broadcast_type_C == 4) + { + float32x4_t _c = vld1q_f32(pC); + _c = vmulq_n_f32(_c, beta); +#if __aarch64__ + _c0 = vdupq_laneq_f32(_c, 0); + float32x4_t _c1 = vdupq_laneq_f32(_c, 1); + float32x4_t _c2 = vdupq_laneq_f32(_c, 2); + float32x4_t _c3 = vdupq_laneq_f32(_c, 3); +#else + _c0 = vdupq_lane_f32(vget_low_f32(_c), 0); + float32x4_t _c1 = vdupq_lane_f32(vget_low_f32(_c), 1); + float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_c), 0); + float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_c), 1); +#endif + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 4; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + + if (out_elempack == 4) + { + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + vst1q_f32(p0 + 8, _f2); + vst1q_f32(p0 + 12, _f3); + p0 += 16; + } + if (out_elempack == 1) + { + transpose4x4_ps(_f0, _f1, _f2, _f3); + vst1q_f32(p0, _f0); + vst1q_f32(p0 + out_hstep, _f1); + vst1q_f32(p0 + out_hstep * 2, _f2); + vst1q_f32(p0 + out_hstep * 3, _f3); + p0 += 4; + } + + pp += 16; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 +#else + // from + // a0 b1 c0 d1 + // a1 b0 c1 d0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + { + _sum1 = vrev64q_s32(_sum1); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); + _sum1 = vrev64q_s32(_sum1); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 3) + { + float32x4_t _c1; + if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + pC += 8; + } + if (c_elempack == 1) + { + float32x2_t _cc0 = vld1_f32(pC); + float32x2_t _cc1 = vld1_f32(pC + c_hstep); + float32x2_t _cc2 = vld1_f32(pC + c_hstep * 2); + float32x2_t _cc3 = vld1_f32(pC + c_hstep * 3); + float32x4_t _c01 = vcombine_f32(_cc0, _cc1); + float32x4_t _c23 = vcombine_f32(_cc2, _cc3); + float32x4x2_t _cc = vuzpq_f32(_c01, _c23); + _c0 = _cc.val[0]; + _c1 = _cc.val[1]; + pC += 2; + } + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + } + } + if (broadcast_type_C == 4) + { + float32x2_t _c = vld1_f32(pC); + _c = vmul_n_f32(_c, beta); + _c0 = vdupq_lane_f32(_c, 0); + float32x4_t _c1 = vdupq_lane_f32(_c, 1); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 2; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + + if (out_elempack == 4) + { + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + p0 += 8; + } + if (out_elempack == 1) + { + float32x4x2_t _f01 = vzipq_f32(_f0, _f1); + vst1_f32(p0, vget_low_f32(_f01.val[0])); + vst1_f32(p0 + out_hstep, vget_high_f32(_f01.val[0])); + vst1_f32(p0 + out_hstep * 2, vget_low_f32(_f01.val[1])); + vst1_f32(p0 + out_hstep * 3, vget_high_f32(_f01.val[1])); + p0 += 2; + } + + pp += 8; + } + for (; jj < max_jj; jj++) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + pC += 4; + } + if (c_elempack == 1) + { + _c0 = vsetq_lane_f32(pC[0], _c0, 0); + _c0 = vsetq_lane_f32(pC[c_hstep], _c0, 1); + _c0 = vsetq_lane_f32(pC[c_hstep * 2], _c0, 2); + _c0 = vsetq_lane_f32(pC[c_hstep * 3], _c0, 3); + pC += 1; + } + _f0 = vmlaq_n_f32(_f0, _c0, beta); + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(pC[0] * beta); + _f0 = vaddq_f32(_f0, _c0); + pC += 1; + } + } + + _f0 = vmulq_n_f32(_f0, alpha); + + if (out_elempack == 4) + { + vst1q_f32(p0, _f0); + p0 += 4; + } + if (out_elempack == 1) + { + p0[0] = vgetq_lane_f32(_f0, 0); + p0[out_hstep] = vgetq_lane_f32(_f0, 1); + p0[out_hstep * 2] = vgetq_lane_f32(_f0, 2); + p0[out_hstep * 3] = vgetq_lane_f32(_f0, 3); + p0++; + } + + pp += 4; + } + } +#endif // __ARM_NEON + for (; ii + 1 < max_ii; ii += 2) + { + // out_elempack == 1 + float* p0 = (float*)top_blob + (i + ii) * out_hstep + j; + + const float descale0 = descales[ii]; + const float descale1 = descales[ii + 1]; +#if __ARM_NEON + float32x2_t _descale = vld1_f32((const float*)descales + ii); +#endif + + float c0; + float c1; +#if __ARM_NEON + float32x4_t _c0; + float32x4_t _c1; +#endif + if (pC) + { + if (broadcast_type_C == 0) + { + c0 = pC[0] * beta; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); +#endif + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)C + i + ii; + c0 = pC[0] * beta; + c1 = pC[1] * beta; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); + _c1 = vdupq_n_f32(c1); +#endif + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + pC = (const float*)C + (i + ii) * c_hstep + j; + } + if (broadcast_type_C == 4) + { + pC = (const float*)C + j; + } + } + + int jj = 0; +#if __ARM_NEON +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale, 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale, 0); + float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), _descale, 1); + float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), _descale, 1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c1); + _f3 = vaddq_f32(_f3, _c1); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + c_hstep); + float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + pC += 8; + } + if (broadcast_type_C == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _c0 = vmulq_f32(_c0, _beta); + _c1 = vmulq_f32(_c1, _beta); + } + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + pC += 8; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + vst1q_f32(p0 + out_hstep, _f2); + vst1q_f32(p0 + out_hstep + 4, _f3); + + pp += 16; + p0 += 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale, 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale, 1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + c_hstep); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + } + pC += 4; + } + if (broadcast_type_C == 4) + { + _c0 = vld1q_f32(pC); + _c0 = vmulq_n_f32(_c0, beta); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + pC += 4; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + out_hstep, _f1); + + pp += 8; + p0 += 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + + float32x2x2_t _descale01 = vzip_f32(_descale, _descale); + float32x4_t _descale0011 = vcombine_f32(_descale01.val[0], _descale01.val[1]); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0011); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + float32x4_t _c0011 = vcombine_f32(vget_low_f32(_c0), vget_high_f32(_c1)); + _f0 = vaddq_f32(_f0, _c0011); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + _c0 = vcombine_f32(vld1_f32(pC), vld1_f32(pC + c_hstep)); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + pC += 2; + } + if (broadcast_type_C == 4) + { + float32x2_t _c = vld1_f32(pC); + _c0 = vcombine_f32(_c, _c); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + pC += 2; + } + } + + _f0 = vmulq_n_f32(_f0, alpha); + + vst1_f32(p0, vget_low_f32(_f0)); + vst1_f32(p0 + out_hstep, vget_high_f32(_f0)); + + pp += 4; + p0 += 2; + } +#endif // __ARM_NEON + for (; jj < max_jj; jj++) + { + float f0 = pp[0] * descale0; + float f1 = pp[1] * descale1; + + if (pC) + { + if (broadcast_type_C == 0) + { + f0 += c0; + f1 += c0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + f0 += c0; + f1 += c1; + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + f0 += pC[0] * beta; + f1 += pC[c_hstep] * beta; + pC += 1; + } + if (broadcast_type_C == 4) + { + f0 += pC[0] * beta; + f1 += pC[0] * beta; + pC += 1; + } + } + + f0 *= alpha; + f1 *= alpha; + + p0[0] = f0; + p0[out_hstep] = f1; + + pp += 2; + p0++; + } + } + for (; ii < max_ii; ii += 1) + { + // out_elempack == 1 + float* p0 = (float*)top_blob + (i + ii) * out_hstep + j; + + const float descale = descales[ii]; +#if __ARM_NEON + float32x4_t _descale = vdupq_n_f32(descale); +#endif + + float c0; +#if __ARM_NEON + float32x4_t _c0; +#endif + if (pC) + { + if (broadcast_type_C == 0) + { + c0 = pC[0] * beta; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); +#endif + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)C + i + ii; + c0 = pC[0] * beta; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); +#endif + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + pC = (const float*)C + (i + ii) * c_hstep + j; + } + if (broadcast_type_C == 4) + { + pC = (const float*)C + j; + } + } + + int jj = 0; +#if __ARM_NEON + for (; jj + 15 < max_jj; jj += 16) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // out_elempack == 1 + _c0 = vld1q_f32(pC); + float32x4_t _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + 8); + float32x4_t _c3 = vld1q_f32(pC + 12); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + pC += 16; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + vst1q_f32(p0 + 8, _f2); + vst1q_f32(p0 + 12, _f3); + + pp += 16; + p0 += 16; + } + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // out_elempack == 1 + _c0 = vld1q_f32(pC); + float32x4_t _c1 = vld1q_f32(pC + 4); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + } + pC += 8; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + + pp += 8; + p0 += 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // out_elempack == 1 + _c0 = vld1q_f32(pC); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + pC += 4; + } + } + + _f0 = vmulq_n_f32(_f0, alpha); + + vst1q_f32(p0, _f0); + + pp += 4; + p0 += 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + float32x2_t _f0 = vmul_f32(vcvt_f32_s32(vld1_s32(pp)), vget_low_f32(_descale)); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vadd_f32(_f0, vget_low_f32(_c0)); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // out_elempack == 1 + float32x2_t _c = vld1_f32(pC); + _f0 = vmla_n_f32(_f0, _c, beta); + pC += 2; + } + } + + _f0 = vmul_n_f32(_f0, alpha); + + vst1_f32(p0, _f0); + + pp += 2; + p0 += 2; + } +#endif // __ARM_NEON + for (; jj < max_jj; jj++) + { + float f0 = pp[0] * descale; + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + f0 += c0; + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // out_elempack == 1 + f0 += pC[0] * beta; + pC += 1; + } + } + + f0 *= alpha; + + p0[0] = f0; + + pp += 1; + p0++; + } + } +} + +static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + transpose_unpack_output_tile_int32_to_fp32_asimddp(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales, alpha, beta); + return; + } +#endif + + const int out_elempack = top_blob.elempack; + const int out_hstep = top_blob.dims == 3 ? (int)top_blob.cstep : top_blob.w; + + const int c_hstep = C.dims == 3 ? (int)C.cstep : C.w; + const int c_elempack = C.elempack; + const float* pC = C; + + // NCNN_LOGE("transpose_unpack_output_tile_int32_to_fp32 %d %d %d %d %d %d %d", i, max_ii, j, max_jj, out_elempack, broadcast_type_C, c_elempack); + + const int* pp = topT; + + int ii = 0; +#if __ARM_NEON + for (; ii + 7 < max_ii; ii += 8) + { + float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * out_elempack; + + float32x4_t _descale0 = vld1q_f32((const float*)descales + ii); + float32x4_t _descale1 = vld1q_f32((const float*)descales + ii + 4); + + float32x4_t _c0; + float32x4_t _c1; + if (pC) + { + if (broadcast_type_C == 0) + { + _c0 = vdupq_n_f32(pC[0] * beta); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)C + i + ii; + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + _c0 = vmulq_n_f32(_c0, beta); + _c1 = vmulq_n_f32(_c1, beta); + } + if (broadcast_type_C == 3) + { + pC = (const float*)C + (i + ii) * c_hstep + j * c_elempack; + } + if (broadcast_type_C == 4) + { + pC = (const float*)C + j; + } + } + + int jj = 0; +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + int32x4_t _sum8 = vld1q_s32(pp + 32); + int32x4_t _sum9 = vld1q_s32(pp + 36); + int32x4_t _suma = vld1q_s32(pp + 40); + int32x4_t _sumb = vld1q_s32(pp + 44); + int32x4_t _sumc = vld1q_s32(pp + 48); + int32x4_t _sumd = vld1q_s32(pp + 52); + int32x4_t _sume = vld1q_s32(pp + 56); + int32x4_t _sumf = vld1q_s32(pp + 60); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 +#else + // from + // a0 b1 c2 d3 + // e4 f5 g6 h7 + // e0 f1 g2 h3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // g4 h5 e6 f7 + // g0 h1 e2 f3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // e7 f6 g5 h4 + // e3 f2 g1 h0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // g7 h6 e5 f4 + // g3 h2 e1 f0 + // c7 d6 a5 b4 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 + { + _sum8 = vrev64q_s32(_sum8); + _sum9 = vrev64q_s32(_sum9); + _suma = vrev64q_s32(_suma); + _sumb = vrev64q_s32(_sumb); + _sumc = vrev64q_s32(_sumc); + _sumd = vrev64q_s32(_sumd); + _sume = vrev64q_s32(_sume); + _sumf = vrev64q_s32(_sumf); + _sum8 = vextq_s32(_sum8, _sum8, 2); + _sum9 = vextq_s32(_sum9, _sum9, 2); + _suma = vextq_s32(_suma, _suma, 2); + _sumb = vextq_s32(_sumb, _sumb, 2); + _sumc = vextq_s32(_sumc, _sumc, 2); + _sumd = vextq_s32(_sumd, _sumd, 2); + _sume = vextq_s32(_sume, _sume, 2); + _sumf = vextq_s32(_sumf, _sumf, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sumc); + int32x4x2_t _t1 = vzipq_s32(_sum4, _sum8); + int32x4x2_t _t2 = vzipq_s32(_sum2, _sume); + int32x4x2_t _t3 = vzipq_s32(_sum6, _suma); + int32x4x2_t _t4 = vzipq_s32(_sum3, _sumf); + int32x4x2_t _t5 = vzipq_s32(_sum7, _sumb); + int32x4x2_t _t6 = vzipq_s32(_sum1, _sumd); + int32x4x2_t _t7 = vzipq_s32(_sum5, _sum9); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum8 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); + _sum9 = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); + _suma = vcombine_s32(vget_low_s32(_t5.val[1]), vget_low_s32(_t4.val[1])); + _sumb = vcombine_s32(vget_high_s32(_t5.val[1]), vget_high_s32(_t4.val[1])); + _sumc = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); + _sumd = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); + _sume = vcombine_s32(vget_low_s32(_t7.val[1]), vget_low_s32(_t6.val[1])); + _sumf = vcombine_s32(vget_high_s32(_t7.val[1]), vget_high_s32(_t6.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + _sum9 = vrev64q_s32(_sum9); + _sumb = vrev64q_s32(_sumb); + _sumd = vrev64q_s32(_sumd); + _sumf = vrev64q_s32(_sumf); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum8), _descale0); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum9), _descale0); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_suma), _descale0); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sumb), _descale0); + float32x4_t _f8 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); + float32x4_t _f9 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); + float32x4_t _fa = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); + float32x4_t _fb = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); + float32x4_t _fc = vmulq_f32(vcvtq_f32_s32(_sumc), _descale1); + float32x4_t _fd = vmulq_f32(vcvtq_f32_s32(_sumd), _descale1); + float32x4_t _fe = vmulq_f32(vcvtq_f32_s32(_sume), _descale1); + float32x4_t _ff = vmulq_f32(vcvtq_f32_s32(_sumf), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c0); + _fa = vaddq_f32(_fa, _c0); + _fb = vaddq_f32(_fb, _c0); + _fc = vaddq_f32(_fc, _c0); + _fd = vaddq_f32(_fd, _c0); + _fe = vaddq_f32(_fe, _c0); + _ff = vaddq_f32(_ff, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + _f8 = vaddq_f32(_f8, _c1); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c1); + _fb = vaddq_f32(_fb, _c1); + _fc = vaddq_f32(_fc, _c1); + _fd = vaddq_f32(_fd, _c1); + _fe = vaddq_f32(_fe, _c1); + _ff = vaddq_f32(_ff, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + 8); + float32x4_t _c3 = vld1q_f32(pC + 12); + float32x4_t _c4 = vld1q_f32(pC + 16); + float32x4_t _c5 = vld1q_f32(pC + 20); + float32x4_t _c6 = vld1q_f32(pC + 24); + float32x4_t _c7 = vld1q_f32(pC + 28); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 4 + 4); + _c2 = vld1q_f32(pC + c_hstep * 4 + 8); + _c3 = vld1q_f32(pC + c_hstep * 4 + 12); + _c4 = vld1q_f32(pC + c_hstep * 4 + 16); + _c5 = vld1q_f32(pC + c_hstep * 4 + 20); + _c6 = vld1q_f32(pC + c_hstep * 4 + 24); + _c7 = vld1q_f32(pC + c_hstep * 4 + 28); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); + } + pC += 32; + } + if (c_elempack == 1) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + c_hstep); + float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); + float32x4_t _c4 = vld1q_f32(pC + c_hstep * 2); + float32x4_t _c5 = vld1q_f32(pC + c_hstep * 2 + 4); + float32x4_t _c6 = vld1q_f32(pC + c_hstep * 3); + float32x4_t _c7 = vld1q_f32(pC + c_hstep * 3 + 4); + transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 4 + 4); + _c2 = vld1q_f32(pC + c_hstep * 5); + _c3 = vld1q_f32(pC + c_hstep * 5 + 4); + _c4 = vld1q_f32(pC + c_hstep * 6); + _c5 = vld1q_f32(pC + c_hstep * 6 + 4); + _c6 = vld1q_f32(pC + c_hstep * 7); + _c7 = vld1q_f32(pC + c_hstep * 7 + 4); + transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); + } + pC += 8; + } + } + if (broadcast_type_C == 4) + { + float32x4_t _cc0 = vld1q_f32(pC); + float32x4_t _cc1 = vld1q_f32(pC + 4); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _cc0 = vmulq_f32(_cc0, _beta); + _cc1 = vmulq_f32(_cc1, _beta); + } + _c0 = vdupq_laneq_f32(_cc0, 0); + _c1 = vdupq_laneq_f32(_cc0, 1); + float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); + float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); + float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); + float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); + float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); + float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + pC += 8; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + _f8 = vmulq_f32(_f8, _alpha); + _f9 = vmulq_f32(_f9, _alpha); + _fa = vmulq_f32(_fa, _alpha); + _fb = vmulq_f32(_fb, _alpha); + _fc = vmulq_f32(_fc, _alpha); + _fd = vmulq_f32(_fd, _alpha); + _fe = vmulq_f32(_fe, _alpha); + _ff = vmulq_f32(_ff, _alpha); + } + + if (out_elempack == 4) + { + float32x4x4_t _ffa; + float32x4x4_t _ffb; + float32x4x4_t _ffc; + float32x4x4_t _ffd; + _ffa.val[0] = _f0; + _ffa.val[1] = _f1; + _ffa.val[2] = _f2; + _ffa.val[3] = _f3; + _ffb.val[0] = _f4; + _ffb.val[1] = _f5; + _ffb.val[2] = _f6; + _ffb.val[3] = _f7; + _ffc.val[0] = _f8; + _ffc.val[1] = _f9; + _ffc.val[2] = _fa; + _ffc.val[3] = _fb; + _ffd.val[0] = _fc; + _ffd.val[1] = _fd; + _ffd.val[2] = _fe; + _ffd.val[3] = _ff; + vst4q_f32(p0, _ffa); + vst4q_f32(p0 + 16, _ffc); + vst4q_f32(p0 + out_hstep * 4, _ffb); + vst4q_f32(p0 + out_hstep * 4 + 16, _ffd); + } + if (out_elempack == 1) + { + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f8); + vst1q_f32(p0 + out_hstep, _f1); + vst1q_f32(p0 + out_hstep + 4, _f9); + vst1q_f32(p0 + out_hstep * 2, _f2); + vst1q_f32(p0 + out_hstep * 2 + 4, _fa); + vst1q_f32(p0 + out_hstep * 3, _f3); + vst1q_f32(p0 + out_hstep * 3 + 4, _fb); + vst1q_f32(p0 + out_hstep * 4, _f4); + vst1q_f32(p0 + out_hstep * 4 + 4, _fc); + vst1q_f32(p0 + out_hstep * 5, _f5); + vst1q_f32(p0 + out_hstep * 5 + 4, _fd); + vst1q_f32(p0 + out_hstep * 6, _f6); + vst1q_f32(p0 + out_hstep * 6 + 4, _fe); + vst1q_f32(p0 + out_hstep * 7, _f7); + vst1q_f32(p0 + out_hstep * 7 + 4, _ff); + } + + pp += 64; + p0 += out_hstep * 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + +#else + // from + // a0 b1 c2 d3 + // e0 f1 g2 h3 + // c0 d1 a2 b3 + // g0 h1 e2 f3 + // a3 b2 c1 d0 + // e3 f2 g1 h0 + // c3 d2 a1 b0 + // g3 h2 e1 f0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + { + _sum4 = vrev64q_s32(_sum4); + _sum5 = vrev64q_s32(_sum5); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + _sum4 = vextq_s32(_sum4, _sum4, 2); + _sum5 = vextq_s32(_sum5, _sum5, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c1); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c1); + _f7 = vaddq_f32(_f7, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + 8); + float32x4_t _c3 = vld1q_f32(pC + 12); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 4 + 4); + _c2 = vld1q_f32(pC + c_hstep * 4 + 8); + _c3 = vld1q_f32(pC + c_hstep * 4 + 12); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); + } + pC += 16; + } + if (c_elempack == 1) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + c_hstep); + float32x4_t _c2 = vld1q_f32(pC + c_hstep * 2); + float32x4_t _c3 = vld1q_f32(pC + c_hstep * 3); + transpose4x4_ps(_c0, _c1, _c2, _c3); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 5); + _c2 = vld1q_f32(pC + c_hstep * 6); + _c3 = vld1q_f32(pC + c_hstep * 7); + transpose4x4_ps(_c0, _c1, _c2, _c3); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); + } + pC += 4; + } + } + if (broadcast_type_C == 4) + { + float32x4_t _cc = vld1q_f32(pC); + _cc = vmulq_n_f32(_cc, beta); +#if __aarch64__ + _c0 = vdupq_laneq_f32(_cc, 0); + _c1 = vdupq_laneq_f32(_cc, 1); + float32x4_t _c2 = vdupq_laneq_f32(_cc, 2); + float32x4_t _c3 = vdupq_laneq_f32(_cc, 3); +#else + _c0 = vdupq_lane_f32(vget_low_f32(_cc), 0); + _c1 = vdupq_lane_f32(vget_low_f32(_cc), 1); + float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_cc), 0); + float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_cc), 1); +#endif + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + pC += 4; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + + if (out_elempack == 4) + { + float32x4x4_t _fa; + float32x4x4_t _fb; + _fa.val[0] = _f0; + _fa.val[1] = _f1; + _fa.val[2] = _f2; + _fa.val[3] = _f3; + _fb.val[0] = _f4; + _fb.val[1] = _f5; + _fb.val[2] = _f6; + _fb.val[3] = _f7; + vst4q_f32(p0, _fa); + vst4q_f32(p0 + 16, _fb); + } + if (out_elempack == 1) + { + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f4); + vst1q_f32(p0 + out_hstep, _f1); + vst1q_f32(p0 + out_hstep + 4, _f5); + vst1q_f32(p0 + out_hstep * 2, _f2); + vst1q_f32(p0 + out_hstep * 2 + 4, _f6); + vst1q_f32(p0 + out_hstep * 3, _f3); + vst1q_f32(p0 + out_hstep * 3 + 4, _f7); + } + + pp += 32; + p0 += out_hstep * 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // e0 f0 g0 h0 + // e1 f1 g1 h1 +#else + // from + // a0 b1 c0 d1 + // e0 f1 g0 h1 + // a1 b0 c1 d0 + // e1 f0 g1 h0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + { + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum2); + int32x4x2_t _t1 = vzipq_s32(_sum1, _sum3); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[0]), vget_low_s32(_t1.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[0]), vget_high_s32(_t1.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale1); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c1); + _f3 = vaddq_f32(_f3, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + float32x2_t _cc0 = vld1_f32(pC); + float32x2_t _cc1 = vld1_f32(pC + c_hstep); + float32x2_t _cc2 = vld1_f32(pC + c_hstep * 2); + float32x2_t _cc3 = vld1_f32(pC + c_hstep * 3); + float32x2_t _cc4 = vld1_f32(pC + c_hstep * 4); + float32x2_t _cc5 = vld1_f32(pC + c_hstep * 5); + float32x2_t _cc6 = vld1_f32(pC + c_hstep * 6); + float32x2_t _cc7 = vld1_f32(pC + c_hstep * 7); + float32x4_t _cc01 = vcombine_f32(_cc0, _cc1); + float32x4_t _cc23 = vcombine_f32(_cc2, _cc3); + float32x4_t _cc45 = vcombine_f32(_cc4, _cc5); + float32x4_t _cc67 = vcombine_f32(_cc6, _cc7); + float32x4x2_t _ccc0 = vuzpq_f32(_cc01, _cc23); + float32x4x2_t _ccc1 = vuzpq_f32(_cc45, _cc67); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _ccc0.val[0]); + _f1 = vaddq_f32(_f1, _ccc0.val[1]); + _f2 = vaddq_f32(_f2, _ccc1.val[0]); + _f3 = vaddq_f32(_f3, _ccc1.val[1]); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _ccc0.val[0], _beta); + _f1 = vmlaq_f32(_f1, _ccc0.val[1], _beta); + _f2 = vmlaq_f32(_f2, _ccc1.val[0], _beta); + _f3 = vmlaq_f32(_f3, _ccc1.val[1], _beta); + } + pC += 2; + } + else // if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + c_hstep * 4); + float32x4_t _c3 = vld1q_f32(pC + c_hstep * 4 + 4); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + pC += 8; + } + } + if (broadcast_type_C == 4) + { + float32x2_t _cc = vld1_f32(pC); + _cc = vmul_n_f32(_cc, beta); + _c0 = vdupq_lane_f32(_cc, 0); + _c1 = vdupq_lane_f32(_cc, 1); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + pC += 2; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f2); + vst1q_f32(p0 + out_hstep, _f1); + vst1q_f32(p0 + out_hstep + 4, _f3); + + pp += 16; + p0 += out_hstep * 2; + } + for (; jj < max_jj; jj += 1) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp + 4)), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + _c0 = vsetq_lane_f32(pC[0], _c0, 0); + _c0 = vsetq_lane_f32(pC[c_hstep], _c0, 1); + _c0 = vsetq_lane_f32(pC[c_hstep * 2], _c0, 2); + _c0 = vsetq_lane_f32(pC[c_hstep * 3], _c0, 3); + _c1 = vsetq_lane_f32(pC[c_hstep * 4], _c1, 0); + _c1 = vsetq_lane_f32(pC[c_hstep * 5], _c1, 1); + _c1 = vsetq_lane_f32(pC[c_hstep * 6], _c1, 2); + _c1 = vsetq_lane_f32(pC[c_hstep * 7], _c1, 3); + pC += 1; + } + else // if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + c_hstep * 4); + pC += 4; + } + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(pC[0] * beta); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + pC += 1; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + pp += 8; + p0 += out_hstep; + } + } + for (; ii + 3 < max_ii; ii += 4) + { + float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * out_elempack; + + float32x4_t _descale = vld1q_f32((const float*)descales + ii); + + float32x4_t _c0; + if (pC) + { + if (broadcast_type_C == 0) + { + _c0 = vdupq_n_f32(pC[0] * beta); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)C + i + ii; + _c0 = vld1q_f32(pC); + _c0 = vmulq_n_f32(_c0, beta); + } + if (broadcast_type_C == 3) + { + pC = (const float*)C + (i + ii) * c_hstep + j * c_elempack; + } + if (broadcast_type_C == 4) + { + pC = (const float*)C + j; + } + } + + int jj = 0; +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 +#else + // from + // a0 b1 c2 d3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // c7 d6 a5 b4 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + { + _sum4 = vrev64q_s32(_sum4); + _sum5 = vrev64q_s32(_sum5); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + _sum4 = vextq_s32(_sum4, _sum4, 2); + _sum5 = vextq_s32(_sum5, _sum5, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 3) + { + float32x4_t _c1; + float32x4_t _c2; + float32x4_t _c3; + float32x4_t _c4; + float32x4_t _c5; + float32x4_t _c6; + float32x4_t _c7; + if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + _c2 = vld1q_f32(pC + 8); + _c3 = vld1q_f32(pC + 12); + _c4 = vld1q_f32(pC + 16); + _c5 = vld1q_f32(pC + 20); + _c6 = vld1q_f32(pC + 24); + _c7 = vld1q_f32(pC + 28); + pC += 32; + } + if (c_elempack == 1) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + _c2 = vld1q_f32(pC + c_hstep); + _c3 = vld1q_f32(pC + c_hstep + 4); + _c4 = vld1q_f32(pC + c_hstep * 2); + _c5 = vld1q_f32(pC + c_hstep * 2 + 4); + _c6 = vld1q_f32(pC + c_hstep * 3); + _c7 = vld1q_f32(pC + c_hstep * 3 + 4); + transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); + pC += 8; + } + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + } + if (broadcast_type_C == 4) + { + float32x4_t _cc0 = vld1q_f32(pC); + float32x4_t _cc1 = vld1q_f32(pC + 4); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _cc0 = vmulq_f32(_cc0, _beta); + _cc1 = vmulq_f32(_cc1, _beta); + } + _c0 = vdupq_laneq_f32(_cc0, 0); + float32x4_t _c1 = vdupq_laneq_f32(_cc0, 1); + float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); + float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); + float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); + float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); + float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); + float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + pC += 8; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + + if (out_elempack == 4) + { + float32x4x4_t _fa; + float32x4x4_t _fb; + _fa.val[0] = _f0; + _fa.val[1] = _f1; + _fa.val[2] = _f2; + _fa.val[3] = _f3; + _fb.val[0] = _f4; + _fb.val[1] = _f5; + _fb.val[2] = _f6; + _fb.val[3] = _f7; + vst4q_f32(p0, _fa); + vst4q_f32(p0 + out_hstep * 4, _fb); + } + if (out_elempack == 1) + { + vst1q_f32(p0, _f0); + vst1q_f32(p0 + out_hstep, _f1); + vst1q_f32(p0 + out_hstep * 2, _f2); + vst1q_f32(p0 + out_hstep * 3, _f3); + vst1q_f32(p0 + out_hstep * 4, _f4); + vst1q_f32(p0 + out_hstep * 5, _f5); + vst1q_f32(p0 + out_hstep * 6, _f6); + vst1q_f32(p0 + out_hstep * 7, _f7); + } + + pp += 32; + p0 += out_hstep * 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 +#else + // from + // a0 b1 c2 d3 + // c0 d1 a2 b3 + // a3 b2 c1 d0 + // c3 d2 a1 b0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + { + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + _sum2 = vextq_s32(_sum2, _sum2, 2); + _sum3 = vextq_s32(_sum3, _sum3, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum3); + int32x4x2_t _t1 = vzipq_s32(_sum1, _sum2); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 3) + { + float32x4_t _c1; + float32x4_t _c2; + float32x4_t _c3; + if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + _c2 = vld1q_f32(pC + 8); + _c3 = vld1q_f32(pC + 12); + pC += 16; + } + if (c_elempack == 1) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + c_hstep); + _c2 = vld1q_f32(pC + c_hstep * 2); + _c3 = vld1q_f32(pC + c_hstep * 3); + transpose4x4_ps(_c0, _c1, _c2, _c3); + pC += 4; + } + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + } + if (broadcast_type_C == 4) + { + float32x4_t _cc = vld1q_f32(pC); + _cc = vmulq_n_f32(_cc, beta); +#if __aarch64__ + _c0 = vdupq_laneq_f32(_cc, 0); + float32x4_t _c1 = vdupq_laneq_f32(_cc, 1); + float32x4_t _c2 = vdupq_laneq_f32(_cc, 2); + float32x4_t _c3 = vdupq_laneq_f32(_cc, 3); +#else + _c0 = vdupq_lane_f32(vget_low_f32(_cc), 0); + float32x4_t _c1 = vdupq_lane_f32(vget_low_f32(_cc), 1); + float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_cc), 0); + float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_cc), 1); +#endif + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 4; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + + if (out_elempack == 4) + { + float32x4x4_t _f; + _f.val[0] = _f0; + _f.val[1] = _f1; + _f.val[2] = _f2; + _f.val[3] = _f3; + vst4q_f32(p0, _f); + } + if (out_elempack == 1) + { + vst1q_f32(p0, _f0); + vst1q_f32(p0 + out_hstep, _f1); + vst1q_f32(p0 + out_hstep * 2, _f2); + vst1q_f32(p0 + out_hstep * 3, _f3); + } + + pp += 16; + p0 += out_hstep * 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 +#else + // from + // a0 b1 c0 d1 + // a1 b0 c1 d0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + { + _sum1 = vrev64q_s32(_sum1); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); + _sum1 = vrev64q_s32(_sum1); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 3) + { + float32x4_t _c1; + if (c_elempack == 1) + { + float32x2_t _cc0 = vld1_f32(pC); + float32x2_t _cc1 = vld1_f32(pC + c_hstep); + float32x2_t _cc2 = vld1_f32(pC + c_hstep * 2); + float32x2_t _cc3 = vld1_f32(pC + c_hstep * 3); + float32x4_t _cc01 = vcombine_f32(_cc0, _cc1); + float32x4_t _cc23 = vcombine_f32(_cc2, _cc3); + float32x4x2_t _cc = vuzpq_f32(_cc01, _cc23); + _c0 = _cc.val[0]; + _c1 = _cc.val[1]; + pC += 2; + } + else // if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + pC += 8; + } + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + } + } + if (broadcast_type_C == 4) + { + float32x2_t _c = vld1_f32(pC); + _c = vmul_n_f32(_c, beta); + _c0 = vdupq_lane_f32(_c, 0); + float32x4_t _c1 = vdupq_lane_f32(_c, 1); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 2; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + out_hstep, _f1); + + pp += 8; + p0 += out_hstep * 2; + } + for (; jj < max_jj; jj += 1) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + _c0 = vsetq_lane_f32(pC[0], _c0, 0); + _c0 = vsetq_lane_f32(pC[c_hstep], _c0, 1); + _c0 = vsetq_lane_f32(pC[c_hstep * 2], _c0, 2); + _c0 = vsetq_lane_f32(pC[c_hstep * 3], _c0, 3); + pC += 1; + } + else // if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + pC += 4; + } + _f0 = vmlaq_n_f32(_f0, _c0, beta); + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(pC[0] * beta); + _f0 = vaddq_f32(_f0, _c0); + pC += 1; + } + } + + _f0 = vmulq_n_f32(_f0, alpha); + + vst1q_f32(p0, _f0); + pp += 4; + p0 += out_hstep; + } + } +#endif // __ARM_NEON + for (; ii + 1 < max_ii; ii += 2) + { + float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * out_elempack; + + const float descale0 = descales[ii]; + const float descale1 = descales[ii + 1]; +#if __ARM_NEON + float32x2_t _descale01 = vld1_f32((const float*)descales + ii); +#endif + + float c0; + float c1; +#if __ARM_NEON + float32x4_t _c0; + float32x4_t _c1; +#endif + if (pC) + { + if (broadcast_type_C == 0) + { + c0 = pC[0] * beta; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); +#endif + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)C + i + ii; + c0 = pC[0] * beta; + c1 = pC[1] * beta; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); + _c1 = vdupq_n_f32(c1); +#endif + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + pC = (const float*)C + (i + ii) * c_hstep + j; + } + if (broadcast_type_C == 4) + { + pC = (const float*)C + j; + } + } + + int jj = 0; +#if __ARM_NEON +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale01, 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale01, 0); + float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), _descale01, 1); + float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), _descale01, 1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c1); + _f3 = vaddq_f32(_f3, _c1); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + c_hstep); + float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + pC += 8; + } + if (broadcast_type_C == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _c0 = vmulq_f32(_c0, _beta); + _c1 = vmulq_f32(_c1, _beta); + } + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + pC += 8; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + + if (out_elempack == 4) + { + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f2); + vst1q_f32(p0 + out_hstep * 4, _f1); + vst1q_f32(p0 + out_hstep * 4 + 4, _f3); + } + if (out_elempack == 1) + { + float32x4x2_t _f02 = vzipq_f32(_f0, _f2); + float32x4x2_t _f13 = vzipq_f32(_f1, _f3); + vst1_f32(p0, vget_low_f32(_f02.val[0])); + vst1_f32(p0 + out_hstep, vget_high_f32(_f02.val[0])); + vst1_f32(p0 + out_hstep * 2, vget_low_f32(_f02.val[1])); + vst1_f32(p0 + out_hstep * 3, vget_high_f32(_f02.val[1])); + vst1_f32(p0 + out_hstep * 4, vget_low_f32(_f13.val[0])); + vst1_f32(p0 + out_hstep * 5, vget_high_f32(_f13.val[0])); + vst1_f32(p0 + out_hstep * 6, vget_low_f32(_f13.val[1])); + vst1_f32(p0 + out_hstep * 7, vget_high_f32(_f13.val[1])); + } + + pp += 16; + p0 += out_hstep * 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + // a0 a1 a2 a3 + // b0 b1 b2 b3 + + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale01, 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale01, 1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + c_hstep); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + } + pC += 4; + } + if (broadcast_type_C == 4) + { + _c0 = vld1q_f32(pC); + _c0 = vmulq_n_f32(_c0, beta); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + pC += 4; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + + if (out_elempack == 4) + { + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + } + if (out_elempack == 1) + { + float32x4x2_t _f01 = vzipq_f32(_f0, _f1); + vst1_f32(p0, vget_low_f32(_f01.val[0])); + vst1_f32(p0 + out_hstep, vget_high_f32(_f01.val[0])); + vst1_f32(p0 + out_hstep * 2, vget_low_f32(_f01.val[1])); + vst1_f32(p0 + out_hstep * 3, vget_high_f32(_f01.val[1])); + } + + pp += 8; + p0 += out_hstep * 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + // a0 a1 b0 b1 + int32x2x2_t _sum0 = vld2_s32(pp); + + float32x4_t _descale = vcombine_f32(_descale01, _descale01); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vcombine_s32(_sum0.val[0], _sum0.val[1])), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + float32x4_t _cc = vzipq_f32(_c0, _c1).val[0]; + _f0 = vaddq_f32(_f0, _cc); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + float32x2_t _cc0 = vld1_f32(pC); + float32x2_t _cc1 = vld1_f32(pC + c_hstep); + float32x2x2_t _c01 = vzip_f32(_cc0, _cc1); + _c0 = vcombine_f32(_c01.val[0], _c01.val[1]); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + pC += 2; + } + if (broadcast_type_C == 4) + { + float32x2_t _cc = vld1_f32(pC); + float32x2x2_t _c01 = vzip_f32(_cc, _cc); + _c0 = vcombine_f32(_c01.val[0], _c01.val[1]); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + pC += 2; + } + } + + _f0 = vmulq_n_f32(_f0, alpha); + + vst1_f32(p0, vget_low_f32(_f0)); + vst1_f32(p0 + out_hstep, vget_high_f32(_f0)); + + pp += 4; + p0 += out_hstep * 2; + } +#endif // __ARM_NEON + for (; jj < max_jj; jj += 1) + { + float f0 = pp[0] * descale0; + float f1 = pp[1] * descale1; + + if (pC) + { + if (broadcast_type_C == 0) + { + f0 += c0; + f1 += c0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + f0 += c0; + f1 += c1; + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + f0 += pC[0] * beta; + f1 += pC[c_hstep] * beta; + pC += 1; + } + if (broadcast_type_C == 4) + { + f0 += pC[0] * beta; + f1 += pC[0] * beta; + pC += 1; + } + } + + f0 *= alpha; + f1 *= alpha; + + p0[0] = f0; + p0[1] = f1; + + pp += 2; + p0 += out_hstep; + } + } + for (; ii < max_ii; ii += 1) + { + float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * out_elempack; + + const float descale = descales[ii]; +#if __ARM_NEON + float32x4_t _descale = vdupq_n_f32(descale); +#endif + + float c0; +#if __ARM_NEON + float32x4_t _c0; +#endif + if (pC) + { + if (broadcast_type_C == 0) + { + c0 = pC[0] * beta; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); +#endif + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)C + i + ii; + c0 = pC[0] * beta; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); +#endif + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + pC = (const float*)C + (i + ii) * c_hstep + j; + } + if (broadcast_type_C == 4) + { + pC = (const float*)C + j; + } + } + + int jj = 0; +#if __ARM_NEON + for (; jj + 15 < max_jj; jj += 16) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + _c0 = vld1q_f32(pC); + float32x4_t _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + 8); + float32x4_t _c3 = vld1q_f32(pC + 12); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + pC += 16; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + + if (out_hstep == 1) + { + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + vst1q_f32(p0 + 8, _f2); + vst1q_f32(p0 + 12, _f3); + } + else + { + if (out_elempack == 4) + { + vst1q_f32(p0, _f0); + vst1q_f32(p0 + out_hstep * 4, _f1); + vst1q_f32(p0 + out_hstep * 8, _f2); + vst1q_f32(p0 + out_hstep * 12, _f3); + } + if (out_elempack == 1) + { + p0[0] = vgetq_lane_f32(_f0, 0); + p0[out_hstep] = vgetq_lane_f32(_f0, 1); + p0[out_hstep * 2] = vgetq_lane_f32(_f0, 2); + p0[out_hstep * 3] = vgetq_lane_f32(_f0, 3); + p0[out_hstep * 4] = vgetq_lane_f32(_f1, 0); + p0[out_hstep * 5] = vgetq_lane_f32(_f1, 1); + p0[out_hstep * 6] = vgetq_lane_f32(_f1, 2); + p0[out_hstep * 7] = vgetq_lane_f32(_f1, 3); + p0[out_hstep * 8] = vgetq_lane_f32(_f2, 0); + p0[out_hstep * 9] = vgetq_lane_f32(_f2, 1); + p0[out_hstep * 10] = vgetq_lane_f32(_f2, 2); + p0[out_hstep * 11] = vgetq_lane_f32(_f2, 3); + p0[out_hstep * 12] = vgetq_lane_f32(_f3, 0); + p0[out_hstep * 13] = vgetq_lane_f32(_f3, 1); + p0[out_hstep * 14] = vgetq_lane_f32(_f3, 2); + p0[out_hstep * 15] = vgetq_lane_f32(_f3, 3); + } + } + + pp += 16; + p0 += out_hstep * 16; + } + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + _c0 = vld1q_f32(pC); + float32x4_t _c1 = vld1q_f32(pC + 4); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + } + pC += 8; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + + if (out_hstep == 1) + { + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + } + else + { + if (out_elempack == 4) + { + vst1q_f32(p0, _f0); + vst1q_f32(p0 + out_hstep * 4, _f1); + } + if (out_elempack == 1) + { + p0[0] = vgetq_lane_f32(_f0, 0); + p0[out_hstep] = vgetq_lane_f32(_f0, 1); + p0[out_hstep * 2] = vgetq_lane_f32(_f0, 2); + p0[out_hstep * 3] = vgetq_lane_f32(_f0, 3); + p0[out_hstep * 4] = vgetq_lane_f32(_f1, 0); + p0[out_hstep * 5] = vgetq_lane_f32(_f1, 1); + p0[out_hstep * 6] = vgetq_lane_f32(_f1, 2); + p0[out_hstep * 7] = vgetq_lane_f32(_f1, 3); + } + } + + pp += 8; + p0 += out_hstep * 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + _c0 = vld1q_f32(pC); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + pC += 4; + } + } + + _f0 = vmulq_n_f32(_f0, alpha); + + if (out_hstep == 1) + { + vst1q_f32(p0, _f0); + } + else + { + if (out_elempack == 4) + { + vst1q_f32(p0, _f0); + } + if (out_elempack == 1) + { + p0[0] = vgetq_lane_f32(_f0, 0); + p0[out_hstep] = vgetq_lane_f32(_f0, 1); + p0[out_hstep * 2] = vgetq_lane_f32(_f0, 2); + p0[out_hstep * 3] = vgetq_lane_f32(_f0, 3); + } + } + + pp += 4; + p0 += out_hstep * 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + float32x2_t _f0 = vmul_f32(vcvt_f32_s32(vld1_s32(pp)), vget_low_f32(_descale)); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vadd_f32(_f0, vget_low_f32(_c0)); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + float32x2_t _c = vld1_f32(pC); + _f0 = vmla_n_f32(_f0, _c, beta); + pC += 2; + } + } + + _f0 = vmul_n_f32(_f0, alpha); + + if (out_hstep == 1) + { + vst1_f32(p0, _f0); + } + else + { + p0[0] = vget_lane_f32(_f0, 0); + p0[out_hstep] = vget_lane_f32(_f0, 1); + } + + pp += 2; + p0 += out_hstep * 2; + } +#endif // __ARM_NEON + for (; jj < max_jj; jj += 1) + { + float f0 = pp[0] * descale; + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + f0 += c0; + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + f0 += pC[0] * beta; + pC += 1; + } + } + + f0 *= alpha; + + p0[0] = f0; + + pp += 1; + p0 += out_hstep; + } + } +} + +static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, Mat& topT_tile, int i, int max_ii, int j, int max_jj, int k, int max_kk) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + gemm_transB_packed_tile_int8_i8mm(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + gemm_transB_packed_tile_int8_asimddp(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); + return; + } +#endif + + // NCNN_LOGE("gemm_transB_packed_tile_int8 %d %d %d %d %d %d", i, max_ii, j, max_jj, k, max_kk); + + const signed char* pAT = AT_tile; + const signed char* pBT = BT_tile; + + int* outptr = topT_tile; + + int ii = 0; +#if __ARM_NEON + for (; ii + 7 < max_ii; ii += 8) + { + const signed char* pB = pBT; + + int jj = 0; +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + const signed char* pA = pAT; + +#if NCNN_GNU_INLINE_ASM + asm volatile( +#if !__ARM_FEATURE_MATMUL_INT8 + "cmp %w7, #0 \n" + "beq 0f \n" + + "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%0], #64 \n" + "ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%0], #64 \n" + "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%0], #64 \n" + "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%0] \n" + "sub %0, %0, #192 \n" + "b 1f \n" + + "0: \n" + "eor v16.16b, v16.16b, v16.16b \n" + "eor v17.16b, v17.16b, v17.16b \n" + "eor v18.16b, v18.16b, v18.16b \n" + "eor v19.16b, v19.16b, v19.16b \n" + "eor v20.16b, v20.16b, v20.16b \n" + "eor v21.16b, v21.16b, v21.16b \n" + "eor v22.16b, v22.16b, v22.16b \n" + "eor v23.16b, v23.16b, v23.16b \n" + "eor v24.16b, v24.16b, v24.16b \n" + "eor v25.16b, v25.16b, v25.16b \n" + "eor v26.16b, v26.16b, v26.16b \n" + "eor v27.16b, v27.16b, v27.16b \n" + "eor v28.16b, v28.16b, v28.16b \n" + "eor v29.16b, v29.16b, v29.16b \n" + "eor v30.16b, v30.16b, v30.16b \n" + "eor v31.16b, v31.16b, v31.16b \n" + + "1: \n" +#endif // !__ARM_FEATURE_MATMUL_INT8 + +#if __ARM_FEATURE_DOTPROD + "lsr w4, %w6, #3 \n" // w4 = max_kk >> 3 + "cmp w4, #0 \n" + "beq 101f \n" + +#if __ARM_FEATURE_MATMUL_INT8 + "eor v0.16b, v0.16b, v0.16b \n" + "eor v1.16b, v1.16b, v1.16b \n" + "eor v2.16b, v2.16b, v2.16b \n" + "eor v3.16b, v3.16b, v3.16b \n" + "eor v4.16b, v4.16b, v4.16b \n" + "eor v5.16b, v5.16b, v5.16b \n" + "eor v6.16b, v6.16b, v6.16b \n" + "eor v7.16b, v7.16b, v7.16b \n" + "eor v8.16b, v8.16b, v8.16b \n" + "eor v9.16b, v9.16b, v9.16b \n" + "eor v10.16b, v10.16b, v10.16b \n" + "eor v11.16b, v11.16b, v11.16b \n" + "eor v12.16b, v12.16b, v12.16b \n" + "eor v13.16b, v13.16b, v13.16b \n" + "eor v14.16b, v14.16b, v14.16b \n" + "eor v15.16b, v15.16b, v15.16b \n" + + "2: \n" + "ld1 {v16.16b, v17.16b, v18.16b, v19.16b}, [%1], #64 \n" + "ld1 {v20.16b, v21.16b, v22.16b, v23.16b}, [%2], #64 \n" + "smmla v0.4s, v16.16b, v20.16b \n" + "smmla v1.4s, v17.16b, v20.16b \n" + "smmla v2.4s, v16.16b, v21.16b \n" + "smmla v3.4s, v17.16b, v21.16b \n" + "smmla v4.4s, v18.16b, v20.16b \n" + "smmla v5.4s, v19.16b, v20.16b \n" + "smmla v6.4s, v18.16b, v21.16b \n" + "smmla v7.4s, v19.16b, v21.16b \n" + "subs w4, w4, #1 \n" + "smmla v8.4s, v16.16b, v22.16b \n" + "smmla v9.4s, v17.16b, v22.16b \n" + "smmla v10.4s, v16.16b, v23.16b \n" + "smmla v11.4s, v17.16b, v23.16b \n" + "smmla v12.4s, v18.16b, v22.16b \n" + "smmla v13.4s, v19.16b, v22.16b \n" + "smmla v14.4s, v18.16b, v23.16b \n" + "smmla v15.4s, v19.16b, v23.16b \n" + "bne 2b \n" + + "uzp1 v16.4s, v0.4s, v1.4s \n" + "uzp2 v17.4s, v0.4s, v1.4s \n" + "uzp1 v18.4s, v2.4s, v3.4s \n" + "uzp2 v19.4s, v2.4s, v3.4s \n" + "uzp1 v20.4s, v4.4s, v5.4s \n" + "uzp2 v21.4s, v4.4s, v5.4s \n" + "uzp1 v22.4s, v6.4s, v7.4s \n" + "uzp2 v23.4s, v6.4s, v7.4s \n" + "uzp1 v24.4s, v8.4s, v9.4s \n" + "uzp2 v25.4s, v8.4s, v9.4s \n" + "uzp1 v26.4s, v10.4s, v11.4s \n" + "uzp2 v27.4s, v10.4s, v11.4s \n" + "uzp1 v28.4s, v12.4s, v13.4s \n" + "uzp2 v29.4s, v12.4s, v13.4s \n" + "uzp1 v30.4s, v14.4s, v15.4s \n" + "uzp2 v31.4s, v14.4s, v15.4s \n" + + "cmp %w7, #0 \n" + "beq 1f \n" + + "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%0], #64 \n" + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%0], #64 \n" + "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%0], #64 \n" + "ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%0] \n" + "sub %0, %0, #192 \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "add v18.4s, v18.4s, v2.4s \n" + "add v19.4s, v19.4s, v3.4s \n" + "add v20.4s, v20.4s, v4.4s \n" + "add v21.4s, v21.4s, v5.4s \n" + "add v22.4s, v22.4s, v6.4s \n" + "add v23.4s, v23.4s, v7.4s \n" + "add v24.4s, v24.4s, v8.4s \n" + "add v25.4s, v25.4s, v9.4s \n" + "add v26.4s, v26.4s, v10.4s \n" + "add v27.4s, v27.4s, v11.4s \n" + "add v28.4s, v28.4s, v12.4s \n" + "add v29.4s, v29.4s, v13.4s \n" + "add v30.4s, v30.4s, v14.4s \n" + "add v31.4s, v31.4s, v15.4s \n" + "b 1f \n" +#else // __ARM_FEATURE_MATMUL_INT8 + "2: \n" + "ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [%1], #64 \n" + "ld1 {v4.16b, v5.16b, v6.16b, v7.16b}, [%2], #64 \n" + "sdot v16.4s, v0.16b, v4.4b[0] \n" + "sdot v17.4s, v0.16b, v4.4b[1] \n" + "sdot v18.4s, v0.16b, v4.4b[2] \n" + "sdot v19.4s, v0.16b, v4.4b[3] \n" + "sdot v20.4s, v1.16b, v4.4b[0] \n" + "sdot v21.4s, v1.16b, v4.4b[1] \n" + "sdot v22.4s, v1.16b, v4.4b[2] \n" + "sdot v23.4s, v1.16b, v4.4b[3] \n" + "sdot v24.4s, v0.16b, v5.4b[0] \n" + "sdot v25.4s, v0.16b, v5.4b[1] \n" + "sdot v26.4s, v0.16b, v5.4b[2] \n" + "sdot v27.4s, v0.16b, v5.4b[3] \n" + "sdot v28.4s, v1.16b, v5.4b[0] \n" + "sdot v29.4s, v1.16b, v5.4b[1] \n" + "sdot v30.4s, v1.16b, v5.4b[2] \n" + "sdot v31.4s, v1.16b, v5.4b[3] \n" + "subs w4, w4, #1 \n" + "sdot v16.4s, v2.16b, v6.4b[0] \n" + "sdot v17.4s, v2.16b, v6.4b[1] \n" + "sdot v18.4s, v2.16b, v6.4b[2] \n" + "sdot v19.4s, v2.16b, v6.4b[3] \n" + "sdot v20.4s, v3.16b, v6.4b[0] \n" + "sdot v21.4s, v3.16b, v6.4b[1] \n" + "sdot v22.4s, v3.16b, v6.4b[2] \n" + "sdot v23.4s, v3.16b, v6.4b[3] \n" + "sdot v24.4s, v2.16b, v7.4b[0] \n" + "sdot v25.4s, v2.16b, v7.4b[1] \n" + "sdot v26.4s, v2.16b, v7.4b[2] \n" + "sdot v27.4s, v2.16b, v7.4b[3] \n" + "sdot v28.4s, v3.16b, v7.4b[0] \n" + "sdot v29.4s, v3.16b, v7.4b[1] \n" + "sdot v30.4s, v3.16b, v7.4b[2] \n" + "sdot v31.4s, v3.16b, v7.4b[3] \n" + "bne 2b \n" +#endif // __ARM_FEATURE_MATMUL_INT8 + + "101: \n" +#if __ARM_FEATURE_MATMUL_INT8 + "cmp %w7, #0 \n" + "beq 0f \n" + + "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%0], #64 \n" + "ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%0], #64 \n" + "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%0], #64 \n" + "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%0] \n" + "sub %0, %0, #192 \n" + "b 1f \n" + + "0: \n" + "eor v16.16b, v16.16b, v16.16b \n" + "eor v17.16b, v17.16b, v17.16b \n" + "eor v18.16b, v18.16b, v18.16b \n" + "eor v19.16b, v19.16b, v19.16b \n" + "eor v20.16b, v20.16b, v20.16b \n" + "eor v21.16b, v21.16b, v21.16b \n" + "eor v22.16b, v22.16b, v22.16b \n" + "eor v23.16b, v23.16b, v23.16b \n" + "eor v24.16b, v24.16b, v24.16b \n" + "eor v25.16b, v25.16b, v25.16b \n" + "eor v26.16b, v26.16b, v26.16b \n" + "eor v27.16b, v27.16b, v27.16b \n" + "eor v28.16b, v28.16b, v28.16b \n" + "eor v29.16b, v29.16b, v29.16b \n" + "eor v30.16b, v30.16b, v30.16b \n" + "eor v31.16b, v31.16b, v31.16b \n" + "1: \n" +#endif // __ARM_FEATURE_MATMUL_INT8 + + "and w4, %w6, #4 \n" // w4 = remain = max_kk & 4 + "cmp w4, #0 \n" + "beq 3f \n" + + // kk += 4 part + "ld1 {v0.16b, v1.16b}, [%1], #32 \n" + "ld1 {v2.16b, v3.16b}, [%2], #32 \n" + "sdot v16.4s, v0.16b, v2.4b[0] \n" + "sdot v17.4s, v0.16b, v2.4b[1] \n" + "sdot v18.4s, v0.16b, v2.4b[2] \n" + "sdot v19.4s, v0.16b, v2.4b[3] \n" + "sdot v20.4s, v1.16b, v2.4b[0] \n" + "sdot v21.4s, v1.16b, v2.4b[1] \n" + "sdot v22.4s, v1.16b, v2.4b[2] \n" + "sdot v23.4s, v1.16b, v2.4b[3] \n" + "sdot v24.4s, v0.16b, v3.4b[0] \n" + "sdot v25.4s, v0.16b, v3.4b[1] \n" + "sdot v26.4s, v0.16b, v3.4b[2] \n" + "sdot v27.4s, v0.16b, v3.4b[3] \n" + "sdot v28.4s, v1.16b, v3.4b[0] \n" + "sdot v29.4s, v1.16b, v3.4b[1] \n" + "sdot v30.4s, v1.16b, v3.4b[2] \n" + "sdot v31.4s, v1.16b, v3.4b[3] \n" +#else // __ARM_FEATURE_DOTPROD + "lsr w4, %w6, #2 \n" // w4 = max_kk >> 2 + "cmp w4, #0 \n" + "beq 3f \n" + + "2: \n" + "ld1 {v0.16b, v1.16b}, [%1], #32 \n" + "ld1 {v4.16b, v5.16b}, [%2], #32 \n" + "smull v8.8h, v0.8b, v4.8b \n" + "smull2 v9.8h, v0.16b, v4.16b \n" + "rev64 v2.4s, v0.4s \n" + "smull v10.8h, v2.8b, v4.8b \n" + "smull2 v11.8h, v2.16b, v4.16b \n" + "rev64 v6.8h, v4.8h \n" + "smull v12.8h, v0.8b, v6.8b \n" + "smull2 v13.8h, v0.16b, v6.16b \n" + "rev64 v3.4s, v1.4s \n" + "smull v14.8h, v2.8b, v6.8b \n" + "smull2 v15.8h, v2.16b, v6.16b \n" + "rev64 v7.8h, v5.8h \n" + "smlal v8.8h, v1.8b, v5.8b \n" + "smlal2 v9.8h, v1.16b, v5.16b \n" + "smlal v10.8h, v3.8b, v5.8b \n" + "smlal2 v11.8h, v3.16b, v5.16b \n" + "smlal v12.8h, v1.8b, v7.8b \n" + "smlal2 v13.8h, v1.16b, v7.16b \n" + "smlal v14.8h, v3.8b, v7.8b \n" + "smlal2 v15.8h, v3.16b, v7.16b \n" + "ext v0.16b, v0.16b, v0.16b, #8 \n" + "ext v2.16b, v2.16b, v2.16b, #8 \n" + "sadalp v16.4s, v8.8h \n" + "sadalp v17.4s, v9.8h \n" + "sadalp v20.4s, v10.8h \n" + "sadalp v21.4s, v11.8h \n" + "ext v1.16b, v1.16b, v1.16b, #8 \n" + "ext v3.16b, v3.16b, v3.16b, #8 \n" + "smull v8.8h, v0.8b, v4.8b \n" + "smull2 v9.8h, v0.16b, v4.16b \n" + "smull v10.8h, v2.8b, v4.8b \n" + "smull2 v11.8h, v2.16b, v4.16b \n" + "sadalp v24.4s, v12.8h \n" + "sadalp v25.4s, v13.8h \n" + "sadalp v28.4s, v14.8h \n" + "sadalp v29.4s, v15.8h \n" + "smull v12.8h, v0.8b, v6.8b \n" + "smull2 v13.8h, v0.16b, v6.16b \n" + "smull v14.8h, v2.8b, v6.8b \n" + "smull2 v15.8h, v2.16b, v6.16b \n" + "smlal v8.8h, v1.8b, v5.8b \n" + "smlal2 v9.8h, v1.16b, v5.16b \n" + "smlal v10.8h, v3.8b, v5.8b \n" + "smlal2 v11.8h, v3.16b, v5.16b \n" + "smlal v12.8h, v1.8b, v7.8b \n" + "smlal2 v13.8h, v1.16b, v7.16b \n" + "smlal v14.8h, v3.8b, v7.8b \n" + "smlal2 v15.8h, v3.16b, v7.16b \n" + "subs w4, w4, #1 \n" + "sadalp v18.4s, v8.8h \n" + "sadalp v19.4s, v9.8h \n" + "sadalp v22.4s, v10.8h \n" + "sadalp v23.4s, v11.8h \n" + "sadalp v26.4s, v12.8h \n" + "sadalp v27.4s, v13.8h \n" + "sadalp v30.4s, v14.8h \n" + "sadalp v31.4s, v15.8h \n" + "bne 2b \n" +#endif // __ARM_FEATURE_DOTPROD + + "3: \n" + "and w4, %w6, #2 \n" // w4 = remain = max_kk & 2 + "cmp w4, #0 \n" + "beq 4f \n" + + // kk += 2 part +#if __ARM_FEATURE_DOTPROD + "ld1 {v0.16b}, [%1], #16 \n" + "ld1 {v1.16b}, [%2], #16 \n" + "dup v4.8h, v1.h[0] \n" + "dup v5.8h, v1.h[1] \n" + "dup v6.8h, v1.h[2] \n" + "dup v7.8h, v1.h[3] \n" + "smull v8.8h, v0.8b, v4.8b \n" + "smull v9.8h, v0.8b, v5.8b \n" + "smull v10.8h, v0.8b, v6.8b \n" + "smull v11.8h, v0.8b, v7.8b \n" + "smull2 v12.8h, v0.16b, v4.16b \n" + "smull2 v13.8h, v0.16b, v5.16b \n" + "smull2 v14.8h, v0.16b, v6.16b \n" + "smull2 v15.8h, v0.16b, v7.16b \n" + "sadalp v16.4s, v8.8h \n" + "sadalp v17.4s, v9.8h \n" + "sadalp v18.4s, v10.8h \n" + "sadalp v19.4s, v11.8h \n" + "sadalp v20.4s, v12.8h \n" + "sadalp v21.4s, v13.8h \n" + "sadalp v22.4s, v14.8h \n" + "sadalp v23.4s, v15.8h \n" + "dup v4.8h, v1.h[4] \n" + "dup v5.8h, v1.h[5] \n" + "dup v6.8h, v1.h[6] \n" + "dup v7.8h, v1.h[7] \n" + "smull v8.8h, v0.8b, v4.8b \n" + "smull v9.8h, v0.8b, v5.8b \n" + "smull v10.8h, v0.8b, v6.8b \n" + "smull v11.8h, v0.8b, v7.8b \n" + "smull2 v12.8h, v0.16b, v4.16b \n" + "smull2 v13.8h, v0.16b, v5.16b \n" + "smull2 v14.8h, v0.16b, v6.16b \n" + "smull2 v15.8h, v0.16b, v7.16b \n" + "sadalp v24.4s, v8.8h \n" + "sadalp v25.4s, v9.8h \n" + "sadalp v26.4s, v10.8h \n" + "sadalp v27.4s, v11.8h \n" + "sadalp v28.4s, v12.8h \n" + "sadalp v29.4s, v13.8h \n" + "sadalp v30.4s, v14.8h \n" + "sadalp v31.4s, v15.8h \n" +#else // __ARM_FEATURE_DOTPROD + "ld1 {v0.16b}, [%1], #16 \n" + "ld1 {v2.16b}, [%2], #16 \n" + "rev64 v1.4s, v0.4s \n" + "rev64 v3.8h, v2.8h \n" + "smull v8.8h, v0.8b, v2.8b \n" + "smull2 v9.8h, v0.16b, v2.16b \n" + "smull v10.8h, v1.8b, v2.8b \n" + "smull2 v11.8h, v1.16b, v2.16b \n" + "smull v12.8h, v0.8b, v3.8b \n" + "smull2 v13.8h, v0.16b, v3.16b \n" + "smull v14.8h, v1.8b, v3.8b \n" + "smull2 v15.8h, v1.16b, v3.16b \n" + "sadalp v16.4s, v8.8h \n" + "sadalp v17.4s, v9.8h \n" + "sadalp v20.4s, v10.8h \n" + "sadalp v21.4s, v11.8h \n" + "sadalp v24.4s, v12.8h \n" + "sadalp v25.4s, v13.8h \n" + "sadalp v28.4s, v14.8h \n" + "sadalp v29.4s, v15.8h \n" + "ext v0.16b, v0.16b, v0.16b, #8 \n" + "ext v1.16b, v1.16b, v1.16b, #8 \n" + "smull v8.8h, v0.8b, v2.8b \n" + "smull2 v9.8h, v0.16b, v2.16b \n" + "smull v10.8h, v1.8b, v2.8b \n" + "smull2 v11.8h, v1.16b, v2.16b \n" + "smull v12.8h, v0.8b, v3.8b \n" + "smull2 v13.8h, v0.16b, v3.16b \n" + "smull v14.8h, v1.8b, v3.8b \n" + "smull2 v15.8h, v1.16b, v3.16b \n" + "sadalp v18.4s, v8.8h \n" + "sadalp v19.4s, v9.8h \n" + "sadalp v22.4s, v10.8h \n" + "sadalp v23.4s, v11.8h \n" + "sadalp v26.4s, v12.8h \n" + "sadalp v27.4s, v13.8h \n" + "sadalp v30.4s, v14.8h \n" + "sadalp v31.4s, v15.8h \n" +#endif // __ARM_FEATURE_DOTPROD + + "4: \n" + "and w4, %w6, #1 \n" // w4 = remain = max_kk & 1 + "cmp w4, #0 \n" + "beq 5f \n" + + // kk += 1 part +#if __ARM_FEATURE_DOTPROD + "ld1 {v0.8b}, [%1], #8 \n" + "ld1 {v1.8b}, [%2], #8 \n" + "dup v8.8b, v1.b[0] \n" + "dup v9.8b, v1.b[1] \n" + "dup v10.8b, v1.b[2] \n" + "dup v11.8b, v1.b[3] \n" + "dup v12.8b, v1.b[4] \n" + "dup v13.8b, v1.b[5] \n" + "dup v14.8b, v1.b[6] \n" + "dup v15.8b, v1.b[7] \n" + "smull v8.8h, v0.8b, v8.8b \n" + "smull v9.8h, v0.8b, v9.8b \n" + "smull v10.8h, v0.8b, v10.8b \n" + "smull v11.8h, v0.8b, v11.8b \n" + "smull v12.8h, v0.8b, v12.8b \n" + "smull v13.8h, v0.8b, v13.8b \n" + "smull v14.8h, v0.8b, v14.8b \n" + "smull v15.8h, v0.8b, v15.8b \n" + "saddw v16.4s, v16.4s, v8.4h \n" + "saddw v17.4s, v17.4s, v9.4h \n" + "saddw v18.4s, v18.4s, v10.4h \n" + "saddw v19.4s, v19.4s, v11.4h \n" + "saddw2 v20.4s, v20.4s, v8.8h \n" + "saddw2 v21.4s, v21.4s, v9.8h \n" + "saddw2 v22.4s, v22.4s, v10.8h \n" + "saddw2 v23.4s, v23.4s, v11.8h \n" + "saddw v24.4s, v24.4s, v12.4h \n" + "saddw v25.4s, v25.4s, v13.4h \n" + "saddw v26.4s, v26.4s, v14.4h \n" + "saddw v27.4s, v27.4s, v15.4h \n" + "saddw2 v28.4s, v28.4s, v12.8h \n" + "saddw2 v29.4s, v29.4s, v13.8h \n" + "saddw2 v30.4s, v30.4s, v14.8h \n" + "saddw2 v31.4s, v31.4s, v15.8h \n" +#else // __ARM_FEATURE_DOTPROD + "ld1 {v0.8b}, [%1], #8 \n" + "ld1 {v4.8b}, [%2], #8 \n" + "ext v1.8b, v0.8b, v0.8b, #4 \n" + "rev32 v2.4h, v0.4h \n" + "rev64 v3.4h, v0.4h \n" + "rev32 v5.8b, v4.8b \n" + "smull v8.8h, v0.8b, v4.8b \n" + "smull v9.8h, v1.8b, v4.8b \n" + "smull v10.8h, v2.8b, v4.8b \n" + "smull v11.8h, v3.8b, v4.8b \n" + "smull v12.8h, v0.8b, v5.8b \n" + "smull v13.8h, v1.8b, v5.8b \n" + "smull v14.8h, v2.8b, v5.8b \n" + "smull v15.8h, v3.8b, v5.8b \n" + "saddw v16.4s, v16.4s, v8.4h \n" + "saddw2 v17.4s, v17.4s, v8.8h \n" + "saddw v18.4s, v18.4s, v9.4h \n" + "saddw2 v19.4s, v19.4s, v9.8h \n" + "saddw v20.4s, v20.4s, v10.4h \n" + "saddw2 v21.4s, v21.4s, v10.8h \n" + "saddw v22.4s, v22.4s, v11.4h \n" + "saddw2 v23.4s, v23.4s, v11.8h \n" + "saddw v24.4s, v24.4s, v12.4h \n" + "saddw2 v25.4s, v25.4s, v12.8h \n" + "saddw v26.4s, v26.4s, v13.4h \n" + "saddw2 v27.4s, v27.4s, v13.8h \n" + "saddw v28.4s, v28.4s, v14.4h \n" + "saddw2 v29.4s, v29.4s, v14.8h \n" + "saddw v30.4s, v30.4s, v15.4h \n" + "saddw2 v31.4s, v31.4s, v15.8h \n" +#endif // __ARM_FEATURE_DOTPROD + + "5: \n" + "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%0], #64 \n" + "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%0], #64 \n" + "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%0], #64 \n" + "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%0], #64 \n" + + : "=r"(outptr), // %0 + "=r"(pA), // %1 + "=r"(pB) // %2 + : "0"(outptr), + "1"(pA), + "2"(pB), + "r"(max_kk), // %6 + "r"(k) // %7 + : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +#else // NCNN_GNU_INLINE_ASM + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + int32x4_t _sum3; + int32x4_t _sum4; + int32x4_t _sum5; + int32x4_t _sum6; + int32x4_t _sum7; + int32x4_t _sum8; + int32x4_t _sum9; + int32x4_t _suma; + int32x4_t _sumb; + int32x4_t _sumc; + int32x4_t _sumd; + int32x4_t _sume; + int32x4_t _sumf; + +#if __ARM_FEATURE_MATMUL_INT8 + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + _sum4 = vdupq_n_s32(0); + _sum5 = vdupq_n_s32(0); + _sum6 = vdupq_n_s32(0); + _sum7 = vdupq_n_s32(0); + _sum8 = vdupq_n_s32(0); + _sum9 = vdupq_n_s32(0); + _suma = vdupq_n_s32(0); + _sumb = vdupq_n_s32(0); + _sumc = vdupq_n_s32(0); + _sumd = vdupq_n_s32(0); + _sume = vdupq_n_s32(0); + _sumf = vdupq_n_s32(0); + } +#else // __ARM_FEATURE_MATMUL_INT8 + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + _sum4 = vdupq_n_s32(0); + _sum5 = vdupq_n_s32(0); + _sum6 = vdupq_n_s32(0); + _sum7 = vdupq_n_s32(0); + _sum8 = vdupq_n_s32(0); + _sum9 = vdupq_n_s32(0); + _suma = vdupq_n_s32(0); + _sumb = vdupq_n_s32(0); + _sumc = vdupq_n_s32(0); + _sumd = vdupq_n_s32(0); + _sume = vdupq_n_s32(0); + _sumf = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + _sum2 = vld1q_s32(outptr + 8); + _sum3 = vld1q_s32(outptr + 12); + _sum4 = vld1q_s32(outptr + 16); + _sum5 = vld1q_s32(outptr + 20); + _sum6 = vld1q_s32(outptr + 24); + _sum7 = vld1q_s32(outptr + 28); + _sum8 = vld1q_s32(outptr + 32); + _sum9 = vld1q_s32(outptr + 36); + _suma = vld1q_s32(outptr + 40); + _sumb = vld1q_s32(outptr + 44); + _sumc = vld1q_s32(outptr + 48); + _sumd = vld1q_s32(outptr + 52); + _sume = vld1q_s32(outptr + 56); + _sumf = vld1q_s32(outptr + 60); + } +#endif // __ARM_FEATURE_MATMUL_INT8 + + int kk = 0; +#if __ARM_FEATURE_MATMUL_INT8 + { + for (; kk + 7 < max_kk; kk += 8) + { + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA1 = vld1q_s8(pA + 16); + int8x16_t _pA2 = vld1q_s8(pA + 32); + int8x16_t _pA3 = vld1q_s8(pA + 48); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + int8x16_t _pB2 = vld1q_s8(pB + 32); + int8x16_t _pB3 = vld1q_s8(pB + 48); + + _sum0 = vmmlaq_s32(_sum0, _pA0, _pB0); + _sum1 = vmmlaq_s32(_sum1, _pA1, _pB0); + _sum2 = vmmlaq_s32(_sum2, _pA0, _pB1); + _sum3 = vmmlaq_s32(_sum3, _pA1, _pB1); + _sum4 = vmmlaq_s32(_sum4, _pA2, _pB0); + _sum5 = vmmlaq_s32(_sum5, _pA3, _pB0); + _sum6 = vmmlaq_s32(_sum6, _pA2, _pB1); + _sum7 = vmmlaq_s32(_sum7, _pA3, _pB1); + _sum8 = vmmlaq_s32(_sum8, _pA0, _pB2); + _sum9 = vmmlaq_s32(_sum9, _pA1, _pB2); + _suma = vmmlaq_s32(_suma, _pA0, _pB3); + _sumb = vmmlaq_s32(_sumb, _pA1, _pB3); + _sumc = vmmlaq_s32(_sumc, _pA2, _pB2); + _sumd = vmmlaq_s32(_sumd, _pA3, _pB2); + _sume = vmmlaq_s32(_sume, _pA2, _pB3); + _sumf = vmmlaq_s32(_sumf, _pA3, _pB3); + + pA += 64; + pB += 64; + } + + int32x4x2_t _ss0 = vuzpq_s32(_sum0, _sum1); + int32x4x2_t _ss1 = vuzpq_s32(_sum2, _sum3); + int32x4x2_t _ss2 = vuzpq_s32(_sum4, _sum5); + int32x4x2_t _ss3 = vuzpq_s32(_sum6, _sum7); + int32x4x2_t _ss4 = vuzpq_s32(_sum8, _sum9); + int32x4x2_t _ss5 = vuzpq_s32(_suma, _sumb); + int32x4x2_t _ss6 = vuzpq_s32(_sumc, _sumd); + int32x4x2_t _ss7 = vuzpq_s32(_sume, _sumf); + + if (k == 0) + { + _sum0 = _ss0.val[0]; + _sum1 = _ss0.val[1]; + _sum2 = _ss1.val[0]; + _sum3 = _ss1.val[1]; + _sum4 = _ss2.val[0]; + _sum5 = _ss2.val[1]; + _sum6 = _ss3.val[0]; + _sum7 = _ss3.val[1]; + _sum8 = _ss4.val[0]; + _sum9 = _ss4.val[1]; + _suma = _ss5.val[0]; + _sumb = _ss5.val[1]; + _sumc = _ss6.val[0]; + _sumd = _ss6.val[1]; + _sume = _ss7.val[0]; + _sumf = _ss7.val[1]; + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + _sum2 = vld1q_s32(outptr + 8); + _sum3 = vld1q_s32(outptr + 12); + _sum4 = vld1q_s32(outptr + 16); + _sum5 = vld1q_s32(outptr + 20); + _sum6 = vld1q_s32(outptr + 24); + _sum7 = vld1q_s32(outptr + 28); + _sum8 = vld1q_s32(outptr + 32); + _sum9 = vld1q_s32(outptr + 36); + _suma = vld1q_s32(outptr + 40); + _sumb = vld1q_s32(outptr + 44); + _sumc = vld1q_s32(outptr + 48); + _sumd = vld1q_s32(outptr + 52); + _sume = vld1q_s32(outptr + 56); + _sumf = vld1q_s32(outptr + 60); + + _sum0 = vaddq_s32(_sum0, _ss0.val[0]); + _sum1 = vaddq_s32(_sum1, _ss0.val[1]); + _sum2 = vaddq_s32(_sum2, _ss1.val[0]); + _sum3 = vaddq_s32(_sum3, _ss1.val[1]); + _sum4 = vaddq_s32(_sum4, _ss2.val[0]); + _sum5 = vaddq_s32(_sum5, _ss2.val[1]); + _sum6 = vaddq_s32(_sum6, _ss3.val[0]); + _sum7 = vaddq_s32(_sum7, _ss3.val[1]); + _sum8 = vaddq_s32(_sum8, _ss4.val[0]); + _sum9 = vaddq_s32(_sum9, _ss4.val[1]); + _suma = vaddq_s32(_suma, _ss5.val[0]); + _sumb = vaddq_s32(_sumb, _ss5.val[1]); + _sumc = vaddq_s32(_sumc, _ss6.val[0]); + _sumd = vaddq_s32(_sumd, _ss6.val[1]); + _sume = vaddq_s32(_sume, _ss7.val[0]); + _sumf = vaddq_s32(_sumf, _ss7.val[1]); + } + } +#elif __ARM_FEATURE_DOTPROD + for (; kk + 7 < max_kk; kk += 8) + { + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA1 = vld1q_s8(pA + 16); + int8x16_t _pA2 = vld1q_s8(pA + 32); + int8x16_t _pA3 = vld1q_s8(pA + 48); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + int8x16_t _pB2 = vld1q_s8(pB + 32); + int8x16_t _pB3 = vld1q_s8(pB + 48); + + // aaaa bbbb cccc dddd eeee ffff gggg hhhh + + // 0000 1111 2222 3333 4444 5555 6666 7777 + _sum0 = vdotq_laneq_s32(_sum0, _pA0, _pB0, 0); + _sum1 = vdotq_laneq_s32(_sum1, _pA0, _pB0, 1); + _sum2 = vdotq_laneq_s32(_sum2, _pA0, _pB0, 2); + _sum3 = vdotq_laneq_s32(_sum3, _pA0, _pB0, 3); + _sum4 = vdotq_laneq_s32(_sum4, _pA1, _pB0, 0); + _sum5 = vdotq_laneq_s32(_sum5, _pA1, _pB0, 1); + _sum6 = vdotq_laneq_s32(_sum6, _pA1, _pB0, 2); + _sum7 = vdotq_laneq_s32(_sum7, _pA1, _pB0, 3); + _sum8 = vdotq_laneq_s32(_sum8, _pA0, _pB1, 0); + _sum9 = vdotq_laneq_s32(_sum9, _pA0, _pB1, 1); + _suma = vdotq_laneq_s32(_suma, _pA0, _pB1, 2); + _sumb = vdotq_laneq_s32(_sumb, _pA0, _pB1, 3); + _sumc = vdotq_laneq_s32(_sumc, _pA1, _pB1, 0); + _sumd = vdotq_laneq_s32(_sumd, _pA1, _pB1, 1); + _sume = vdotq_laneq_s32(_sume, _pA1, _pB1, 2); + _sumf = vdotq_laneq_s32(_sumf, _pA1, _pB1, 3); + + _sum0 = vdotq_laneq_s32(_sum0, _pA2, _pB2, 0); + _sum1 = vdotq_laneq_s32(_sum1, _pA2, _pB2, 1); + _sum2 = vdotq_laneq_s32(_sum2, _pA2, _pB2, 2); + _sum3 = vdotq_laneq_s32(_sum3, _pA2, _pB2, 3); + _sum4 = vdotq_laneq_s32(_sum4, _pA3, _pB2, 0); + _sum5 = vdotq_laneq_s32(_sum5, _pA3, _pB2, 1); + _sum6 = vdotq_laneq_s32(_sum6, _pA3, _pB2, 2); + _sum7 = vdotq_laneq_s32(_sum7, _pA3, _pB2, 3); + _sum8 = vdotq_laneq_s32(_sum8, _pA2, _pB3, 0); + _sum9 = vdotq_laneq_s32(_sum9, _pA2, _pB3, 1); + _suma = vdotq_laneq_s32(_suma, _pA2, _pB3, 2); + _sumb = vdotq_laneq_s32(_sumb, _pA2, _pB3, 3); + _sumc = vdotq_laneq_s32(_sumc, _pA3, _pB3, 0); + _sumd = vdotq_laneq_s32(_sumd, _pA3, _pB3, 1); + _sume = vdotq_laneq_s32(_sume, _pA3, _pB3, 2); + _sumf = vdotq_laneq_s32(_sumf, _pA3, _pB3, 3); + + pA += 64; + pB += 64; + } +#endif // __ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA1 = vld1q_s8(pA + 16); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + + // aaaa bbbb cccc dddd eeee ffff gggg hhhh + + // 0000 1111 2222 3333 4444 5555 6666 7777 + _sum0 = vdotq_laneq_s32(_sum0, _pA0, _pB0, 0); + _sum1 = vdotq_laneq_s32(_sum1, _pA0, _pB0, 1); + _sum2 = vdotq_laneq_s32(_sum2, _pA0, _pB0, 2); + _sum3 = vdotq_laneq_s32(_sum3, _pA0, _pB0, 3); + _sum4 = vdotq_laneq_s32(_sum4, _pA1, _pB0, 0); + _sum5 = vdotq_laneq_s32(_sum5, _pA1, _pB0, 1); + _sum6 = vdotq_laneq_s32(_sum6, _pA1, _pB0, 2); + _sum7 = vdotq_laneq_s32(_sum7, _pA1, _pB0, 3); + _sum8 = vdotq_laneq_s32(_sum8, _pA0, _pB1, 0); + _sum9 = vdotq_laneq_s32(_sum9, _pA0, _pB1, 1); + _suma = vdotq_laneq_s32(_suma, _pA0, _pB1, 2); + _sumb = vdotq_laneq_s32(_sumb, _pA0, _pB1, 3); + _sumc = vdotq_laneq_s32(_sumc, _pA1, _pB1, 0); + _sumd = vdotq_laneq_s32(_sumd, _pA1, _pB1, 1); + _sume = vdotq_laneq_s32(_sume, _pA1, _pB1, 2); + _sumf = vdotq_laneq_s32(_sumf, _pA1, _pB1, 3); + +#else // __ARM_FEATURE_DOTPROD + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA2 = vld1q_s8(pA + 16); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB2 = vld1q_s8(pB + 16); + + // aabbccdd eeffgghh + // ccddaabb gghheeff + + int8x16_t _pA1 = vreinterpretq_s8_s32(vrev64q_s32(vreinterpretq_s32_s8(_pA0))); + + // 00112233 44556677 + // 33221100 77665544 + + int8x16_t _pB1 = vreinterpretq_s8_s16(vrev64q_s16(vreinterpretq_s16_s8(_pB0))); + + // aabbccdd eeffgghh + // ccddaabb gghheeff + + int8x16_t _pA3 = vreinterpretq_s8_s32(vrev64q_s32(vreinterpretq_s32_s8(_pA2))); + + // 00112233 44556677 + // 33221100 77665544 + + int8x16_t _pB3 = vreinterpretq_s8_s16(vrev64q_s16(vreinterpretq_s16_s8(_pB2))); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_pA0), vget_low_s8(_pB0)); + int16x8_t _s1 = vmull_s8(vget_high_s8(_pA0), vget_high_s8(_pB0)); + int16x8_t _s2 = vmull_s8(vget_high_s8(_pA0), vget_low_s8(_pB0)); + int16x8_t _s3 = vmull_s8(vget_low_s8(_pA0), vget_high_s8(_pB0)); + int16x8_t _s4 = vmull_s8(vget_low_s8(_pA1), vget_low_s8(_pB0)); + int16x8_t _s5 = vmull_s8(vget_high_s8(_pA1), vget_high_s8(_pB0)); + int16x8_t _s6 = vmull_s8(vget_high_s8(_pA1), vget_low_s8(_pB0)); + int16x8_t _s7 = vmull_s8(vget_low_s8(_pA1), vget_high_s8(_pB0)); + int16x8_t _s8 = vmull_s8(vget_low_s8(_pA0), vget_low_s8(_pB1)); + int16x8_t _s9 = vmull_s8(vget_high_s8(_pA0), vget_high_s8(_pB1)); + int16x8_t _sa = vmull_s8(vget_high_s8(_pA0), vget_low_s8(_pB1)); + int16x8_t _sb = vmull_s8(vget_low_s8(_pA0), vget_high_s8(_pB1)); + int16x8_t _sc = vmull_s8(vget_low_s8(_pA1), vget_low_s8(_pB1)); + int16x8_t _sd = vmull_s8(vget_high_s8(_pA1), vget_high_s8(_pB1)); + int16x8_t _se = vmull_s8(vget_high_s8(_pA1), vget_low_s8(_pB1)); + int16x8_t _sf = vmull_s8(vget_low_s8(_pA1), vget_high_s8(_pB1)); + + _s0 = vmlal_s8(_s0, vget_low_s8(_pA2), vget_low_s8(_pB2)); + _s1 = vmlal_s8(_s1, vget_high_s8(_pA2), vget_high_s8(_pB2)); + _s2 = vmlal_s8(_s2, vget_high_s8(_pA2), vget_low_s8(_pB2)); + _s3 = vmlal_s8(_s3, vget_low_s8(_pA2), vget_high_s8(_pB2)); + _s4 = vmlal_s8(_s4, vget_low_s8(_pA3), vget_low_s8(_pB2)); + _s5 = vmlal_s8(_s5, vget_high_s8(_pA3), vget_high_s8(_pB2)); + _s6 = vmlal_s8(_s6, vget_high_s8(_pA3), vget_low_s8(_pB2)); + _s7 = vmlal_s8(_s7, vget_low_s8(_pA3), vget_high_s8(_pB2)); + _s8 = vmlal_s8(_s8, vget_low_s8(_pA2), vget_low_s8(_pB3)); + _s9 = vmlal_s8(_s9, vget_high_s8(_pA2), vget_high_s8(_pB3)); + _sa = vmlal_s8(_sa, vget_high_s8(_pA2), vget_low_s8(_pB3)); + _sb = vmlal_s8(_sb, vget_low_s8(_pA2), vget_high_s8(_pB3)); + _sc = vmlal_s8(_sc, vget_low_s8(_pA3), vget_low_s8(_pB3)); + _sd = vmlal_s8(_sd, vget_high_s8(_pA3), vget_high_s8(_pB3)); + _se = vmlal_s8(_se, vget_high_s8(_pA3), vget_low_s8(_pB3)); + _sf = vmlal_s8(_sf, vget_low_s8(_pA3), vget_high_s8(_pB3)); + + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + _sum4 = vpadalq_s16(_sum4, _s4); + _sum5 = vpadalq_s16(_sum5, _s5); + _sum6 = vpadalq_s16(_sum6, _s6); + _sum7 = vpadalq_s16(_sum7, _s7); + _sum8 = vpadalq_s16(_sum8, _s8); + _sum9 = vpadalq_s16(_sum9, _s9); + _suma = vpadalq_s16(_suma, _sa); + _sumb = vpadalq_s16(_sumb, _sb); + _sumc = vpadalq_s16(_sumc, _sc); + _sumd = vpadalq_s16(_sumd, _sd); + _sume = vpadalq_s16(_sume, _se); + _sumf = vpadalq_s16(_sumf, _sf); +#endif // __ARM_FEATURE_DOTPROD + + pA += 32; + pB += 32; + } + for (; kk + 1 < max_kk; kk += 2) + { +#if __ARM_FEATURE_DOTPROD + int8x16_t _pA = vld1q_s8(pA); + int8x16_t _pB = vld1q_s8(pB); + + // aabbccdd eeffgghh + + // 00112233 44556677 + + int16x8_t _s0 = vmull_s8(vget_low_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pB)), 0))); + int16x8_t _s1 = vmull_s8(vget_low_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pB)), 1))); + int16x8_t _s2 = vmull_s8(vget_low_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pB)), 2))); + int16x8_t _s3 = vmull_s8(vget_low_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pB)), 3))); + int16x8_t _s4 = vmull_s8(vget_high_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pB)), 0))); + int16x8_t _s5 = vmull_s8(vget_high_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pB)), 1))); + int16x8_t _s6 = vmull_s8(vget_high_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pB)), 2))); + int16x8_t _s7 = vmull_s8(vget_high_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pB)), 3))); + int16x8_t _s8 = vmull_s8(vget_low_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pB)), 0))); + int16x8_t _s9 = vmull_s8(vget_low_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pB)), 1))); + int16x8_t _sa = vmull_s8(vget_low_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pB)), 2))); + int16x8_t _sb = vmull_s8(vget_low_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pB)), 3))); + int16x8_t _sc = vmull_s8(vget_high_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pB)), 0))); + int16x8_t _sd = vmull_s8(vget_high_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pB)), 1))); + int16x8_t _se = vmull_s8(vget_high_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pB)), 2))); + int16x8_t _sf = vmull_s8(vget_high_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pB)), 3))); + + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + _sum4 = vpadalq_s16(_sum4, _s4); + _sum5 = vpadalq_s16(_sum5, _s5); + _sum6 = vpadalq_s16(_sum6, _s6); + _sum7 = vpadalq_s16(_sum7, _s7); + _sum8 = vpadalq_s16(_sum8, _s8); + _sum9 = vpadalq_s16(_sum9, _s9); + _suma = vpadalq_s16(_suma, _sa); + _sumb = vpadalq_s16(_sumb, _sb); + _sumc = vpadalq_s16(_sumc, _sc); + _sumd = vpadalq_s16(_sumd, _sd); + _sume = vpadalq_s16(_sume, _se); + _sumf = vpadalq_s16(_sumf, _sf); +#else // __ARM_FEATURE_DOTPROD + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pB0 = vld1q_s8(pB); + + // aabbccdd eeffgghh + + // ccddaabb gghheeff + + int8x16_t _pA1 = vreinterpretq_s8_s32(vrev64q_s32(vreinterpretq_s32_s8(_pA0))); + + // 00112233 44556677 + + // 33221100 77665544 + + int8x16_t _pB1 = vreinterpretq_s8_s16(vrev64q_s16(vreinterpretq_s16_s8(_pB0))); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_pA0), vget_low_s8(_pB0)); + int16x8_t _s1 = vmull_s8(vget_high_s8(_pA0), vget_high_s8(_pB0)); + int16x8_t _s2 = vmull_s8(vget_high_s8(_pA0), vget_low_s8(_pB0)); + int16x8_t _s3 = vmull_s8(vget_low_s8(_pA0), vget_high_s8(_pB0)); + int16x8_t _s4 = vmull_s8(vget_low_s8(_pA1), vget_low_s8(_pB0)); + int16x8_t _s5 = vmull_s8(vget_high_s8(_pA1), vget_high_s8(_pB0)); + int16x8_t _s6 = vmull_s8(vget_high_s8(_pA1), vget_low_s8(_pB0)); + int16x8_t _s7 = vmull_s8(vget_low_s8(_pA1), vget_high_s8(_pB0)); + int16x8_t _s8 = vmull_s8(vget_low_s8(_pA0), vget_low_s8(_pB1)); + int16x8_t _s9 = vmull_s8(vget_high_s8(_pA0), vget_high_s8(_pB1)); + int16x8_t _sa = vmull_s8(vget_high_s8(_pA0), vget_low_s8(_pB1)); + int16x8_t _sb = vmull_s8(vget_low_s8(_pA0), vget_high_s8(_pB1)); + int16x8_t _sc = vmull_s8(vget_low_s8(_pA1), vget_low_s8(_pB1)); + int16x8_t _sd = vmull_s8(vget_high_s8(_pA1), vget_high_s8(_pB1)); + int16x8_t _se = vmull_s8(vget_high_s8(_pA1), vget_low_s8(_pB1)); + int16x8_t _sf = vmull_s8(vget_low_s8(_pA1), vget_high_s8(_pB1)); + + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + _sum4 = vpadalq_s16(_sum4, _s4); + _sum5 = vpadalq_s16(_sum5, _s5); + _sum6 = vpadalq_s16(_sum6, _s6); + _sum7 = vpadalq_s16(_sum7, _s7); + _sum8 = vpadalq_s16(_sum8, _s8); + _sum9 = vpadalq_s16(_sum9, _s9); + _suma = vpadalq_s16(_suma, _sa); + _sumb = vpadalq_s16(_sumb, _sb); + _sumc = vpadalq_s16(_sumc, _sc); + _sumd = vpadalq_s16(_sumd, _sd); + _sume = vpadalq_s16(_sume, _se); + _sumf = vpadalq_s16(_sumf, _sf); +#endif // __ARM_FEATURE_DOTPROD + + pA += 16; + pB += 16; + } + for (; kk < max_kk; kk += 1) + { +#if __ARM_FEATURE_DOTPROD + int8x8_t _pA = vld1_s8(pA); + // int8x8_t _pB0 = vld1_s8(pB); + + // abcd efgh + // 0123 4567 + + int16x8_t _s01 = vmull_s8(_pA, vdup_n_s8(pB[0])); + int16x8_t _s23 = vmull_s8(_pA, vdup_n_s8(pB[1])); + int16x8_t _s45 = vmull_s8(_pA, vdup_n_s8(pB[2])); + int16x8_t _s67 = vmull_s8(_pA, vdup_n_s8(pB[3])); + int16x8_t _s89 = vmull_s8(_pA, vdup_n_s8(pB[4])); + int16x8_t _sab = vmull_s8(_pA, vdup_n_s8(pB[5])); + int16x8_t _scd = vmull_s8(_pA, vdup_n_s8(pB[6])); + int16x8_t _sef = vmull_s8(_pA, vdup_n_s8(pB[7])); + + _sum0 = vaddw_s16(_sum0, vget_low_s16(_s01)); + _sum1 = vaddw_s16(_sum1, vget_low_s16(_s23)); + _sum2 = vaddw_s16(_sum2, vget_low_s16(_s45)); + _sum3 = vaddw_s16(_sum3, vget_low_s16(_s67)); + _sum4 = vaddw_s16(_sum4, vget_high_s16(_s01)); + _sum5 = vaddw_s16(_sum5, vget_high_s16(_s23)); + _sum6 = vaddw_s16(_sum6, vget_high_s16(_s45)); + _sum7 = vaddw_s16(_sum7, vget_high_s16(_s67)); + _sum8 = vaddw_s16(_sum8, vget_low_s16(_s89)); + _sum9 = vaddw_s16(_sum9, vget_low_s16(_sab)); + _suma = vaddw_s16(_suma, vget_low_s16(_scd)); + _sumb = vaddw_s16(_sumb, vget_low_s16(_sef)); + _sumc = vaddw_s16(_sumc, vget_high_s16(_s89)); + _sumd = vaddw_s16(_sumd, vget_high_s16(_sab)); + _sume = vaddw_s16(_sume, vget_high_s16(_scd)); + _sumf = vaddw_s16(_sumf, vget_high_s16(_sef)); +#else // __ARM_FEATURE_DOTPROD + int8x8_t _pA0 = vld1_s8(pA); + int8x8_t _pB0 = vld1_s8(pB); + + // abcd efgh + // efgh abcd + // cdab ghef + // ghef cdab + + // 0123 4567 + // 3210 7654 + + // abcdefgh -> ghefcdab -> cdabghef + + int8x8_t _pA1 = vext_s8(_pA0, _pA0, 4); + int8x8_t _pA2 = vreinterpret_s8_s16(vrev32_s16(vreinterpret_s16_s8(_pA0))); + int8x8_t _pA3 = vreinterpret_s8_s16(vrev64_s16(vreinterpret_s16_s8(_pA0))); + + // 01234567 -> 32107654 + + int8x8_t _pB1 = vrev32_s8(_pB0); + + int16x8_t _s01 = vmull_s8(_pA0, _pB0); + int16x8_t _s23 = vmull_s8(_pA1, _pB0); + int16x8_t _s45 = vmull_s8(_pA2, _pB0); + int16x8_t _s67 = vmull_s8(_pA3, _pB0); + int16x8_t _s89 = vmull_s8(_pA0, _pB1); + int16x8_t _sab = vmull_s8(_pA1, _pB1); + int16x8_t _scd = vmull_s8(_pA2, _pB1); + int16x8_t _sef = vmull_s8(_pA3, _pB1); + _sum0 = vaddw_s16(_sum0, vget_low_s16(_s01)); + _sum1 = vaddw_s16(_sum1, vget_high_s16(_s01)); + _sum2 = vaddw_s16(_sum2, vget_low_s16(_s23)); + _sum3 = vaddw_s16(_sum3, vget_high_s16(_s23)); + _sum4 = vaddw_s16(_sum4, vget_low_s16(_s45)); + _sum5 = vaddw_s16(_sum5, vget_high_s16(_s45)); + _sum6 = vaddw_s16(_sum6, vget_low_s16(_s67)); + _sum7 = vaddw_s16(_sum7, vget_high_s16(_s67)); + _sum8 = vaddw_s16(_sum8, vget_low_s16(_s89)); + _sum9 = vaddw_s16(_sum9, vget_high_s16(_s89)); + _suma = vaddw_s16(_suma, vget_low_s16(_sab)); + _sumb = vaddw_s16(_sumb, vget_high_s16(_sab)); + _sumc = vaddw_s16(_sumc, vget_low_s16(_scd)); + _sumd = vaddw_s16(_sumd, vget_high_s16(_scd)); + _sume = vaddw_s16(_sume, vget_low_s16(_sef)); + _sumf = vaddw_s16(_sumf, vget_high_s16(_sef)); +#endif // __ARM_FEATURE_DOTPROD + + pA += 8; + pB += 8; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + vst1q_s32(outptr + 8, _sum2); + vst1q_s32(outptr + 12, _sum3); + vst1q_s32(outptr + 16, _sum4); + vst1q_s32(outptr + 20, _sum5); + vst1q_s32(outptr + 24, _sum6); + vst1q_s32(outptr + 28, _sum7); + vst1q_s32(outptr + 32, _sum8); + vst1q_s32(outptr + 36, _sum9); + vst1q_s32(outptr + 40, _suma); + vst1q_s32(outptr + 44, _sumb); + vst1q_s32(outptr + 48, _sumc); + vst1q_s32(outptr + 52, _sumd); + vst1q_s32(outptr + 56, _sume); + vst1q_s32(outptr + 60, _sumf); + + outptr += 64; +#endif // NCNN_GNU_INLINE_ASM + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + const signed char* pA = pAT; + +#if NCNN_GNU_INLINE_ASM +#if __aarch64__ + asm volatile( + "cmp %w7, #0 \n" + "beq 0f \n" + + "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%0], #64 \n" + "ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%0] \n" + "sub %0, %0, #64 \n" + "b 1f \n" + + "0: \n" + "eor v16.16b, v16.16b, v16.16b \n" + "eor v17.16b, v17.16b, v17.16b \n" + "eor v18.16b, v18.16b, v18.16b \n" + "eor v19.16b, v19.16b, v19.16b \n" + "eor v20.16b, v20.16b, v20.16b \n" + "eor v21.16b, v21.16b, v21.16b \n" + "eor v22.16b, v22.16b, v22.16b \n" + "eor v23.16b, v23.16b, v23.16b \n" + + "1: \n" +#if __ARM_FEATURE_DOTPROD + "lsr w4, %w6, #3 \n" // w4 = max_kk >> 3 + "cmp w4, #0 \n" + "beq 101f \n" + +#if __ARM_FEATURE_MATMUL_INT8 + "eor v24.16b, v24.16b, v24.16b \n" + "eor v25.16b, v25.16b, v25.16b \n" + "eor v26.16b, v26.16b, v26.16b \n" + "eor v27.16b, v27.16b, v27.16b \n" + "eor v28.16b, v28.16b, v28.16b \n" + "eor v29.16b, v29.16b, v29.16b \n" + "eor v30.16b, v30.16b, v30.16b \n" + "eor v31.16b, v31.16b, v31.16b \n" +#endif // __ARM_FEATURE_MATMUL_INT8 + + "2: \n" + "ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [%1], #64 \n" + "ld1 {v4.16b, v5.16b}, [%2], #32 \n" + +#if __ARM_FEATURE_MATMUL_INT8 + "smmla v24.4s, v0.16b, v4.16b \n" + "smmla v25.4s, v1.16b, v4.16b \n" + "smmla v26.4s, v0.16b, v5.16b \n" + "smmla v27.4s, v1.16b, v5.16b \n" + "subs w4, w4, #1 \n" + "smmla v28.4s, v2.16b, v4.16b \n" + "smmla v29.4s, v3.16b, v4.16b \n" + "smmla v30.4s, v2.16b, v5.16b \n" + "smmla v31.4s, v3.16b, v5.16b \n" +#else // __ARM_FEATURE_MATMUL_INT8 + "sdot v16.4s, v0.16b, v4.4b[0] \n" + "sdot v17.4s, v0.16b, v4.4b[1] \n" + "sdot v18.4s, v0.16b, v4.4b[2] \n" + "sdot v19.4s, v0.16b, v4.4b[3] \n" + "sdot v20.4s, v1.16b, v4.4b[0] \n" + "sdot v21.4s, v1.16b, v4.4b[1] \n" + "sdot v22.4s, v1.16b, v4.4b[2] \n" + "sdot v23.4s, v1.16b, v4.4b[3] \n" + "subs w4, w4, #1 \n" + "sdot v16.4s, v2.16b, v5.4b[0] \n" + "sdot v17.4s, v2.16b, v5.4b[1] \n" + "sdot v18.4s, v2.16b, v5.4b[2] \n" + "sdot v19.4s, v2.16b, v5.4b[3] \n" + "sdot v20.4s, v3.16b, v5.4b[0] \n" + "sdot v21.4s, v3.16b, v5.4b[1] \n" + "sdot v22.4s, v3.16b, v5.4b[2] \n" + "sdot v23.4s, v3.16b, v5.4b[3] \n" +#endif // __ARM_FEATURE_MATMUL_INT8 + "bne 2b \n" + +#if __ARM_FEATURE_MATMUL_INT8 + "uzp1 v0.4s, v24.4s, v25.4s \n" + "uzp2 v1.4s, v24.4s, v25.4s \n" + "uzp1 v2.4s, v26.4s, v27.4s \n" + "uzp2 v3.4s, v26.4s, v27.4s \n" + "uzp1 v4.4s, v28.4s, v29.4s \n" + "uzp2 v5.4s, v28.4s, v29.4s \n" + "uzp1 v6.4s, v30.4s, v31.4s \n" + "uzp2 v7.4s, v30.4s, v31.4s \n" + + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "add v18.4s, v18.4s, v2.4s \n" + "add v19.4s, v19.4s, v3.4s \n" + "add v20.4s, v20.4s, v4.4s \n" + "add v21.4s, v21.4s, v5.4s \n" + "add v22.4s, v22.4s, v6.4s \n" + "add v23.4s, v23.4s, v7.4s \n" +#endif // __ARM_FEATURE_MATMUL_INT8 + + "101: \n" + "and w4, %w6, #4 \n" // w4 = remain = max_kk & 4 + "cmp w4, #0 \n" + "beq 3f \n" + + // kk += 4 part + "ld1 {v0.16b, v1.16b}, [%1], #32 \n" + "ld1 {v2.16b}, [%2], #16 \n" + "sdot v16.4s, v0.16b, v2.4b[0] \n" + "sdot v17.4s, v0.16b, v2.4b[1] \n" + "sdot v18.4s, v0.16b, v2.4b[2] \n" + "sdot v19.4s, v0.16b, v2.4b[3] \n" + "sdot v20.4s, v1.16b, v2.4b[0] \n" + "sdot v21.4s, v1.16b, v2.4b[1] \n" + "sdot v22.4s, v1.16b, v2.4b[2] \n" + "sdot v23.4s, v1.16b, v2.4b[3] \n" +#else // __ARM_FEATURE_DOTPROD + "lsr w4, %w6, #2 \n" // w4 = max_kk >> 2 + "cmp w4, #0 \n" + "beq 3f \n" + + "2: \n" + "ld1 {v0.16b, v1.16b}, [%1], #32 \n" + "ld1 {v4.16b}, [%2], #16 \n" + "smull v8.8h, v0.8b, v4.8b \n" + "rev64 v2.4s, v0.4s \n" + "smull v10.8h, v2.8b, v4.8b \n" + "ext v5.16b, v4.16b, v4.16b, #8 \n" + "smull2 v9.8h, v0.16b, v5.16b \n" + "rev64 v6.8h, v4.8h \n" + "smull2 v11.8h, v2.16b, v5.16b \n" + "ext v7.16b, v6.16b, v6.16b, #8 \n" + "smull v12.8h, v0.8b, v6.8b \n" + "smull v14.8h, v2.8b, v6.8b \n" + "rev64 v3.4s, v1.4s \n" + "smull2 v13.8h, v0.16b, v7.16b \n" + "smull2 v15.8h, v2.16b, v7.16b \n" + "smlal v8.8h, v1.8b, v5.8b \n" + "smlal v10.8h, v3.8b, v5.8b \n" + "smlal2 v9.8h, v1.16b, v4.16b \n" + "smlal2 v11.8h, v3.16b, v4.16b \n" + "smlal v12.8h, v1.8b, v7.8b \n" + "smlal v14.8h, v3.8b, v7.8b \n" + "smlal2 v13.8h, v1.16b, v6.16b \n" + "smlal2 v15.8h, v3.16b, v6.16b \n" + "subs w4, w4, #1 \n" + "sadalp v16.4s, v8.8h \n" + "sadalp v18.4s, v10.8h \n" + "sadalp v17.4s, v9.8h \n" + "sadalp v19.4s, v11.8h \n" + "sadalp v20.4s, v12.8h \n" + "sadalp v22.4s, v14.8h \n" + "sadalp v21.4s, v13.8h \n" + "sadalp v23.4s, v15.8h \n" + "bne 2b \n" +#endif // __ARM_FEATURE_DOTPROD + + "3: \n" + "and w4, %w6, #2 \n" // w4 = remain = max_kk & 2 + "cmp w4, #0 \n" + "beq 4f \n" + + // kk += 2 part +#if __ARM_FEATURE_DOTPROD + "ld1 {v0.16b}, [%1], #16 \n" + "ld1 {v1.8b}, [%2], #8 \n" + "dup v4.8h, v1.h[0] \n" + "dup v5.8h, v1.h[1] \n" + "dup v6.8h, v1.h[2] \n" + "dup v7.8h, v1.h[3] \n" + "smull v8.8h, v0.8b, v4.8b \n" + "smull v9.8h, v0.8b, v5.8b \n" + "smull v10.8h, v0.8b, v6.8b \n" + "smull v11.8h, v0.8b, v7.8b \n" + "smull2 v12.8h, v0.16b, v4.16b \n" + "smull2 v13.8h, v0.16b, v5.16b \n" + "smull2 v14.8h, v0.16b, v6.16b \n" + "smull2 v15.8h, v0.16b, v7.16b \n" + "sadalp v16.4s, v8.8h \n" + "sadalp v17.4s, v9.8h \n" + "sadalp v18.4s, v10.8h \n" + "sadalp v19.4s, v11.8h \n" + "sadalp v20.4s, v12.8h \n" + "sadalp v21.4s, v13.8h \n" + "sadalp v22.4s, v14.8h \n" + "sadalp v23.4s, v15.8h \n" +#else // __ARM_FEATURE_DOTPROD + "ld1 {v0.16b}, [%1], #16 \n" + "ld1r {v2.2d}, [%2] \n" + "add %2, %2, #8 \n" + "rev64 v1.4s, v0.4s \n" + "rev64 v3.8h, v2.8h \n" + "smull v8.8h, v0.8b, v2.8b \n" + "smull2 v9.8h, v0.16b, v2.16b \n" + "smull v10.8h, v1.8b, v2.8b \n" + "smull2 v11.8h, v1.16b, v2.16b \n" + "smull v12.8h, v0.8b, v3.8b \n" + "smull2 v13.8h, v0.16b, v3.16b \n" + "smull v14.8h, v1.8b, v3.8b \n" + "smull2 v15.8h, v1.16b, v3.16b \n" + "sadalp v16.4s, v8.8h \n" + "sadalp v17.4s, v9.8h \n" + "sadalp v18.4s, v10.8h \n" + "sadalp v19.4s, v11.8h \n" + "sadalp v20.4s, v12.8h \n" + "sadalp v21.4s, v13.8h \n" + "sadalp v22.4s, v14.8h \n" + "sadalp v23.4s, v15.8h \n" +#endif // __ARM_FEATURE_DOTPROD + + "4: \n" + "and w4, %w6, #1 \n" // w4 = remain = max_kk & 1 + "cmp w4, #0 \n" + "beq 5f \n" + + // kk += 1 part +#if __ARM_FEATURE_DOTPROD + "ld1 {v0.8b}, [%1], #8 \n" + "ld1 {v1.8b}, [%2] \n" + "add %2, %2, #4 \n" + "dup v8.8b, v1.b[0] \n" + "dup v9.8b, v1.b[1] \n" + "dup v10.8b, v1.b[2] \n" + "dup v11.8b, v1.b[3] \n" + "smull v8.8h, v0.8b, v8.8b \n" + "smull v9.8h, v0.8b, v9.8b \n" + "smull v10.8h, v0.8b, v10.8b \n" + "smull v11.8h, v0.8b, v11.8b \n" + "saddw v16.4s, v16.4s, v8.4h \n" + "saddw v17.4s, v17.4s, v9.4h \n" + "saddw v18.4s, v18.4s, v10.4h \n" + "saddw v19.4s, v19.4s, v11.4h \n" + "saddw2 v20.4s, v20.4s, v8.8h \n" + "saddw2 v21.4s, v21.4s, v9.8h \n" + "saddw2 v22.4s, v22.4s, v10.8h \n" + "saddw2 v23.4s, v23.4s, v11.8h \n" +#else // __ARM_FEATURE_DOTPROD + "ld1 {v0.8b}, [%1], #8 \n" + "ld1r {v4.2s}, [%2] \n" + "add %2, %2, #4 \n" + "rev32 v1.4h, v0.4h \n" + "rev64 v5.8b, v4.8b \n" + "smull v8.8h, v0.8b, v4.8b \n" + "smull v9.8h, v1.8b, v4.8b \n" + "smull v10.8h, v0.8b, v5.8b \n" + "smull v11.8h, v1.8b, v5.8b \n" + "saddw v16.4s, v16.4s, v8.4h \n" + "saddw2 v17.4s, v17.4s, v8.8h \n" + "saddw v18.4s, v18.4s, v9.4h \n" + "saddw2 v19.4s, v19.4s, v9.8h \n" + "saddw v20.4s, v20.4s, v10.4h \n" + "saddw2 v21.4s, v21.4s, v10.8h \n" + "saddw v22.4s, v22.4s, v11.4h \n" + "saddw2 v23.4s, v23.4s, v11.8h \n" +#endif // __ARM_FEATURE_DOTPROD + + "5: \n" + "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%0], #64 \n" + "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%0], #64 \n" + + : "=r"(outptr), // %0 + "=r"(pA), // %1 + "=r"(pB) // %2 + : "0"(outptr), + "1"(pA), + "2"(pB), + "r"(max_kk), // %6 + "r"(k) // %7 + : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +#else // __aarch64__ + asm volatile( + "cmp %7, #0 \n" + "beq 0f \n" + + "vldm %0!, {d16-d23} \n" + "vldm %0, {d24-d31} \n" + "sub %0, %0, #64 \n" + "b 1f \n" + + "0: \n" + "veor q8, q8 \n" + "veor q9, q9 \n" + "veor q10, q10 \n" + "veor q11, q11 \n" + "veor q12, q12 \n" + "veor q13, q13 \n" + "veor q14, q14 \n" + "veor q15, q15 \n" + + "1: \n" + "lsr r4, %6, #2 \n" // r4 = max_kk >> 2 + "cmp r4, #0 \n" + "beq 3f \n" + + ".align 4 \n" + "2: \n" + "pld [%1, #256] \n" + "vld1.s8 {d0-d3}, [%1 :64]! \n" + "pld [%2, #128] \n" + "vld1.s8 {d4-d5}, [%2]! \n" + "vmull.s8 q4, d0, d4 \n" + "vrev64.32 q3, q0 \n" + "vmull.s8 q5, d1, d4 \n" + "vmull.s8 q6, d6, d4 \n" + "vmull.s8 q7, d7, d4 \n" + "vrev64.32 q0, q1 \n" + "vmlal.s8 q4, d2, d5 \n" + "vmlal.s8 q5, d3, d5 \n" + "vmlal.s8 q6, d0, d5 \n" + "vmlal.s8 q7, d1, d5 \n" + "vrev64.16 q2, q2 \n" + "vpadal.s16 q8, q4 \n" + "vrev64.32 q1, q3 \n" + "vpadal.s16 q9, q5 \n" + "vmull.s8 q4, d6, d4 \n" + "vpadal.s16 q10, q6 \n" + "vmull.s8 q5, d7, d4 \n" + "vpadal.s16 q11, q7 \n" + "vmull.s8 q6, d2, d4 \n" + "vmull.s8 q7, d3, d4 \n" + "vrev64.32 q3, q0 \n" + "vmlal.s8 q4, d0, d5 \n" + "vmlal.s8 q5, d1, d5 \n" + "vmlal.s8 q6, d6, d5 \n" + "vmlal.s8 q7, d7, d5 \n" + "subs r4, r4, #1 \n" + "vpadal.s16 q14, q4 \n" + "vpadal.s16 q15, q5 \n" + "vpadal.s16 q12, q6 \n" + "vpadal.s16 q13, q7 \n" + "bne 2b \n" + + "3: \n" + "and r4, %6, #2 \n" // r4 = remain = max_kk & 2 + "cmp r4, #0 \n" + "beq 4f \n" + + // kk += 2 part + "vld1.s8 {d0-d1}, [%1 :64]! \n" + "vld1.s8 {d4}, [%2]! \n" + "vrev64.32 q1, q0 \n" + "vrev64.16 d5, d4 \n" + "vmull.s8 q4, d0, d4 \n" + "vmull.s8 q5, d1, d4 \n" + "vmull.s8 q6, d2, d4 \n" + "vmull.s8 q7, d3, d4 \n" + "vpadal.s16 q8, q4 \n" + "vpadal.s16 q9, q5 \n" + "vpadal.s16 q10, q6 \n" + "vpadal.s16 q11, q7 \n" + "vmull.s8 q4, d0, d5 \n" + "vmull.s8 q5, d1, d5 \n" + "vmull.s8 q6, d2, d5 \n" + "vmull.s8 q7, d3, d5 \n" + "vpadal.s16 q12, q4 \n" + "vpadal.s16 q13, q5 \n" + "vpadal.s16 q14, q6 \n" + "vpadal.s16 q15, q7 \n" + + "4: \n" + "and r4, %6, #1 \n" // r4 = remain = max_kk & 1 + "cmp r4, #0 \n" + "beq 5f \n" + + // kk += 1 part + "vld1.s8 {d0}, [%1 :64]! \n" + "vld1.s32 {d2[]}, [%2]! \n" + "vrev64.16 d1, d0 \n" + "vrev64.8 d3, d2 \n" + "vext.s8 d1, d1, #4 \n" + "vmull.s8 q4, d0, d2 \n" + "vmull.s8 q5, d1, d2 \n" + "vmull.s8 q6, d0, d3 \n" + "vmull.s8 q7, d1, d3 \n" + "vaddw.s16 q8, d8 \n" + "vaddw.s16 q9, d9 \n" + "vaddw.s16 q10, d10 \n" + "vaddw.s16 q11, d11 \n" + "vaddw.s16 q12, d12 \n" + "vaddw.s16 q13, d13 \n" + "vaddw.s16 q14, d14 \n" + "vaddw.s16 q15, d15 \n" + + "5: \n" + "vstm %0!, {d16-d23} \n" + "vstm %0!, {d24-d31} \n" + + : "=r"(outptr), // %0 + "=r"(pA), // %1 + "=r"(pB) // %2 + : "0"(outptr), + "1"(pA), + "2"(pB), + "r"(max_kk), // %6 + "r"(k) // %7 + : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +#endif // __aarch64__ +#else // NCNN_GNU_INLINE_ASM + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + int32x4_t _sum3; + int32x4_t _sum4; + int32x4_t _sum5; + int32x4_t _sum6; + int32x4_t _sum7; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + _sum4 = vdupq_n_s32(0); + _sum5 = vdupq_n_s32(0); + _sum6 = vdupq_n_s32(0); + _sum7 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + _sum2 = vld1q_s32(outptr + 8); + _sum3 = vld1q_s32(outptr + 12); + _sum4 = vld1q_s32(outptr + 16); + _sum5 = vld1q_s32(outptr + 20); + _sum6 = vld1q_s32(outptr + 24); + _sum7 = vld1q_s32(outptr + 28); + } + + int kk = 0; +#if __ARM_FEATURE_DOTPROD + { +#if __ARM_FEATURE_MATMUL_INT8 + int32x4_t _s0 = vdupq_n_s32(0); + int32x4_t _s1 = vdupq_n_s32(0); + int32x4_t _s2 = vdupq_n_s32(0); + int32x4_t _s3 = vdupq_n_s32(0); + int32x4_t _s4 = vdupq_n_s32(0); + int32x4_t _s5 = vdupq_n_s32(0); + int32x4_t _s6 = vdupq_n_s32(0); + int32x4_t _s7 = vdupq_n_s32(0); +#endif // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 7 < max_kk; kk += 8) + { + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA1 = vld1q_s8(pA + 16); + int8x16_t _pA2 = vld1q_s8(pA + 32); + int8x16_t _pA3 = vld1q_s8(pA + 48); + + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + +#if __ARM_FEATURE_MATMUL_INT8 + // aaaaaaaa bbbbbbbb ..... hhhhhhhh + // 00000000 11111111 22222222 33333333 + + _s0 = vmmlaq_s32(_s0, _pA0, _pB0); + _s1 = vmmlaq_s32(_s1, _pA1, _pB0); + _s2 = vmmlaq_s32(_s2, _pA0, _pB1); + _s3 = vmmlaq_s32(_s3, _pA1, _pB1); + _s4 = vmmlaq_s32(_s4, _pA2, _pB0); + _s5 = vmmlaq_s32(_s5, _pA3, _pB0); + _s6 = vmmlaq_s32(_s6, _pA2, _pB1); + _s7 = vmmlaq_s32(_s7, _pA3, _pB1); +#else // __ARM_FEATURE_MATMUL_INT8 + _sum0 = vdotq_laneq_s32(_sum0, _pA0, _pB0, 0); + _sum1 = vdotq_laneq_s32(_sum1, _pA0, _pB0, 1); + _sum2 = vdotq_laneq_s32(_sum2, _pA0, _pB0, 2); + _sum3 = vdotq_laneq_s32(_sum3, _pA0, _pB0, 3); + _sum4 = vdotq_laneq_s32(_sum4, _pA1, _pB0, 0); + _sum5 = vdotq_laneq_s32(_sum5, _pA1, _pB0, 1); + _sum6 = vdotq_laneq_s32(_sum6, _pA1, _pB0, 2); + _sum7 = vdotq_laneq_s32(_sum7, _pA1, _pB0, 3); + + _sum0 = vdotq_laneq_s32(_sum0, _pA2, _pB1, 0); + _sum1 = vdotq_laneq_s32(_sum1, _pA2, _pB1, 1); + _sum2 = vdotq_laneq_s32(_sum2, _pA2, _pB1, 2); + _sum3 = vdotq_laneq_s32(_sum3, _pA2, _pB1, 3); + _sum4 = vdotq_laneq_s32(_sum4, _pA3, _pB1, 0); + _sum5 = vdotq_laneq_s32(_sum5, _pA3, _pB1, 1); + _sum6 = vdotq_laneq_s32(_sum6, _pA3, _pB1, 2); + _sum7 = vdotq_laneq_s32(_sum7, _pA3, _pB1, 3); +#endif // __ARM_FEATURE_MATMUL_INT8 + + pA += 64; + pB += 32; + } +#if __ARM_FEATURE_MATMUL_INT8 + int32x4x2_t _ss0 = vuzpq_s32(_s0, _s1); + int32x4x2_t _ss1 = vuzpq_s32(_s2, _s3); + int32x4x2_t _ss2 = vuzpq_s32(_s4, _s5); + int32x4x2_t _ss3 = vuzpq_s32(_s6, _s7); + _sum0 = vaddq_s32(_sum0, _ss0.val[0]); + _sum1 = vaddq_s32(_sum1, _ss0.val[1]); + _sum2 = vaddq_s32(_sum2, _ss1.val[0]); + _sum3 = vaddq_s32(_sum3, _ss1.val[1]); + _sum4 = vaddq_s32(_sum4, _ss2.val[0]); + _sum5 = vaddq_s32(_sum5, _ss2.val[1]); + _sum6 = vaddq_s32(_sum6, _ss3.val[0]); + _sum7 = vaddq_s32(_sum7, _ss3.val[1]); +#endif // __ARM_FEATURE_MATMUL_INT8 + } +#endif // __ARM_FEATURE_DOTPROD + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA1 = vld1q_s8(pA + 16); + int8x16_t _pB = vld1q_s8(pB); + + // aaaa bbbb cccc dddd eeee ffff gggg hhhh + + // 0000 1111 2222 3333 + + _sum0 = vdotq_laneq_s32(_sum0, _pA0, _pB, 0); + _sum1 = vdotq_laneq_s32(_sum1, _pA0, _pB, 1); + _sum2 = vdotq_laneq_s32(_sum2, _pA0, _pB, 2); + _sum3 = vdotq_laneq_s32(_sum3, _pA0, _pB, 3); + _sum4 = vdotq_laneq_s32(_sum4, _pA1, _pB, 0); + _sum5 = vdotq_laneq_s32(_sum5, _pA1, _pB, 1); + _sum6 = vdotq_laneq_s32(_sum6, _pA1, _pB, 2); + _sum7 = vdotq_laneq_s32(_sum7, _pA1, _pB, 3); +#else // __ARM_FEATURE_DOTPROD + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA2 = vld1q_s8(pA + 16); + int8x16_t _pB02 = vld1q_s8(pB); + + // aabbccdd eeffgghh + + // ccddaabb gghheeff + + int8x16_t _pA1 = vreinterpretq_s8_s32(vrev64q_s32(vreinterpretq_s32_s8(_pA0))); + int8x16_t _pA3 = vreinterpretq_s8_s32(vrev64q_s32(vreinterpretq_s32_s8(_pA2))); + + // 00112233 44556677 + + // 33221100 77665544 + + int8x16_t _pB13 = vreinterpretq_s8_s16(vrev64q_s16(vreinterpretq_s16_s8(_pB02))); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_pA0), vget_low_s8(_pB02)); + int16x8_t _s1 = vmull_s8(vget_high_s8(_pA0), vget_low_s8(_pB02)); + int16x8_t _s2 = vmull_s8(vget_low_s8(_pA1), vget_low_s8(_pB02)); + int16x8_t _s3 = vmull_s8(vget_high_s8(_pA1), vget_low_s8(_pB02)); + int16x8_t _s4 = vmull_s8(vget_low_s8(_pA0), vget_low_s8(_pB13)); + int16x8_t _s5 = vmull_s8(vget_high_s8(_pA0), vget_low_s8(_pB13)); + int16x8_t _s6 = vmull_s8(vget_low_s8(_pA1), vget_low_s8(_pB13)); + int16x8_t _s7 = vmull_s8(vget_high_s8(_pA1), vget_low_s8(_pB13)); + + _s0 = vmlal_s8(_s0, vget_low_s8(_pA2), vget_high_s8(_pB02)); + _s1 = vmlal_s8(_s1, vget_high_s8(_pA2), vget_high_s8(_pB02)); + _s2 = vmlal_s8(_s2, vget_low_s8(_pA3), vget_high_s8(_pB02)); + _s3 = vmlal_s8(_s3, vget_high_s8(_pA3), vget_high_s8(_pB02)); + _s4 = vmlal_s8(_s4, vget_low_s8(_pA2), vget_high_s8(_pB13)); + _s5 = vmlal_s8(_s5, vget_high_s8(_pA2), vget_high_s8(_pB13)); + _s6 = vmlal_s8(_s6, vget_low_s8(_pA3), vget_high_s8(_pB13)); + _s7 = vmlal_s8(_s7, vget_high_s8(_pA3), vget_high_s8(_pB13)); + + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + _sum4 = vpadalq_s16(_sum4, _s4); + _sum5 = vpadalq_s16(_sum5, _s5); + _sum6 = vpadalq_s16(_sum6, _s6); + _sum7 = vpadalq_s16(_sum7, _s7); +#endif // __ARM_FEATURE_DOTPROD + + pA += 32; + pB += 16; + } + for (; kk + 1 < max_kk; kk += 2) + { +#if __ARM_FEATURE_DOTPROD + int8x16_t _pA = vld1q_s8(pA); + int8x8_t _pB = vld1_s8(pB); + + // aabbccdd eeffgghh + + // 00112233 + int16x8_t _s0 = vmull_s8(vget_low_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pB), 0))); + int16x8_t _s1 = vmull_s8(vget_low_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pB), 1))); + int16x8_t _s2 = vmull_s8(vget_low_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pB), 2))); + int16x8_t _s3 = vmull_s8(vget_low_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pB), 3))); + int16x8_t _s4 = vmull_s8(vget_high_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pB), 0))); + int16x8_t _s5 = vmull_s8(vget_high_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pB), 1))); + int16x8_t _s6 = vmull_s8(vget_high_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pB), 2))); + int16x8_t _s7 = vmull_s8(vget_high_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pB), 3))); + + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + _sum4 = vpadalq_s16(_sum4, _s4); + _sum5 = vpadalq_s16(_sum5, _s5); + _sum6 = vpadalq_s16(_sum6, _s6); + _sum7 = vpadalq_s16(_sum7, _s7); +#else // __ARM_FEATURE_DOTPROD + int8x16_t _pA0 = vld1q_s8(pA); + int8x8_t _pB0 = vld1_s8(pB); + + // aabbccdd eeffgghh + + // ccddaabb gghheeff + + int8x16_t _pA1 = vreinterpretq_s8_s32(vrev64q_s32(vreinterpretq_s32_s8(_pA0))); + + // 00112233 + + // 33221100 + + int8x8_t _pB1 = vreinterpret_s8_s16(vrev64_s16(vreinterpret_s16_s8(_pB0))); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_pA0), _pB0); + int16x8_t _s1 = vmull_s8(vget_high_s8(_pA0), _pB0); + int16x8_t _s2 = vmull_s8(vget_low_s8(_pA1), _pB0); + int16x8_t _s3 = vmull_s8(vget_high_s8(_pA1), _pB0); + int16x8_t _s4 = vmull_s8(vget_low_s8(_pA0), _pB1); + int16x8_t _s5 = vmull_s8(vget_high_s8(_pA0), _pB1); + int16x8_t _s6 = vmull_s8(vget_low_s8(_pA1), _pB1); + int16x8_t _s7 = vmull_s8(vget_high_s8(_pA1), _pB1); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + _sum4 = vpadalq_s16(_sum4, _s4); + _sum5 = vpadalq_s16(_sum5, _s5); + _sum6 = vpadalq_s16(_sum6, _s6); + _sum7 = vpadalq_s16(_sum7, _s7); +#endif // __ARM_FEATURE_DOTPROD + + pA += 16; + pB += 8; + } + for (; kk < max_kk; kk += 1) + { +#if __ARM_FEATURE_DOTPROD + int8x8_t _pA0 = vld1_s8(pA); + // int8x8_t _pB0 = vreinterpret_s32_s8(vld1_dup_s32(pB)); + + // abcdefgh + + // 0123 + + int16x8_t _s01 = vmull_s8(_pA0, vdup_n_s8(pB[0])); + int16x8_t _s23 = vmull_s8(_pA0, vdup_n_s8(pB[1])); + int16x8_t _s45 = vmull_s8(_pA0, vdup_n_s8(pB[2])); + int16x8_t _s67 = vmull_s8(_pA0, vdup_n_s8(pB[3])); + _sum0 = vaddw_s16(_sum0, vget_low_s16(_s01)); + _sum1 = vaddw_s16(_sum1, vget_low_s16(_s23)); + _sum2 = vaddw_s16(_sum2, vget_low_s16(_s45)); + _sum3 = vaddw_s16(_sum3, vget_low_s16(_s67)); + _sum4 = vaddw_s16(_sum4, vget_high_s16(_s01)); + _sum5 = vaddw_s16(_sum5, vget_high_s16(_s23)); + _sum6 = vaddw_s16(_sum6, vget_high_s16(_s45)); + _sum7 = vaddw_s16(_sum7, vget_high_s16(_s67)); +#else // __ARM_FEATURE_DOTPROD + int8x8_t _pA0 = vld1_s8(pA); + int8x8_t _pB0 = vreinterpret_s8_s32(vld1_dup_s32((const int*)pB)); + // int8x8_t _pB0 = vld1_s8(pB); + // _pB0 = vreinterpret_s8_s32(vzip_s32(vreinterpret_s32_s8(_pB0), vreinterpret_s32_s8(_pB0)).val[0]); + + // abcdefgh -> cdabghef + int8x8_t _pA1 = vreinterpret_s8_s16(vrev32_s16(vreinterpret_s16_s8(_pA0))); + + // 01230123 -> 32103210 + int8x8_t _pB1 = vrev64_s8(_pB0); + + int16x8_t _s01 = vmull_s8(_pA0, _pB0); + int16x8_t _s23 = vmull_s8(_pA1, _pB0); + int16x8_t _s45 = vmull_s8(_pA0, _pB1); + int16x8_t _s67 = vmull_s8(_pA1, _pB1); + _sum0 = vaddw_s16(_sum0, vget_low_s16(_s01)); + _sum1 = vaddw_s16(_sum1, vget_high_s16(_s01)); + _sum2 = vaddw_s16(_sum2, vget_low_s16(_s23)); + _sum3 = vaddw_s16(_sum3, vget_high_s16(_s23)); + _sum4 = vaddw_s16(_sum4, vget_low_s16(_s45)); + _sum5 = vaddw_s16(_sum5, vget_high_s16(_s45)); + _sum6 = vaddw_s16(_sum6, vget_low_s16(_s67)); + _sum7 = vaddw_s16(_sum7, vget_high_s16(_s67)); +#endif // __ARM_FEATURE_DOTPROD + + pA += 8; + pB += 4; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + vst1q_s32(outptr + 8, _sum2); + vst1q_s32(outptr + 12, _sum3); + vst1q_s32(outptr + 16, _sum4); + vst1q_s32(outptr + 20, _sum5); + vst1q_s32(outptr + 24, _sum6); + vst1q_s32(outptr + 28, _sum7); + + outptr += 32; +#endif // NCNN_GNU_INLINE_ASM + } + for (; jj + 1 < max_jj; jj += 2) + { + const signed char* pA = pAT; + + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + int32x4_t _sum3; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + _sum2 = vld1q_s32(outptr + 8); + _sum3 = vld1q_s32(outptr + 12); + } + + int kk = 0; +#if __ARM_FEATURE_DOTPROD + { +#if __ARM_FEATURE_MATMUL_INT8 + int32x4_t _s0 = vdupq_n_s32(0); + int32x4_t _s1 = vdupq_n_s32(0); + int32x4_t _s2 = vdupq_n_s32(0); + int32x4_t _s3 = vdupq_n_s32(0); +#endif // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 7 < max_kk; kk += 8) + { + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA1 = vld1q_s8(pA + 16); + int8x16_t _pA2 = vld1q_s8(pA + 32); + int8x16_t _pA3 = vld1q_s8(pA + 48); + + int8x16_t _pB = vld1q_s8(pB); + +#if __ARM_FEATURE_MATMUL_INT8 + // aaaaaaaa bbbbbbbb ..... hhhhhhhh + // 00000000 11111111 + + _s0 = vmmlaq_s32(_s0, _pA0, _pB); + _s1 = vmmlaq_s32(_s1, _pA1, _pB); + _s2 = vmmlaq_s32(_s2, _pA2, _pB); + _s3 = vmmlaq_s32(_s3, _pA3, _pB); +#else // __ARM_FEATURE_MATMUL_INT8 + _sum0 = vdotq_laneq_s32(_sum0, _pA0, _pB, 0); + _sum1 = vdotq_laneq_s32(_sum1, _pA0, _pB, 1); + _sum2 = vdotq_laneq_s32(_sum2, _pA1, _pB, 0); + _sum3 = vdotq_laneq_s32(_sum3, _pA1, _pB, 1); + + _sum0 = vdotq_laneq_s32(_sum0, _pA2, _pB, 2); + _sum1 = vdotq_laneq_s32(_sum1, _pA2, _pB, 3); + _sum2 = vdotq_laneq_s32(_sum2, _pA3, _pB, 2); + _sum3 = vdotq_laneq_s32(_sum3, _pA3, _pB, 3); +#endif // __ARM_FEATURE_MATMUL_INT8 + + pA += 64; + pB += 16; + } +#if __ARM_FEATURE_MATMUL_INT8 + int32x4x2_t _ss0 = vuzpq_s32(_s0, _s1); + int32x4x2_t _ss1 = vuzpq_s32(_s2, _s3); + _sum0 = vaddq_s32(_sum0, _ss0.val[0]); + _sum1 = vaddq_s32(_sum1, _ss0.val[1]); + _sum2 = vaddq_s32(_sum2, _ss1.val[0]); + _sum3 = vaddq_s32(_sum3, _ss1.val[1]); +#endif // __ARM_FEATURE_MATMUL_INT8 + } +#endif // __ARM_FEATURE_DOTPROD + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA1 = vld1q_s8(pA + 16); + int8x8_t _pB = vld1_s8(pB); + + // aaaa bbbb cccc dddd eeee ffff gggg hhhh + + // 0000 1111 + + _sum0 = vdotq_lane_s32(_sum0, _pA0, _pB, 0); + _sum1 = vdotq_lane_s32(_sum1, _pA0, _pB, 1); + _sum2 = vdotq_lane_s32(_sum2, _pA1, _pB, 0); + _sum3 = vdotq_lane_s32(_sum3, _pA1, _pB, 1); +#else // __ARM_FEATURE_DOTPROD + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA2 = vld1q_s8(pA + 16); + int8x8_t _pB = vld1_s8(pB); + + // aabbccdd eeffgghh aabbccdd eeffgghh + + // 00112233 -> 00110011 22332233 + + // 11001100 33223322 + + int32x2x2_t _pBB = vzip_s32(vreinterpret_s32_s8(_pB), vreinterpret_s32_s8(_pB)); + int8x16_t _pB02 = vreinterpretq_s8_s32(vcombine_s32(_pBB.val[0], _pBB.val[1])); + + int8x16_t _pB13 = vreinterpretq_s8_s16(vrev64q_s16(vreinterpretq_s16_s8(_pB02))); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_pA0), vget_low_s8(_pB02)); + int16x8_t _s1 = vmull_s8(vget_high_s8(_pA0), vget_low_s8(_pB02)); + int16x8_t _s2 = vmull_s8(vget_low_s8(_pA0), vget_low_s8(_pB13)); + int16x8_t _s3 = vmull_s8(vget_high_s8(_pA0), vget_low_s8(_pB13)); + _s0 = vmlal_s8(_s0, vget_low_s8(_pA2), vget_high_s8(_pB02)); + _s1 = vmlal_s8(_s1, vget_high_s8(_pA2), vget_high_s8(_pB02)); + _s2 = vmlal_s8(_s2, vget_low_s8(_pA2), vget_high_s8(_pB13)); + _s3 = vmlal_s8(_s3, vget_high_s8(_pA2), vget_high_s8(_pB13)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); +#endif // __ARM_FEATURE_DOTPROD + + pA += 32; + pB += 8; + } + for (; kk + 1 < max_kk; kk += 2) + { +#if __ARM_FEATURE_DOTPROD + int8x16_t _pA = vld1q_s8(pA); + int16x4_t _pB = vreinterpret_s16_s32(vld1_dup_s32((const int*)pB)); + + int16x4x2_t _pB01 = vuzp_s16(_pB, _pB); + int8x8_t _pB0 = vreinterpret_s8_s16(_pB01.val[0]); + int8x8_t _pB1 = vreinterpret_s8_s16(_pB01.val[1]); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_pA), _pB0); + int16x8_t _s1 = vmull_s8(vget_low_s8(_pA), _pB1); + int16x8_t _s2 = vmull_s8(vget_high_s8(_pA), _pB0); + int16x8_t _s3 = vmull_s8(vget_high_s8(_pA), _pB1); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); +#else // __ARM_FEATURE_DOTPROD + int8x16_t _pA = vld1q_s8(pA); + int8x8_t _pB0 = vreinterpret_s8_s32(vld1_dup_s32((const int*)pB)); + + // aabbccdd eeffgghh + + // 00110011 + // 11001100 + + int8x8_t _pB1 = vreinterpret_s8_s16(vrev64_s16(vreinterpret_s16_s8(_pB0))); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_pA), _pB0); + int16x8_t _s1 = vmull_s8(vget_high_s8(_pA), _pB0); + int16x8_t _s2 = vmull_s8(vget_low_s8(_pA), _pB1); + int16x8_t _s3 = vmull_s8(vget_high_s8(_pA), _pB1); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); +#endif // __ARM_FEATURE_DOTPROD + + pA += 16; + pB += 4; + } + for (; kk < max_kk; kk += 1) + { +#if __ARM_FEATURE_DOTPROD + int8x8_t _pA = vld1_s8(pA); + int8x8_t _pB = vreinterpret_s8_s16(vld1_dup_s16((const short*)pB)); + + int8x8x2_t _pB01 = vuzp_s8(_pB, _pB); + + int16x8_t _s0 = vmull_s8(_pA, _pB01.val[0]); + int16x8_t _s1 = vmull_s8(_pA, _pB01.val[1]); + _sum0 = vaddw_s16(_sum0, vget_low_s16(_s0)); + _sum1 = vaddw_s16(_sum1, vget_low_s16(_s1)); + _sum2 = vaddw_s16(_sum2, vget_high_s16(_s0)); + _sum3 = vaddw_s16(_sum3, vget_high_s16(_s1)); +#else // __ARM_FEATURE_DOTPROD + int8x8_t _pA = vld1_s8(pA); + int8x8_t _pB0 = vreinterpret_s8_s16(vld1_dup_s16((const short*)pB)); + + // abcdefgh + + // 01010101 + // 10101010 + int8x8_t _pB1 = vext_s8(_pB0, _pB0, 1); + + int16x8_t _s0 = vmull_s8(_pA, _pB0); + int16x8_t _s1 = vmull_s8(_pA, _pB1); + _sum0 = vaddw_s16(_sum0, vget_low_s16(_s0)); + _sum1 = vaddw_s16(_sum1, vget_high_s16(_s0)); + _sum2 = vaddw_s16(_sum2, vget_low_s16(_s1)); + _sum3 = vaddw_s16(_sum3, vget_high_s16(_s1)); +#endif // __ARM_FEATURE_DOTPROD + + pA += 8; + pB += 2; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + vst1q_s32(outptr + 8, _sum2); + vst1q_s32(outptr + 12, _sum3); + + outptr += 16; + } + for (; jj < max_jj; jj += 1) + { + const signed char* pA = pAT; + + int32x4_t _sum0; + int32x4_t _sum1; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + } + + int kk = 0; +#if __ARM_FEATURE_DOTPROD + { +#if __ARM_FEATURE_MATMUL_INT8 + int32x4_t _s0 = vdupq_n_s32(0); + int32x4_t _s1 = vdupq_n_s32(0); + int32x4_t _s2 = vdupq_n_s32(0); + int32x4_t _s3 = vdupq_n_s32(0); +#endif // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 7 < max_kk; kk += 8) + { + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA1 = vld1q_s8(pA + 16); + int8x16_t _pA2 = vld1q_s8(pA + 32); + int8x16_t _pA3 = vld1q_s8(pA + 48); + + int8x8_t _pB = vld1_s8(pB); + +#if __ARM_FEATURE_MATMUL_INT8 + // aaaaaaaa bbbbbbbb ..... hhhhhhhh + // 00000000 + int8x16_t _pBB = vcombine_s8(_pB, _pB); + + _s0 = vdotq_s32(_s0, _pA0, _pBB); + _s1 = vdotq_s32(_s1, _pA1, _pBB); + _s2 = vdotq_s32(_s2, _pA2, _pBB); + _s3 = vdotq_s32(_s3, _pA3, _pBB); +#else // __ARM_FEATURE_MATMUL_INT8 + _sum0 = vdotq_lane_s32(_sum0, _pA0, _pB, 0); + _sum1 = vdotq_lane_s32(_sum1, _pA1, _pB, 0); + _sum0 = vdotq_lane_s32(_sum0, _pA2, _pB, 1); + _sum1 = vdotq_lane_s32(_sum1, _pA3, _pB, 1); +#endif // __ARM_FEATURE_MATMUL_INT8 + + pA += 64; + pB += 8; + } +#if __ARM_FEATURE_MATMUL_INT8 + _sum0 = vaddq_s32(_sum0, vpaddq_s32(_s0, _s1)); + _sum1 = vaddq_s32(_sum1, vpaddq_s32(_s2, _s3)); +#endif // __ARM_FEATURE_MATMUL_INT8 + } +#endif // __ARM_FEATURE_DOTPROD + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA1 = vld1q_s8(pA + 16); + + int8x8_t _pB = vreinterpret_s8_s32(vld1_dup_s32((const int*)pB)); + + // aaaa bbbb cccc dddd eeee ffff gggg hhhh + + // 0000 0000 + + _sum0 = vdotq_lane_s32(_sum0, _pA0, _pB, 0); + _sum1 = vdotq_lane_s32(_sum1, _pA1, _pB, 0); +#else // __ARM_FEATURE_DOTPROD + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA2 = vld1q_s8(pA + 16); + int8x8_t _pB0 = vreinterpret_s8_s16(vld1_dup_s16((const short*)pB)); + int8x8_t _pB1 = vreinterpret_s8_s16(vld1_dup_s16((const short*)(pB + 2))); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_pA0), _pB0); + int16x8_t _s1 = vmull_s8(vget_high_s8(_pA0), _pB0); + _s0 = vmlal_s8(_s0, vget_low_s8(_pA2), _pB1); + _s1 = vmlal_s8(_s1, vget_high_s8(_pA2), _pB1); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); +#endif // __ARM_FEATURE_DOTPROD + + pA += 32; + pB += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + int8x16_t _pA = vld1q_s8(pA); + int8x8_t _pB = vreinterpret_s8_s16(vld1_dup_s16((const short*)pB)); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_pA), _pB); + int16x8_t _s1 = vmull_s8(vget_high_s8(_pA), _pB); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + + pA += 16; + pB += 2; + } + for (; kk < max_kk; kk += 1) + { + int8x8_t _pA = vld1_s8(pA); + int8x8_t _pB = vld1_dup_s8(pB); + + int16x8_t _s0 = vmull_s8(_pA, _pB); + _sum0 = vaddw_s16(_sum0, vget_low_s16(_s0)); + _sum1 = vaddw_s16(_sum1, vget_high_s16(_s0)); + + pA += 8; + pB += 1; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + + outptr += 8; + } + + pAT += max_kk * 8; + } + for (; ii + 3 < max_ii; ii += 4) + { + const signed char* pB = pBT; + + int jj = 0; +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + const signed char* pA = pAT; + +#if NCNN_GNU_INLINE_ASM + asm volatile( + "cmp %w7, #0 \n" + "beq 0f \n" + + "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%0], #64 \n" + "ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%0] \n" + "sub %0, %0, #64 \n" + "b 1f \n" + + "0: \n" + "eor v16.16b, v16.16b, v16.16b \n" + "eor v17.16b, v17.16b, v17.16b \n" + "eor v18.16b, v18.16b, v18.16b \n" + "eor v19.16b, v19.16b, v19.16b \n" + "eor v20.16b, v20.16b, v20.16b \n" + "eor v21.16b, v21.16b, v21.16b \n" + "eor v22.16b, v22.16b, v22.16b \n" + "eor v23.16b, v23.16b, v23.16b \n" + + "1: \n" +#if __ARM_FEATURE_DOTPROD + "lsr w4, %w6, #3 \n" // w4 = max_kk >> 3 + "cmp w4, #0 \n" + "beq 101f \n" + +#if __ARM_FEATURE_MATMUL_INT8 + "eor v24.16b, v24.16b, v24.16b \n" + "eor v25.16b, v25.16b, v25.16b \n" + "eor v26.16b, v26.16b, v26.16b \n" + "eor v27.16b, v27.16b, v27.16b \n" + "eor v28.16b, v28.16b, v28.16b \n" + "eor v29.16b, v29.16b, v29.16b \n" + "eor v30.16b, v30.16b, v30.16b \n" + "eor v31.16b, v31.16b, v31.16b \n" +#endif // __ARM_FEATURE_MATMUL_INT8 + + "2: \n" + "ld1 {v0.16b, v1.16b}, [%1], #32 \n" + "ld1 {v2.16b, v3.16b, v4.16b, v5.16b}, [%2], #64 \n" + +#if __ARM_FEATURE_MATMUL_INT8 + "smmla v24.4s, v0.16b, v2.16b \n" + "smmla v25.4s, v1.16b, v2.16b \n" + "smmla v26.4s, v0.16b, v3.16b \n" + "smmla v27.4s, v1.16b, v3.16b \n" + "subs w4, w4, #1 \n" + "smmla v28.4s, v0.16b, v4.16b \n" + "smmla v29.4s, v1.16b, v4.16b \n" + "smmla v30.4s, v0.16b, v5.16b \n" + "smmla v31.4s, v1.16b, v5.16b \n" +#else // __ARM_FEATURE_MATMUL_INT8 + "sdot v16.4s, v0.16b, v2.4b[0] \n" + "sdot v17.4s, v0.16b, v2.4b[1] \n" + "sdot v18.4s, v0.16b, v2.4b[2] \n" + "sdot v19.4s, v0.16b, v2.4b[3] \n" + "sdot v20.4s, v0.16b, v3.4b[0] \n" + "sdot v21.4s, v0.16b, v3.4b[1] \n" + "sdot v22.4s, v0.16b, v3.4b[2] \n" + "sdot v23.4s, v0.16b, v3.4b[3] \n" + "subs w4, w4, #1 \n" + "sdot v16.4s, v1.16b, v4.4b[0] \n" + "sdot v17.4s, v1.16b, v4.4b[1] \n" + "sdot v18.4s, v1.16b, v4.4b[2] \n" + "sdot v19.4s, v1.16b, v4.4b[3] \n" + "sdot v20.4s, v1.16b, v5.4b[0] \n" + "sdot v21.4s, v1.16b, v5.4b[1] \n" + "sdot v22.4s, v1.16b, v5.4b[2] \n" + "sdot v23.4s, v1.16b, v5.4b[3] \n" +#endif // __ARM_FEATURE_MATMUL_INT8 + "bne 2b \n" + +#if __ARM_FEATURE_MATMUL_INT8 + "uzp1 v0.4s, v24.4s, v25.4s \n" + "uzp2 v1.4s, v24.4s, v25.4s \n" + "uzp1 v2.4s, v26.4s, v27.4s \n" + "uzp2 v3.4s, v26.4s, v27.4s \n" + "uzp1 v4.4s, v28.4s, v29.4s \n" + "uzp2 v5.4s, v28.4s, v29.4s \n" + "uzp1 v6.4s, v30.4s, v31.4s \n" + "uzp2 v7.4s, v30.4s, v31.4s \n" + + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "add v18.4s, v18.4s, v2.4s \n" + "add v19.4s, v19.4s, v3.4s \n" + "add v20.4s, v20.4s, v4.4s \n" + "add v21.4s, v21.4s, v5.4s \n" + "add v22.4s, v22.4s, v6.4s \n" + "add v23.4s, v23.4s, v7.4s \n" +#endif // __ARM_FEATURE_MATMUL_INT8 + + "101: \n" + "and w4, %w6, #4 \n" // w4 = remain = max_kk & 4 + "cmp w4, #0 \n" + "beq 3f \n" + + // kk += 4 part + "ld1 {v0.16b}, [%1], #16 \n" + "ld1 {v2.16b, v3.16b}, [%2], #32 \n" + "sdot v16.4s, v0.16b, v2.4b[0] \n" + "sdot v17.4s, v0.16b, v2.4b[1] \n" + "sdot v18.4s, v0.16b, v2.4b[2] \n" + "sdot v19.4s, v0.16b, v2.4b[3] \n" + "sdot v20.4s, v0.16b, v3.4b[0] \n" + "sdot v21.4s, v0.16b, v3.4b[1] \n" + "sdot v22.4s, v0.16b, v3.4b[2] \n" + "sdot v23.4s, v0.16b, v3.4b[3] \n" +#else // __ARM_FEATURE_DOTPROD + "lsr w4, %w6, #2 \n" // w4 = max_kk >> 2 + "cmp w4, #0 \n" + "beq 3f \n" + + "2: \n" + "ld1 {v0.16b}, [%1], #16 \n" + "ld1 {v4.16b, v5.16b}, [%2], #32 \n" + "smull v8.8h, v0.8b, v4.8b \n" + "smull2 v9.8h, v0.16b, v5.16b \n" + "rev64 v2.4s, v0.4s \n" + "smull v10.8h, v2.8b, v4.8b \n" + "smull2 v11.8h, v2.16b, v5.16b \n" + "rev64 v6.8h, v4.8h \n" + "smull v12.8h, v0.8b, v6.8b \n" + "smull v14.8h, v2.8b, v6.8b \n" + "rev64 v7.8h, v5.8h \n" + "smull2 v13.8h, v0.16b, v7.16b \n" + "smull2 v15.8h, v2.16b, v7.16b \n" + "ext v1.16b, v0.16b, v0.16b, #8 \n" + "ext v3.16b, v2.16b, v2.16b, #8 \n" + "smlal v8.8h, v1.8b, v5.8b \n" + "smlal2 v9.8h, v1.16b, v4.16b \n" + "smlal v10.8h, v3.8b, v5.8b \n" + "smlal2 v11.8h, v3.16b, v4.16b \n" + "smlal v12.8h, v1.8b, v7.8b \n" + "smlal v14.8h, v3.8b, v7.8b \n" + "smlal2 v13.8h, v1.16b, v6.16b \n" + "smlal2 v15.8h, v3.16b, v6.16b \n" + "subs w4, w4, #1 \n" + "sadalp v16.4s, v8.8h \n" + "sadalp v17.4s, v9.8h \n" + "sadalp v18.4s, v10.8h \n" + "sadalp v19.4s, v11.8h \n" + "sadalp v20.4s, v12.8h \n" + "sadalp v22.4s, v14.8h \n" + "sadalp v21.4s, v13.8h \n" + "sadalp v23.4s, v15.8h \n" + "bne 2b \n" +#endif // __ARM_FEATURE_DOTPROD + + "3: \n" + "and w4, %w6, #2 \n" // w4 = remain = max_kk & 2 + "cmp w4, #0 \n" + "beq 4f \n" + + // kk += 2 part +#if __ARM_FEATURE_DOTPROD + "ld1 {v0.8b}, [%1], #8 \n" + "ld1 {v1.16b}, [%2], #16 \n" + "dup v4.8h, v1.h[0] \n" + "dup v5.8h, v1.h[1] \n" + "dup v6.8h, v1.h[2] \n" + "dup v7.8h, v1.h[3] \n" + "smull v8.8h, v0.8b, v4.8b \n" + "smull v9.8h, v0.8b, v5.8b \n" + "smull v10.8h, v0.8b, v6.8b \n" + "smull v11.8h, v0.8b, v7.8b \n" + "dup v4.8h, v1.h[4] \n" + "dup v5.8h, v1.h[5] \n" + "dup v6.8h, v1.h[6] \n" + "dup v7.8h, v1.h[7] \n" + "smull v12.8h, v0.8b, v4.8b \n" + "smull v13.8h, v0.8b, v5.8b \n" + "smull v14.8h, v0.8b, v6.8b \n" + "smull v15.8h, v0.8b, v7.8b \n" + "sadalp v16.4s, v8.8h \n" + "sadalp v17.4s, v9.8h \n" + "sadalp v18.4s, v10.8h \n" + "sadalp v19.4s, v11.8h \n" + "sadalp v20.4s, v12.8h \n" + "sadalp v21.4s, v13.8h \n" + "sadalp v22.4s, v14.8h \n" + "sadalp v23.4s, v15.8h \n" +#else // __ARM_FEATURE_DOTPROD + "ld1r {v0.2d}, [%1] \n" + "add %1, %1, #8 \n" + "ld1 {v2.16b}, [%2], #16 \n" + "rev64 v1.4s, v0.4s \n" + "rev64 v3.8h, v2.8h \n" + "smull v8.8h, v0.8b, v2.8b \n" + "smull2 v9.8h, v0.16b, v2.16b \n" + "smull v10.8h, v1.8b, v2.8b \n" + "smull2 v11.8h, v1.16b, v2.16b \n" + "smull v12.8h, v0.8b, v3.8b \n" + "smull2 v13.8h, v0.16b, v3.16b \n" + "smull v14.8h, v1.8b, v3.8b \n" + "smull2 v15.8h, v1.16b, v3.16b \n" + "sadalp v16.4s, v8.8h \n" + "sadalp v17.4s, v9.8h \n" + "sadalp v18.4s, v10.8h \n" + "sadalp v19.4s, v11.8h \n" + "sadalp v20.4s, v12.8h \n" + "sadalp v21.4s, v13.8h \n" + "sadalp v22.4s, v14.8h \n" + "sadalp v23.4s, v15.8h \n" +#endif // __ARM_FEATURE_DOTPROD + + "4: \n" + "and w4, %w6, #1 \n" // w4 = remain = max_kk & 1 + "cmp w4, #0 \n" + "beq 5f \n" + + // kk += 1 part +#if __ARM_FEATURE_DOTPROD + "ld1r {v0.2s}, [%1] \n" + "ld1 {v1.8b}, [%2], #8 \n" + "add %1, %1, #4 \n" + "dup v8.8h, v1.h[0] \n" + "dup v9.8h, v1.h[1] \n" + "dup v10.8h, v1.h[2] \n" + "dup v11.8h, v1.h[3] \n" + "uzp1 v2.8b, v8.8b, v9.8b \n" + "uzp2 v3.8b, v8.8b, v9.8b \n" + "uzp1 v4.8b, v10.8b, v11.8b \n" + "uzp2 v5.8b, v10.8b, v11.8b \n" + "smull v8.8h, v0.8b, v2.8b \n" + "smull v9.8h, v0.8b, v3.8b \n" + "smull v10.8h, v0.8b, v4.8b \n" + "smull v11.8h, v0.8b, v5.8b \n" + "saddw v16.4s, v16.4s, v8.4h \n" + "saddw v17.4s, v17.4s, v9.4h \n" + "saddw2 v18.4s, v18.4s, v8.8h \n" + "saddw2 v19.4s, v19.4s, v9.8h \n" + "saddw v20.4s, v20.4s, v10.4h \n" + "saddw v21.4s, v21.4s, v11.4h \n" + "saddw2 v22.4s, v22.4s, v10.8h \n" + "saddw2 v23.4s, v23.4s, v11.8h \n" +#else // __ARM_FEATURE_DOTPROD + "ld1r {v0.2s}, [%1] \n" + "ld1 {v2.8b}, [%2], #8 \n" + "add %1, %1, #4 \n" + "ext v1.8b, v0.8b, v0.8b, #2 \n" + "rev32 v3.8b, v2.8b \n" + "smull v8.8h, v0.8b, v2.8b \n" + "smull v9.8h, v1.8b, v2.8b \n" + "smull v10.8h, v0.8b, v3.8b \n" + "smull v11.8h, v1.8b, v3.8b \n" + "saddw v16.4s, v16.4s, v8.4h \n" + "saddw2 v17.4s, v17.4s, v8.8h \n" + "saddw v18.4s, v18.4s, v9.4h \n" + "saddw2 v19.4s, v19.4s, v9.8h \n" + "saddw v20.4s, v20.4s, v10.4h \n" + "saddw2 v21.4s, v21.4s, v10.8h \n" + "saddw v22.4s, v22.4s, v11.4h \n" + "saddw2 v23.4s, v23.4s, v11.8h \n" +#endif // __ARM_FEATURE_DOTPROD + + "5: \n" + "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%0], #64 \n" + "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%0], #64 \n" + + : "=r"(outptr), // %0 + "=r"(pA), // %1 + "=r"(pB) // %2 + : "0"(outptr), + "1"(pA), + "2"(pB), + "r"(max_kk), // %6 + "r"(k) // %7 + : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +#else // NCNN_GNU_INLINE_ASM + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + int32x4_t _sum3; + int32x4_t _sum4; + int32x4_t _sum5; + int32x4_t _sum6; + int32x4_t _sum7; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + _sum4 = vdupq_n_s32(0); + _sum5 = vdupq_n_s32(0); + _sum6 = vdupq_n_s32(0); + _sum7 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + _sum2 = vld1q_s32(outptr + 8); + _sum3 = vld1q_s32(outptr + 12); + _sum4 = vld1q_s32(outptr + 16); + _sum5 = vld1q_s32(outptr + 20); + _sum6 = vld1q_s32(outptr + 24); + _sum7 = vld1q_s32(outptr + 28); + } + + int kk = 0; +#if __ARM_FEATURE_MATMUL_INT8 + { + int32x4_t _sum00 = vdupq_n_s32(0); + int32x4_t _sum01 = vdupq_n_s32(0); + int32x4_t _sum10 = vdupq_n_s32(0); + int32x4_t _sum11 = vdupq_n_s32(0); + int32x4_t _sum20 = vdupq_n_s32(0); + int32x4_t _sum21 = vdupq_n_s32(0); + int32x4_t _sum30 = vdupq_n_s32(0); + int32x4_t _sum31 = vdupq_n_s32(0); + for (; kk + 7 < max_kk; kk += 8) + { + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA1 = vld1q_s8(pA + 16); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + int8x16_t _pB2 = vld1q_s8(pB + 32); + int8x16_t _pB3 = vld1q_s8(pB + 48); + + // aaaaaaaa bbbbbbbb cccccccc dddddddd + + // 00000000 11111111 22222222 33333333 + // 44444444 55555555 66666666 77777777 + + _sum00 = vmmlaq_s32(_sum00, _pA0, _pB0); + _sum01 = vmmlaq_s32(_sum01, _pA1, _pB0); + _sum10 = vmmlaq_s32(_sum10, _pA0, _pB1); + _sum11 = vmmlaq_s32(_sum11, _pA1, _pB1); + _sum20 = vmmlaq_s32(_sum20, _pA0, _pB2); + _sum21 = vmmlaq_s32(_sum21, _pA1, _pB2); + _sum30 = vmmlaq_s32(_sum30, _pA0, _pB3); + _sum31 = vmmlaq_s32(_sum31, _pA1, _pB3); + + // a0 a1 b0 b1 + // c0 c1 d0 d1 + // a2 a3 b2 b3 + // c2 c3 d2 d3 + // a4 a5 b4 b5 + // c4 c5 d4 d5 + // a6 a7 b6 b7 + // c6 c7 d6 d7 + + pA += 32; + pB += 64; + } + int32x4x2_t _ss0 = vuzpq_s32(_sum00, _sum01); + int32x4x2_t _ss1 = vuzpq_s32(_sum10, _sum11); + int32x4x2_t _ss2 = vuzpq_s32(_sum20, _sum21); + int32x4x2_t _ss3 = vuzpq_s32(_sum30, _sum31); + _sum0 = vaddq_s32(_sum0, _ss0.val[0]); + _sum1 = vaddq_s32(_sum1, _ss0.val[1]); + _sum2 = vaddq_s32(_sum2, _ss1.val[0]); + _sum3 = vaddq_s32(_sum3, _ss1.val[1]); + _sum4 = vaddq_s32(_sum4, _ss2.val[0]); + _sum5 = vaddq_s32(_sum5, _ss2.val[1]); + _sum6 = vaddq_s32(_sum6, _ss3.val[0]); + _sum7 = vaddq_s32(_sum7, _ss3.val[1]); + } +#elif __ARM_FEATURE_DOTPROD + for (; kk + 7 < max_kk; kk += 8) + { + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA1 = vld1q_s8(pA + 16); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + int8x16_t _pB2 = vld1q_s8(pB + 32); + int8x16_t _pB3 = vld1q_s8(pB + 48); + + _sum0 = vdotq_laneq_s32(_sum0, _pA0, _pB0, 0); + _sum1 = vdotq_laneq_s32(_sum1, _pA0, _pB0, 1); + _sum2 = vdotq_laneq_s32(_sum2, _pA0, _pB0, 2); + _sum3 = vdotq_laneq_s32(_sum3, _pA0, _pB0, 3); + _sum4 = vdotq_laneq_s32(_sum4, _pA0, _pB1, 0); + _sum5 = vdotq_laneq_s32(_sum5, _pA0, _pB1, 1); + _sum6 = vdotq_laneq_s32(_sum6, _pA0, _pB1, 2); + _sum7 = vdotq_laneq_s32(_sum7, _pA0, _pB1, 3); + + _sum0 = vdotq_laneq_s32(_sum0, _pA1, _pB2, 0); + _sum1 = vdotq_laneq_s32(_sum1, _pA1, _pB2, 1); + _sum2 = vdotq_laneq_s32(_sum2, _pA1, _pB2, 2); + _sum3 = vdotq_laneq_s32(_sum3, _pA1, _pB2, 3); + _sum4 = vdotq_laneq_s32(_sum4, _pA1, _pB3, 0); + _sum5 = vdotq_laneq_s32(_sum5, _pA1, _pB3, 1); + _sum6 = vdotq_laneq_s32(_sum6, _pA1, _pB3, 2); + _sum7 = vdotq_laneq_s32(_sum7, _pA1, _pB3, 3); + + pA += 32; + pB += 64; + } +#endif // __ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + int8x16_t _pA = vld1q_s8(pA); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + + _sum0 = vdotq_laneq_s32(_sum0, _pA, _pB0, 0); + _sum1 = vdotq_laneq_s32(_sum1, _pA, _pB0, 1); + _sum2 = vdotq_laneq_s32(_sum2, _pA, _pB0, 2); + _sum3 = vdotq_laneq_s32(_sum3, _pA, _pB0, 3); + _sum4 = vdotq_laneq_s32(_sum4, _pA, _pB1, 0); + _sum5 = vdotq_laneq_s32(_sum5, _pA, _pB1, 1); + _sum6 = vdotq_laneq_s32(_sum6, _pA, _pB1, 2); + _sum7 = vdotq_laneq_s32(_sum7, _pA, _pB1, 3); +#else // __ARM_FEATURE_DOTPROD + int8x16_t _pA02 = vld1q_s8(pA); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB2 = vld1q_s8(pB + 16); + + int8x16_t _pA13 = vreinterpretq_s8_s32(vrev64q_s32(vreinterpretq_s32_s8(_pA02))); + + int8x16_t _pB1 = vreinterpretq_s8_s16(vrev64q_s16(vreinterpretq_s16_s8(_pB0))); + int8x16_t _pB3 = vreinterpretq_s8_s16(vrev64q_s16(vreinterpretq_s16_s8(_pB2))); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_pA02), vget_low_s8(_pB0)); + int16x8_t _s1 = vmull_s8(vget_low_s8(_pA02), vget_high_s8(_pB0)); + int16x8_t _s2 = vmull_s8(vget_low_s8(_pA13), vget_low_s8(_pB0)); + int16x8_t _s3 = vmull_s8(vget_low_s8(_pA13), vget_high_s8(_pB0)); + int16x8_t _s4 = vmull_s8(vget_low_s8(_pA02), vget_low_s8(_pB1)); + int16x8_t _s5 = vmull_s8(vget_low_s8(_pA02), vget_high_s8(_pB1)); + int16x8_t _s6 = vmull_s8(vget_low_s8(_pA13), vget_low_s8(_pB1)); + int16x8_t _s7 = vmull_s8(vget_low_s8(_pA13), vget_high_s8(_pB1)); + + _s0 = vmlal_s8(_s0, vget_high_s8(_pA02), vget_low_s8(_pB2)); + _s1 = vmlal_s8(_s1, vget_high_s8(_pA02), vget_high_s8(_pB2)); + _s2 = vmlal_s8(_s2, vget_high_s8(_pA13), vget_low_s8(_pB2)); + _s3 = vmlal_s8(_s3, vget_high_s8(_pA13), vget_high_s8(_pB2)); + _s4 = vmlal_s8(_s4, vget_high_s8(_pA02), vget_low_s8(_pB3)); + _s5 = vmlal_s8(_s5, vget_high_s8(_pA02), vget_high_s8(_pB3)); + _s6 = vmlal_s8(_s6, vget_high_s8(_pA13), vget_low_s8(_pB3)); + _s7 = vmlal_s8(_s7, vget_high_s8(_pA13), vget_high_s8(_pB3)); + + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + _sum4 = vpadalq_s16(_sum4, _s4); + _sum5 = vpadalq_s16(_sum5, _s5); + _sum6 = vpadalq_s16(_sum6, _s6); + _sum7 = vpadalq_s16(_sum7, _s7); +#endif // __ARM_FEATURE_DOTPROD + + pA += 16; + pB += 32; + } + for (; kk + 1 < max_kk; kk += 2) + { +#if __ARM_FEATURE_DOTPROD + int8x8_t _pA0 = vld1_s8(pA); + int8x16_t _pB01 = vld1q_s8(pB); + + // aabbccdd + + // 00112233 44556677 + + int16x8_t _s0 = vmull_s8(_pA0, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pB01)), 0))); + int16x8_t _s1 = vmull_s8(_pA0, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pB01)), 1))); + int16x8_t _s2 = vmull_s8(_pA0, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pB01)), 2))); + int16x8_t _s3 = vmull_s8(_pA0, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pB01)), 3))); + int16x8_t _s4 = vmull_s8(_pA0, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pB01)), 0))); + int16x8_t _s5 = vmull_s8(_pA0, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pB01)), 1))); + int16x8_t _s6 = vmull_s8(_pA0, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pB01)), 2))); + int16x8_t _s7 = vmull_s8(_pA0, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pB01)), 3))); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + _sum4 = vpadalq_s16(_sum4, _s4); + _sum5 = vpadalq_s16(_sum5, _s5); + _sum6 = vpadalq_s16(_sum6, _s6); + _sum7 = vpadalq_s16(_sum7, _s7); +#else // __ARM_FEATURE_DOTPROD + int8x8_t _pA0 = vld1_s8(pA); + int8x16_t _pB0 = vld1q_s8(pB); + + // aabbccdd + // ccddaabb + + int8x8_t _pA1 = vreinterpret_s8_s32(vrev64_s32(vreinterpret_s32_s8(_pA0))); + + // 00112233 44556677 + // 33221100 77665544 + + int8x16_t _pB1 = vreinterpretq_s8_s16(vrev64q_s16(vreinterpretq_s16_s8(_pB0))); + + int16x8_t _s0 = vmull_s8(_pA0, vget_low_s8(_pB0)); + int16x8_t _s1 = vmull_s8(_pA0, vget_high_s8(_pB0)); + int16x8_t _s2 = vmull_s8(_pA1, vget_low_s8(_pB0)); + int16x8_t _s3 = vmull_s8(_pA1, vget_high_s8(_pB0)); + int16x8_t _s4 = vmull_s8(_pA0, vget_low_s8(_pB1)); + int16x8_t _s5 = vmull_s8(_pA0, vget_high_s8(_pB1)); + int16x8_t _s6 = vmull_s8(_pA1, vget_low_s8(_pB1)); + int16x8_t _s7 = vmull_s8(_pA1, vget_high_s8(_pB1)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + _sum4 = vpadalq_s16(_sum4, _s4); + _sum5 = vpadalq_s16(_sum5, _s5); + _sum6 = vpadalq_s16(_sum6, _s6); + _sum7 = vpadalq_s16(_sum7, _s7); +#endif // __ARM_FEATURE_DOTPROD + + pA += 8; + pB += 16; + } + for (; kk < max_kk; kk += 1) + { +#if __ARM_FEATURE_DOTPROD + int8x8_t _pAA = vreinterpret_s8_s32(vld1_dup_s32((const int*)pA)); + int8x8_t _pB = vld1_s8(pB); + + // abcdabcd + // 01234567 -> 01010101 23232323 45454545 67676767 + int8x8_t _pB0 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pB), 0)); + int8x8_t _pB2 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pB), 1)); + int8x8_t _pB4 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pB), 2)); + int8x8_t _pB6 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pB), 3)); + + int8x8x2_t _pB0123 = vuzp_s8(_pB0, _pB2); + int8x8x2_t _pB4567 = vuzp_s8(_pB4, _pB6); + + int16x8_t _s02 = vmull_s8(_pAA, _pB0123.val[0]); + int16x8_t _s13 = vmull_s8(_pAA, _pB0123.val[1]); + int16x8_t _s46 = vmull_s8(_pAA, _pB4567.val[0]); + int16x8_t _s57 = vmull_s8(_pAA, _pB4567.val[1]); + _sum0 = vaddw_s16(_sum0, vget_low_s16(_s02)); + _sum1 = vaddw_s16(_sum1, vget_low_s16(_s13)); + _sum2 = vaddw_s16(_sum2, vget_high_s16(_s02)); + _sum3 = vaddw_s16(_sum3, vget_high_s16(_s13)); + _sum4 = vaddw_s16(_sum4, vget_low_s16(_s46)); + _sum5 = vaddw_s16(_sum5, vget_low_s16(_s57)); + _sum6 = vaddw_s16(_sum6, vget_high_s16(_s46)); + _sum7 = vaddw_s16(_sum7, vget_high_s16(_s57)); +#else // __ARM_FEATURE_DOTPROD + int8x8_t _pA0 = vreinterpret_s8_s32(vld1_dup_s32((const int*)pA)); + int8x8_t _pB0 = vld1_s8(pB); + + // abcd abcd + // cdab cdab + + int8x8_t _pA1 = vext_s8(_pA0, _pA0, 2); + + // 0123 4567 + // 3210 7654 + + int8x8_t _pB1 = vrev32_s8(_pB0); + + int16x8_t _s01 = vmull_s8(_pA0, _pB0); + int16x8_t _s23 = vmull_s8(_pA1, _pB0); + int16x8_t _s45 = vmull_s8(_pA0, _pB1); + int16x8_t _s67 = vmull_s8(_pA1, _pB1); + _sum0 = vaddw_s16(_sum0, vget_low_s16(_s01)); + _sum1 = vaddw_s16(_sum1, vget_high_s16(_s01)); + _sum2 = vaddw_s16(_sum2, vget_low_s16(_s23)); + _sum3 = vaddw_s16(_sum3, vget_high_s16(_s23)); + _sum4 = vaddw_s16(_sum4, vget_low_s16(_s45)); + _sum5 = vaddw_s16(_sum5, vget_high_s16(_s45)); + _sum6 = vaddw_s16(_sum6, vget_low_s16(_s67)); + _sum7 = vaddw_s16(_sum7, vget_high_s16(_s67)); +#endif // __ARM_FEATURE_DOTPROD + + pA += 4; + pB += 8; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + vst1q_s32(outptr + 8, _sum2); + vst1q_s32(outptr + 12, _sum3); + vst1q_s32(outptr + 16, _sum4); + vst1q_s32(outptr + 20, _sum5); + vst1q_s32(outptr + 24, _sum6); + vst1q_s32(outptr + 28, _sum7); + + outptr += 32; +#endif // NCNN_GNU_INLINE_ASM + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + const signed char* pA = pAT; + +#if NCNN_GNU_INLINE_ASM +#if __aarch64__ + asm volatile( + "cmp %w7, #0 \n" + "beq 0f \n" + + "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%0] \n" + "b 1f \n" + + "0: \n" + "eor v16.16b, v16.16b, v16.16b \n" + "eor v17.16b, v17.16b, v17.16b \n" + "eor v18.16b, v18.16b, v18.16b \n" + "eor v19.16b, v19.16b, v19.16b \n" + + "1: \n" +#if __ARM_FEATURE_DOTPROD + "lsr w4, %w6, #3 \n" // w4 = max_kk >> 3 + "cmp w4, #0 \n" + "beq 101f \n" + +#if __ARM_FEATURE_MATMUL_INT8 + "eor v24.16b, v24.16b, v24.16b \n" + "eor v25.16b, v25.16b, v25.16b \n" + "eor v26.16b, v26.16b, v26.16b \n" + "eor v27.16b, v27.16b, v27.16b \n" +#endif // __ARM_FEATURE_MATMUL_INT8 + + "2: \n" + "ld1 {v0.16b, v1.16b}, [%1], #32 \n" + "ld1 {v4.16b, v5.16b}, [%2], #32 \n" + +#if __ARM_FEATURE_MATMUL_INT8 + "smmla v24.4s, v0.16b, v4.16b \n" + "smmla v25.4s, v1.16b, v4.16b \n" + "subs w4, w4, #1 \n" + "smmla v26.4s, v0.16b, v5.16b \n" + "smmla v27.4s, v1.16b, v5.16b \n" +#else // __ARM_FEATURE_MATMUL_INT8 + "sdot v16.4s, v0.16b, v4.4b[0] \n" + "sdot v17.4s, v0.16b, v4.4b[1] \n" + "sdot v18.4s, v0.16b, v4.4b[2] \n" + "sdot v19.4s, v0.16b, v4.4b[3] \n" + "subs w4, w4, #1 \n" + "sdot v16.4s, v1.16b, v5.4b[0] \n" + "sdot v17.4s, v1.16b, v5.4b[1] \n" + "sdot v18.4s, v1.16b, v5.4b[2] \n" + "sdot v19.4s, v1.16b, v5.4b[3] \n" +#endif // __ARM_FEATURE_MATMUL_INT8 + "bne 2b \n" + +#if __ARM_FEATURE_MATMUL_INT8 + "uzp1 v0.4s, v24.4s, v25.4s \n" + "uzp2 v1.4s, v24.4s, v25.4s \n" + "uzp1 v2.4s, v26.4s, v27.4s \n" + "uzp2 v3.4s, v26.4s, v27.4s \n" + + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "add v18.4s, v18.4s, v2.4s \n" + "add v19.4s, v19.4s, v3.4s \n" +#endif // __ARM_FEATURE_MATMUL_INT8 + + "101: \n" + "and w4, %w6, #4 \n" // w4 = remain = max_kk & 4 + "cmp w4, #0 \n" + "beq 3f \n" + + // kk += 4 part + "ld1 {v0.16b}, [%1], #16 \n" + "ld1 {v2.16b}, [%2], #16 \n" + "sdot v16.4s, v0.16b, v2.4b[0] \n" + "sdot v17.4s, v0.16b, v2.4b[1] \n" + "sdot v18.4s, v0.16b, v2.4b[2] \n" + "sdot v19.4s, v0.16b, v2.4b[3] \n" +#else // __ARM_FEATURE_DOTPROD + "lsr w4, %w6, #2 \n" // w4 = max_kk >> 2 + "cmp w4, #0 \n" + "beq 3f \n" + + "2: \n" + "ld1 {v0.16b}, [%1], #16 \n" + "ld1 {v4.16b}, [%2], #16 \n" + "smull v8.8h, v0.8b, v4.8b \n" + "rev64 v1.4s, v0.4s \n" + "smull v9.8h, v1.8b, v4.8b \n" + "rev64 v5.8h, v4.8h \n" + "smull v10.8h, v0.8b, v5.8b \n" + "smull v11.8h, v1.8b, v5.8b \n" + "smlal2 v8.8h, v0.16b, v4.16b \n" + "smlal2 v9.8h, v1.16b, v4.16b \n" + "smlal2 v10.8h, v0.16b, v5.16b \n" + "smlal2 v11.8h, v1.16b, v5.16b \n" + "subs w4, w4, #1 \n" + "sadalp v16.4s, v8.8h \n" + "sadalp v17.4s, v9.8h \n" + "sadalp v18.4s, v10.8h \n" + "sadalp v19.4s, v11.8h \n" + "bne 2b \n" +#endif // __ARM_FEATURE_DOTPROD + + "3: \n" + "and w4, %w6, #2 \n" // w4 = remain = max_kk & 2 + "cmp w4, #0 \n" + "beq 4f \n" + + // kk += 2 part +#if __ARM_FEATURE_DOTPROD + "ld1 {v0.8b}, [%1], #8 \n" + "ld1 {v1.8b}, [%2], #8 \n" + "dup v4.4h, v1.h[0] \n" + "dup v5.4h, v1.h[1] \n" + "dup v6.4h, v1.h[2] \n" + "dup v7.4h, v1.h[3] \n" + "smull v8.8h, v0.8b, v4.8b \n" + "smull v9.8h, v0.8b, v5.8b \n" + "smull v10.8h, v0.8b, v6.8b \n" + "smull v11.8h, v0.8b, v7.8b \n" + "sadalp v16.4s, v8.8h \n" + "sadalp v17.4s, v9.8h \n" + "sadalp v18.4s, v10.8h \n" + "sadalp v19.4s, v11.8h \n" +#else // __ARM_FEATURE_DOTPROD + "ld1 {v0.8b}, [%1], #8 \n" + "ld1 {v2.8b}, [%2], #8 \n" + "ext v1.8b, v0.8b, v0.8b, #4 \n" + "rev64 v3.4h, v2.4h \n" + "smull v8.8h, v0.8b, v2.8b \n" + "smull v9.8h, v1.8b, v2.8b \n" + "smull v10.8h, v0.8b, v3.8b \n" + "smull v11.8h, v1.8b, v3.8b \n" + "sadalp v16.4s, v8.8h \n" + "sadalp v17.4s, v9.8h \n" + "sadalp v18.4s, v10.8h \n" + "sadalp v19.4s, v11.8h \n" +#endif // __ARM_FEATURE_DOTPROD + + "4: \n" + "and w4, %w6, #1 \n" // w4 = remain = max_kk & 1 + "cmp w4, #0 \n" + "beq 5f \n" + + // kk += 1 part +#if __ARM_FEATURE_DOTPROD + "ld1r {v0.2s}, [%1] \n" + "ld1r {v1.2s}, [%2] \n" + "add %1, %1, #4 \n" + "add %2, %2, #4 \n" + "zip1 v1.8b, v1.8b, v1.8b \n" + "zip1 v2.4h, v1.4h, v1.4h \n" + "zip2 v3.4h, v1.4h, v1.4h \n" + "smull v8.8h, v0.8b, v2.8b \n" + "smull v9.8h, v0.8b, v3.8b \n" + "saddw v16.4s, v16.4s, v8.4h \n" + "saddw2 v17.4s, v17.4s, v8.8h \n" + "saddw v18.4s, v18.4s, v9.4h \n" + "saddw2 v19.4s, v19.4s, v9.8h \n" +#else // __ARM_FEATURE_DOTPROD + "ld1 {v0.8b}, [%1] \n" + "ld1r {v4.2s}, [%2] \n" + "add %1, %1, #4 \n" + "add %2, %2, #4 \n" + "rev32 v1.4h, v0.4h \n" + "zip1 v0.2s, v0.2s, v1.2s \n" + "rev32 v5.8b, v4.8b \n" + "smull v8.8h, v0.8b, v4.8b \n" + "smull v9.8h, v0.8b, v5.8b \n" + "saddw v16.4s, v16.4s, v8.4h \n" + "saddw2 v17.4s, v17.4s, v8.8h \n" + "saddw v18.4s, v18.4s, v9.4h \n" + "saddw2 v19.4s, v19.4s, v9.8h \n" +#endif // __ARM_FEATURE_DOTPROD + + "5: \n" + "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%0], #64 \n" + + : "=r"(outptr), // %0 + "=r"(pA), // %1 + "=r"(pB) // %2 + : "0"(outptr), + "1"(pA), + "2"(pB), + "r"(max_kk), // %6 + "r"(k) // %7 + : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +#else // __aarch64__ + asm volatile( + "cmp %7, #0 \n" + "beq 0f \n" + + "vldm %0, {d16-d23} \n" + "b 1f \n" + + "0: \n" + "veor q8, q8 \n" + "veor q9, q9 \n" + "veor q10, q10 \n" + "veor q11, q11 \n" + + "1: \n" + "lsr r4, %6, #2 \n" // r4 = max_kk >> 2 + "cmp r4, #0 \n" + "beq 3f \n" + + ".align 4 \n" + "2: \n" + "pld [%1, #256] \n" + "vld1.s8 {d0-d1}, [%1 :64]! \n" + "pld [%2, #128] \n" + "vld1.s8 {d4-d5}, [%2]! \n" + "vrev64.32 q1, q0 \n" + "vmull.s8 q4, d0, d4 \n" + "vrev64.16 q3, q2 \n" + "vmull.s8 q5, d2, d4 \n" + "vmull.s8 q6, d0, d6 \n" + "vmull.s8 q7, d2, d6 \n" + "vmlal.s8 q4, d1, d5 \n" + "vmlal.s8 q5, d3, d5 \n" + "vmlal.s8 q6, d1, d7 \n" + "vmlal.s8 q7, d3, d7 \n" + "subs r4, r4, #1 \n" + "vpadal.s16 q8, q4 \n" + "vpadal.s16 q9, q5 \n" + "vpadal.s16 q10, q6 \n" + "vpadal.s16 q11, q7 \n" + "bne 2b \n" + + "3: \n" + "and r4, %6, #2 \n" // r4 = remain = max_kk & 2 + "cmp r4, #0 \n" + "beq 4f \n" + + // kk += 2 part + "vld1.s8 {d0}, [%1 :64]! \n" + "vld1.s8 {d4}, [%2]! \n" + "vext.8 d1, d0, d0, #4 \n" + "vrev64.16 d5, d4 \n" + "vmull.s8 q4, d0, d4 \n" + "vmull.s8 q5, d1, d4 \n" + "vmull.s8 q6, d0, d5 \n" + "vmull.s8 q7, d1, d5 \n" + "vpadal.s16 q8, q4 \n" + "vpadal.s16 q9, q5 \n" + "vpadal.s16 q10, q6 \n" + "vpadal.s16 q11, q7 \n" + + "4: \n" + "and r4, %6, #1 \n" // r4 = remain = max_kk & 1 + "cmp r4, #0 \n" + "beq 5f \n" + + // kk += 1 part + "vld1.s32 {d0[0]}, [%1]! \n" + "vld1.s32 {d2[]}, [%2]! \n" + "vrev32.16 d1, d0 \n" + "vrev32.s8 d3, d2 \n" + "vzip.32 d0, d1 \n" + "vmull.s8 q4, d0, d2 \n" + "vmull.s8 q5, d0, d3 \n" + "vaddw.s16 q8, d8 \n" + "vaddw.s16 q9, d9 \n" + "vaddw.s16 q10, d10 \n" + "vaddw.s16 q11, d11 \n" + + "5: \n" + "vstm %0!, {d16-d23} \n" + + : "=r"(outptr), // %0 + "=r"(pA), // %1 + "=r"(pB) // %2 + : "0"(outptr), + "1"(pA), + "2"(pB), + "r"(max_kk), // %6 + "r"(k) // %7 + : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +#endif // __aarch64__ +#else // NCNN_GNU_INLINE_ASM + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + int32x4_t _sum3; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + _sum2 = vld1q_s32(outptr + 8); + _sum3 = vld1q_s32(outptr + 12); + } + + int kk = 0; +#if __ARM_FEATURE_MATMUL_INT8 + { + int32x4_t _sum00 = vdupq_n_s32(0); + int32x4_t _sum01 = vdupq_n_s32(0); + int32x4_t _sum10 = vdupq_n_s32(0); + int32x4_t _sum11 = vdupq_n_s32(0); + for (; kk + 7 < max_kk; kk += 8) + { + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA1 = vld1q_s8(pA + 16); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + + // aaaaaaaa bbbbbbbb cccccccc dddddddd + + // 00000000 11111111 22222222 33333333 + + _sum00 = vmmlaq_s32(_sum00, _pA0, _pB0); + _sum01 = vmmlaq_s32(_sum01, _pA1, _pB0); + _sum10 = vmmlaq_s32(_sum10, _pA0, _pB1); + _sum11 = vmmlaq_s32(_sum11, _pA1, _pB1); + + // a0 a1 b0 b1 + // c0 c1 d0 d1 + // a2 a3 b2 b3 + // c2 c3 d2 d3 + + pA += 32; + pB += 32; + } + int32x4x2_t _ss0 = vuzpq_s32(_sum00, _sum01); + int32x4x2_t _ss1 = vuzpq_s32(_sum10, _sum11); + _sum0 = vaddq_s32(_sum0, _ss0.val[0]); + _sum1 = vaddq_s32(_sum1, _ss0.val[1]); + _sum2 = vaddq_s32(_sum2, _ss1.val[0]); + _sum3 = vaddq_s32(_sum3, _ss1.val[1]); + } +#elif __ARM_FEATURE_DOTPROD + for (; kk + 7 < max_kk; kk += 8) + { + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA1 = vld1q_s8(pA + 16); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + + _sum0 = vdotq_laneq_s32(_sum0, _pA0, _pB0, 0); + _sum1 = vdotq_laneq_s32(_sum1, _pA0, _pB0, 1); + _sum2 = vdotq_laneq_s32(_sum2, _pA0, _pB0, 2); + _sum3 = vdotq_laneq_s32(_sum3, _pA0, _pB0, 3); + + _sum0 = vdotq_laneq_s32(_sum0, _pA1, _pB1, 0); + _sum1 = vdotq_laneq_s32(_sum1, _pA1, _pB1, 1); + _sum2 = vdotq_laneq_s32(_sum2, _pA1, _pB1, 2); + _sum3 = vdotq_laneq_s32(_sum3, _pA1, _pB1, 3); + + pA += 32; + pB += 32; + } +#endif // __ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + int8x16_t _pA = vld1q_s8(pA); + int8x16_t _pB = vld1q_s8(pB); + + _sum0 = vdotq_laneq_s32(_sum0, _pA, _pB, 0); + _sum1 = vdotq_laneq_s32(_sum1, _pA, _pB, 1); + _sum2 = vdotq_laneq_s32(_sum2, _pA, _pB, 2); + _sum3 = vdotq_laneq_s32(_sum3, _pA, _pB, 3); +#else // __ARM_FEATURE_DOTPROD + int8x16_t _pA02 = vld1q_s8(pA); + int8x16_t _pB02 = vld1q_s8(pB); + + // aabbccdd eeffgghh + // ccddaabb gghheeff + + int8x16_t _pA13 = vreinterpretq_s8_s32(vrev64q_s32(vreinterpretq_s32_s8(_pA02))); + + // 00112233 44556677 + // 33221100 77665544 + + int8x16_t _pB13 = vreinterpretq_s8_s16(vrev64q_s16(vreinterpretq_s16_s8(_pB02))); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_pA02), vget_low_s8(_pB02)); + int16x8_t _s1 = vmull_s8(vget_low_s8(_pA13), vget_low_s8(_pB02)); + int16x8_t _s2 = vmull_s8(vget_low_s8(_pA02), vget_low_s8(_pB13)); + int16x8_t _s3 = vmull_s8(vget_low_s8(_pA13), vget_low_s8(_pB13)); + + _s0 = vmlal_s8(_s0, vget_high_s8(_pA02), vget_high_s8(_pB02)); + _s1 = vmlal_s8(_s1, vget_high_s8(_pA13), vget_high_s8(_pB02)); + _s2 = vmlal_s8(_s2, vget_high_s8(_pA02), vget_high_s8(_pB13)); + _s3 = vmlal_s8(_s3, vget_high_s8(_pA13), vget_high_s8(_pB13)); + + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); +#endif // __ARM_FEATURE_DOTPROD + + pA += 16; + pB += 16; + } + for (; kk + 1 < max_kk; kk += 2) + { +#if __ARM_FEATURE_DOTPROD + int8x8_t _pA = vld1_s8(pA); + int8x8_t _pB = vld1_s8(pB); + + int16x8_t _s0 = vmull_s8(_pA, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pB), 0))); + int16x8_t _s1 = vmull_s8(_pA, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pB), 1))); + int16x8_t _s2 = vmull_s8(_pA, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pB), 2))); + int16x8_t _s3 = vmull_s8(_pA, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pB), 3))); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); +#else // __ARM_FEATURE_DOTPROD + int8x8_t _pA0 = vld1_s8(pA); + int8x8_t _pB0 = vld1_s8(pB); + + // aabbccdd + // ccddaabb + + int8x8_t _pA1 = vext_s8(_pA0, _pA0, 4); + + // 00112233 + // 33221100 + + int8x8_t _pB1 = vreinterpret_s8_s16(vrev64_s16(vreinterpret_s16_s8(_pB0))); + + int16x8_t _s0 = vmull_s8(_pA0, _pB0); + int16x8_t _s1 = vmull_s8(_pA1, _pB0); + int16x8_t _s2 = vmull_s8(_pA0, _pB1); + int16x8_t _s3 = vmull_s8(_pA1, _pB1); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); +#endif // __ARM_FEATURE_DOTPROD + + pA += 8; + pB += 8; + } + for (; kk < max_kk; kk += 1) + { +#if __ARM_FEATURE_DOTPROD + int8x8_t _pA = vreinterpret_s8_s32(vld1_dup_s32((const int*)pA)); + int8x8_t _pB = vreinterpret_s8_s32(vld1_dup_s32((const int*)pB)); + + _pB = vzip_s8(_pB, _pB).val[0]; + int16x4x2_t _pB0123 = vzip_s16(vreinterpret_s16_s8(_pB), vreinterpret_s16_s8(_pB)); + + int16x8_t _s01 = vmull_s8(_pA, vreinterpret_s8_s16(_pB0123.val[0])); + int16x8_t _s23 = vmull_s8(_pA, vreinterpret_s8_s16(_pB0123.val[1])); + _sum0 = vaddw_s16(_sum0, vget_low_s16(_s01)); + _sum1 = vaddw_s16(_sum1, vget_high_s16(_s01)); + _sum2 = vaddw_s16(_sum2, vget_low_s16(_s23)); + _sum3 = vaddw_s16(_sum3, vget_high_s16(_s23)); +#else // __ARM_FEATURE_DOTPROD + + int8x8_t _pA0 = vld1_s8(pA); + int8x8_t _pB0 = vreinterpret_s8_s32(vld1_dup_s32((const int*)pB)); + + // abcd.... -> cdab.... -> abcdcdab + int8x8_t _pA1 = vreinterpret_s8_s16(vrev32_s16(vreinterpret_s16_s8(_pA0))); + int8x8_t _pA01 = vreinterpret_s8_s32(vzip_s32(vreinterpret_s32_s8(_pA0), vreinterpret_s32_s8(_pA1)).val[0]); + + // 01230123 -> 32103210 + int8x8_t _pB1 = vrev32_s8(_pB0); + + int16x8_t _s01 = vmull_s8(_pA01, _pB0); + int16x8_t _s23 = vmull_s8(_pA01, _pB1); + _sum0 = vaddw_s16(_sum0, vget_low_s16(_s01)); + _sum1 = vaddw_s16(_sum1, vget_high_s16(_s01)); + _sum2 = vaddw_s16(_sum2, vget_low_s16(_s23)); + _sum3 = vaddw_s16(_sum3, vget_high_s16(_s23)); +#endif // __ARM_FEATURE_DOTPROD + + pA += 4; + pB += 4; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + vst1q_s32(outptr + 8, _sum2); + vst1q_s32(outptr + 12, _sum3); + + outptr += 16; +#endif // NCNN_GNU_INLINE_ASM + } + for (; jj + 1 < max_jj; jj += 2) + { + const signed char* pA = pAT; + + int32x4_t _sum0; + int32x4_t _sum1; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + } + + int kk = 0; +#if __ARM_FEATURE_DOTPROD + { +#if __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum00 = vdupq_n_s32(0); + int32x4_t _sum01 = vdupq_n_s32(0); +#endif // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 7 < max_kk; kk += 8) + { + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA1 = vld1q_s8(pA + 16); + int8x16_t _pB = vld1q_s8(pB); + +#if __ARM_FEATURE_MATMUL_INT8 + // aaaaaaaa bbbbbbbb cccccccc dddddddd + + // 00000000 11111111 + + _sum00 = vmmlaq_s32(_sum00, _pA0, _pB); + _sum01 = vmmlaq_s32(_sum01, _pA1, _pB); +#else // __ARM_FEATURE_MATMUL_INT8 + _sum0 = vdotq_laneq_s32(_sum0, _pA0, _pB, 0); + _sum1 = vdotq_laneq_s32(_sum1, _pA0, _pB, 1); + _sum0 = vdotq_laneq_s32(_sum0, _pA1, _pB, 2); + _sum1 = vdotq_laneq_s32(_sum1, _pA1, _pB, 3); +#endif // __ARM_FEATURE_MATMUL_INT8 + + pA += 32; + pB += 16; + } +#if __ARM_FEATURE_MATMUL_INT8 + int32x4x2_t _ss = vuzpq_s32(_sum00, _sum01); + _sum0 = vaddq_s32(_sum0, _ss.val[0]); + _sum1 = vaddq_s32(_sum1, _ss.val[1]); +#endif // __ARM_FEATURE_MATMUL_INT8 + } +#endif // __ARM_FEATURE_DOTPROD + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + int8x16_t _pA = vld1q_s8(pA); + int8x8_t _pB = vld1_s8(pB); + + _sum0 = vdotq_lane_s32(_sum0, _pA, _pB, 0); + _sum1 = vdotq_lane_s32(_sum1, _pA, _pB, 1); +#else // __ARM_FEATURE_DOTPROD + int8x16_t _pA = vld1q_s8(pA); + int8x8_t _pB = vld1_s8(pB); + + // aabbccdd eeffgghh + + // 00112233 -> 00110011 22332233 + // 11001100 33223322 + + int32x2x2_t _pBB = vzip_s32(vreinterpret_s32_s8(_pB), vreinterpret_s32_s8(_pB)); + int8x16_t _pB02 = vreinterpretq_s8_s32(vcombine_s32(_pBB.val[0], _pBB.val[1])); + + int8x16_t _pB13 = vreinterpretq_s8_s16(vrev64q_s16(vreinterpretq_s16_s8(_pB02))); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_pA), vget_low_s8(_pB02)); + int16x8_t _s1 = vmull_s8(vget_low_s8(_pA), vget_low_s8(_pB13)); + _s0 = vmlal_s8(_s0, vget_high_s8(_pA), vget_high_s8(_pB02)); + _s1 = vmlal_s8(_s1, vget_high_s8(_pA), vget_high_s8(_pB13)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); +#endif // __ARM_FEATURE_DOTPROD + + pA += 16; + pB += 8; + } + for (; kk + 1 < max_kk; kk += 2) + { +#if __ARM_FEATURE_DOTPROD + int8x8_t _pA = vld1_s8(pA); + int8x8_t _pB = vld1_s8(pB); + // aabbccdd + // 0011.... + int16x8_t _s0 = vmull_s8(_pA, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pB), 0))); + int16x8_t _s1 = vmull_s8(_pA, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pB), 1))); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); +#else // __ARM_FEATURE_DOTPROD + int8x8_t _pA = vld1_s8(pA); + int8x8_t _pB0 = vreinterpret_s8_s32(vld1_dup_s32((const int*)pB)); + + // aabbccdd + + // 00110011 + // 11001100 + int8x8_t _pB1 = vext_s8(_pB0, _pB0, 2); + + int16x8_t _s0 = vmull_s8(_pA, _pB0); + int16x8_t _s1 = vmull_s8(_pA, _pB1); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); +#endif // __ARM_FEATURE_DOTPROD + + pA += 8; + pB += 4; + } + for (; kk < max_kk; kk += 1) + { +#if __ARM_FEATURE_DOTPROD + int8x8_t _pA = vreinterpret_s8_s32(vld1_dup_s32((const int*)pA)); + int8x8_t _pB = vreinterpret_s8_s16(vld1_dup_s16((const short*)pB)); + + // abcdabcd + + // 01010101 -> 00001111 + _pB = vuzp_s8(_pB, vext_s8(_pB, _pB, 1)).val[0]; + + int16x8_t _s0 = vmull_s8(_pA, _pB); + _sum0 = vaddw_s16(_sum0, vget_low_s16(_s0)); + _sum1 = vaddw_s16(_sum1, vget_high_s16(_s0)); +#else // __ARM_FEATURE_DOTPROD + int8x8_t _pA = vreinterpret_s8_s32(vld1_dup_s32((const int*)pA)); + int8x8_t _pB0 = vreinterpret_s8_s16(vld1_dup_s16((const short*)pB)); + + // abcd abcd + + // 0101 0101 -> 0101 1010 + + int8x8_t _pB1 = vext_s8(_pB0, _pB0, 1); + int8x8_t _pB = vreinterpret_s8_s32(vzip_s32(vreinterpret_s32_s8(_pB0), vreinterpret_s32_s8(_pB1)).val[0]); + + int16x8_t _s0 = vmull_s8(_pA, _pB); + _sum0 = vaddw_s16(_sum0, vget_low_s16(_s0)); + _sum1 = vaddw_s16(_sum1, vget_high_s16(_s0)); +#endif // __ARM_FEATURE_DOTPROD + + pA += 4; + pB += 2; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + + outptr += 8; + } + for (; jj < max_jj; jj += 1) + { + const signed char* pA = pAT; + + int32x4_t _sum0; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + } + + int kk = 0; +#if __ARM_FEATURE_DOTPROD + { +#if __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum01 = vdupq_n_s32(0); + int32x4_t _sum23 = vdupq_n_s32(0); +#endif // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 7 < max_kk; kk += 8) + { + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA1 = vld1q_s8(pA + 16); + int8x8_t _pB = vld1_s8(pB); + +#if __ARM_FEATURE_MATMUL_INT8 + // aaaaaaaa bbbbbbbb cccccccc dddddddd + + // 00000000 + + int8x16_t _pBB = vcombine_s8(_pB, _pB); + + _sum01 = vdotq_s32(_sum01, _pA0, _pBB); + _sum23 = vdotq_s32(_sum23, _pA1, _pBB); +#else // __ARM_FEATURE_MATMUL_INT8 + _sum0 = vdotq_lane_s32(_sum0, _pA0, _pB, 0); + _sum0 = vdotq_lane_s32(_sum0, _pA1, _pB, 1); +#endif // __ARM_FEATURE_MATMUL_INT8 + + pA += 32; + pB += 8; + } +#if __ARM_FEATURE_MATMUL_INT8 + _sum0 = vaddq_s32(_sum0, vpaddq_s32(_sum01, _sum23)); +#endif // __ARM_FEATURE_MATMUL_INT8 + } +#endif // __ARM_FEATURE_DOTPROD + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + int8x16_t _pA = vld1q_s8(pA); + int8x8_t _pB = vld1_s8(pB); + + _sum0 = vdotq_lane_s32(_sum0, _pA, _pB, 0); +#else // __ARM_FEATURE_DOTPROD + int8x16_t _pA = vld1q_s8(pA); + int8x8_t _pB0 = vreinterpret_s8_s16(vld1_dup_s16((const short*)pB)); + int8x8_t _pB1 = vreinterpret_s8_s16(vld1_dup_s16((const short*)(pB + 2))); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_pA), _pB0); + _s0 = vmlal_s8(_s0, vget_high_s8(_pA), _pB1); + _sum0 = vpadalq_s16(_sum0, _s0); +#endif // __ARM_FEATURE_DOTPROD + + pA += 16; + pB += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + int8x8_t _pA = vld1_s8(pA); + int8x8_t _pB = vreinterpret_s8_s16(vld1_dup_s16((const short*)pB)); + + int16x8_t _s0 = vmull_s8(_pA, _pB); + _sum0 = vpadalq_s16(_sum0, _s0); + + pA += 8; + pB += 2; + } + for (; kk < max_kk; kk += 1) + { + int8x8_t _pA = vreinterpret_s8_s32(vld1_dup_s32((const int*)pA)); + int8x8_t _pB = vld1_dup_s8(pB); + + int16x8_t _s0 = vmull_s8(_pA, _pB); + _sum0 = vaddw_s16(_sum0, vget_low_s16(_s0)); + + pA += 4; + pB += 1; + } + + vst1q_s32(outptr, _sum0); + + outptr += 4; + } + + pAT += max_kk * 4; + } +#endif // __ARM_NEON + for (; ii + 1 < max_ii; ii += 2) + { + const signed char* pB = pBT; + + int jj = 0; +#if __ARM_NEON +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + int32x4_t _sum3; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + _sum2 = vld1q_s32(outptr + 8); + _sum3 = vld1q_s32(outptr + 12); + } + + const signed char* pA = pAT; + int kk = 0; +#if __ARM_FEATURE_DOTPROD + { +#if __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum01 = vdupq_n_s32(0); + int32x4_t _sum23 = vdupq_n_s32(0); + int32x4_t _sum45 = vdupq_n_s32(0); + int32x4_t _sum67 = vdupq_n_s32(0); +#else // __ARM_FEATURE_MATMUL_INT8 + int32x2_t _sum00 = vdup_n_s32(0); + int32x2_t _sum01 = vdup_n_s32(0); + int32x2_t _sum10 = vdup_n_s32(0); + int32x2_t _sum11 = vdup_n_s32(0); + int32x2_t _sum20 = vdup_n_s32(0); + int32x2_t _sum21 = vdup_n_s32(0); + int32x2_t _sum30 = vdup_n_s32(0); + int32x2_t _sum31 = vdup_n_s32(0); +#endif // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 7 < max_kk; kk += 8) + { + int8x16_t _pA = vld1q_s8(pA); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + int8x16_t _pB2 = vld1q_s8(pB + 32); + int8x16_t _pB3 = vld1q_s8(pB + 48); + +#if __ARM_FEATURE_MATMUL_INT8 + _sum01 = vmmlaq_s32(_sum01, _pA, _pB0); + _sum23 = vmmlaq_s32(_sum23, _pA, _pB1); + _sum45 = vmmlaq_s32(_sum45, _pA, _pB2); + _sum67 = vmmlaq_s32(_sum67, _pA, _pB3); +#else // __ARM_FEATURE_MATMUL_INT8 + _sum00 = vdot_laneq_s32(_sum00, vget_low_s8(_pA), _pB0, 0); + _sum01 = vdot_laneq_s32(_sum01, vget_low_s8(_pA), _pB0, 1); + _sum10 = vdot_laneq_s32(_sum10, vget_low_s8(_pA), _pB0, 2); + _sum11 = vdot_laneq_s32(_sum11, vget_low_s8(_pA), _pB0, 3); + _sum20 = vdot_laneq_s32(_sum20, vget_low_s8(_pA), _pB1, 0); + _sum21 = vdot_laneq_s32(_sum21, vget_low_s8(_pA), _pB1, 1); + _sum30 = vdot_laneq_s32(_sum30, vget_low_s8(_pA), _pB1, 2); + _sum31 = vdot_laneq_s32(_sum31, vget_low_s8(_pA), _pB1, 3); + _sum00 = vdot_laneq_s32(_sum00, vget_high_s8(_pA), _pB2, 0); + _sum01 = vdot_laneq_s32(_sum01, vget_high_s8(_pA), _pB2, 1); + _sum10 = vdot_laneq_s32(_sum10, vget_high_s8(_pA), _pB2, 2); + _sum11 = vdot_laneq_s32(_sum11, vget_high_s8(_pA), _pB2, 3); + _sum20 = vdot_laneq_s32(_sum20, vget_high_s8(_pA), _pB3, 0); + _sum21 = vdot_laneq_s32(_sum21, vget_high_s8(_pA), _pB3, 1); + _sum30 = vdot_laneq_s32(_sum30, vget_high_s8(_pA), _pB3, 2); + _sum31 = vdot_laneq_s32(_sum31, vget_high_s8(_pA), _pB3, 3); +#endif // __ARM_FEATURE_MATMUL_INT8 + + pA += 16; + pB += 64; + } +#if __ARM_FEATURE_MATMUL_INT8 + _sum0 = vaddq_s32(_sum0, vcombine_s32(vget_low_s32(_sum01), vget_low_s32(_sum23))); + _sum1 = vaddq_s32(_sum1, vcombine_s32(vget_low_s32(_sum45), vget_low_s32(_sum67))); + _sum2 = vaddq_s32(_sum2, vcombine_s32(vget_high_s32(_sum01), vget_high_s32(_sum23))); + _sum3 = vaddq_s32(_sum3, vcombine_s32(vget_high_s32(_sum45), vget_high_s32(_sum67))); +#else // __ARM_FEATURE_MATMUL_INT8 + int32x2x2_t _sum0x = vzip_s32(_sum00, _sum01); + int32x2x2_t _sum1x = vzip_s32(_sum10, _sum11); + int32x2x2_t _sum2x = vzip_s32(_sum20, _sum21); + int32x2x2_t _sum3x = vzip_s32(_sum30, _sum31); + _sum0 = vaddq_s32(_sum0, vcombine_s32(_sum0x.val[0], _sum1x.val[0])); + _sum1 = vaddq_s32(_sum1, vcombine_s32(_sum2x.val[0], _sum3x.val[0])); + _sum2 = vaddq_s32(_sum2, vcombine_s32(_sum0x.val[1], _sum1x.val[1])); + _sum3 = vaddq_s32(_sum3, vcombine_s32(_sum2x.val[1], _sum3x.val[1])); +#endif // __ARM_FEATURE_MATMUL_INT8 + } +#endif // __ARM_FEATURE_DOTPROD + { +#if __ARM_FEATURE_DOTPROD + int32x2_t _sum00 = vdup_n_s32(0); + int32x2_t _sum01 = vdup_n_s32(0); + int32x2_t _sum10 = vdup_n_s32(0); + int32x2_t _sum11 = vdup_n_s32(0); + int32x2_t _sum20 = vdup_n_s32(0); + int32x2_t _sum21 = vdup_n_s32(0); + int32x2_t _sum30 = vdup_n_s32(0); + int32x2_t _sum31 = vdup_n_s32(0); +#endif // __ARM_FEATURE_DOTPROD + for (; kk + 3 < max_kk; kk += 4) + { + int8x8_t _pA = vld1_s8(pA); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + +#if __ARM_FEATURE_DOTPROD + _sum00 = vdot_laneq_s32(_sum00, _pA, _pB0, 0); + _sum01 = vdot_laneq_s32(_sum01, _pA, _pB0, 1); + _sum10 = vdot_laneq_s32(_sum10, _pA, _pB0, 2); + _sum11 = vdot_laneq_s32(_sum11, _pA, _pB0, 3); + _sum20 = vdot_laneq_s32(_sum20, _pA, _pB1, 0); + _sum21 = vdot_laneq_s32(_sum21, _pA, _pB1, 1); + _sum30 = vdot_laneq_s32(_sum30, _pA, _pB1, 2); + _sum31 = vdot_laneq_s32(_sum31, _pA, _pB1, 3); +#else // __ARM_FEATURE_DOTPROD + int8x8_t _pA0 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 0)); + int8x8_t _pA1 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 1)); + int8x8_t _pA2 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 2)); + int8x8_t _pA3 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 3)); + + int16x8_t _s0 = vmull_s8(_pA0, vget_low_s8(_pB0)); + int16x8_t _s1 = vmull_s8(_pA0, vget_high_s8(_pB0)); + int16x8_t _s2 = vmull_s8(_pA1, vget_low_s8(_pB0)); + int16x8_t _s3 = vmull_s8(_pA1, vget_high_s8(_pB0)); + _s0 = vmlal_s8(_s0, _pA2, vget_low_s8(_pB1)); + _s1 = vmlal_s8(_s1, _pA2, vget_high_s8(_pB1)); + _s2 = vmlal_s8(_s2, _pA3, vget_low_s8(_pB1)); + _s3 = vmlal_s8(_s3, _pA3, vget_high_s8(_pB1)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); +#endif // __ARM_FEATURE_DOTPROD + + pA += 8; + pB += 32; + } +#if __ARM_FEATURE_DOTPROD + int32x2x2_t _sum0x = vzip_s32(_sum00, _sum01); + int32x2x2_t _sum1x = vzip_s32(_sum10, _sum11); + int32x2x2_t _sum2x = vzip_s32(_sum20, _sum21); + int32x2x2_t _sum3x = vzip_s32(_sum30, _sum31); + _sum0 = vaddq_s32(_sum0, vcombine_s32(_sum0x.val[0], _sum1x.val[0])); + _sum1 = vaddq_s32(_sum1, vcombine_s32(_sum2x.val[0], _sum3x.val[0])); + _sum2 = vaddq_s32(_sum2, vcombine_s32(_sum0x.val[1], _sum1x.val[1])); + _sum3 = vaddq_s32(_sum3, vcombine_s32(_sum2x.val[1], _sum3x.val[1])); +#endif // __ARM_FEATURE_DOTPROD + } + for (; kk + 1 < max_kk; kk += 2) + { + int16x4_t _pA = vreinterpret_s16_s32(vld1_dup_s32((const int*)pA)); + int8x16_t _pB = vld1q_s8(pB); + + int16x4x2_t _pA01 = vuzp_s16(_pA, _pA); + int8x8_t _pA0 = vreinterpret_s8_s16(_pA01.val[0]); + int8x8_t _pA1 = vreinterpret_s8_s16(_pA01.val[1]); + + int16x8_t _s0 = vmull_s8(_pA0, vget_low_s8(_pB)); + int16x8_t _s1 = vmull_s8(_pA0, vget_high_s8(_pB)); + int16x8_t _s2 = vmull_s8(_pA1, vget_low_s8(_pB)); + int16x8_t _s3 = vmull_s8(_pA1, vget_high_s8(_pB)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + + pA += 4; + pB += 16; + } + for (; kk < max_kk; kk += 1) + { + int8x8_t _pA = vreinterpret_s8_s16(vld1_dup_s16((const short*)pA)); + int8x8_t _pB = vld1_s8(pB); + + int8x8x2_t _pA01 = vuzp_s8(_pA, _pA); + + int16x8_t _s0 = vmull_s8(_pA01.val[0], _pB); + int16x8_t _s1 = vmull_s8(_pA01.val[1], _pB); + _sum0 = vaddw_s16(_sum0, vget_low_s16(_s0)); + _sum1 = vaddw_s16(_sum1, vget_high_s16(_s0)); + _sum2 = vaddw_s16(_sum2, vget_low_s16(_s1)); + _sum3 = vaddw_s16(_sum3, vget_high_s16(_s1)); + + pA += 2; + pB += 8; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + vst1q_s32(outptr + 8, _sum2); + vst1q_s32(outptr + 12, _sum3); + + outptr += 16; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0; + int32x4_t _sum1; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + } + + const signed char* pA = pAT; + int kk = 0; +#if __ARM_FEATURE_DOTPROD + { +#if __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum01 = vdupq_n_s32(0); + int32x4_t _sum23 = vdupq_n_s32(0); +#else // __ARM_FEATURE_MATMUL_INT8 + int32x2_t _sum00 = vdup_n_s32(0); + int32x2_t _sum01 = vdup_n_s32(0); + int32x2_t _sum10 = vdup_n_s32(0); + int32x2_t _sum11 = vdup_n_s32(0); +#endif // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 7 < max_kk; kk += 8) + { + int8x16_t _pA = vld1q_s8(pA); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + +#if __ARM_FEATURE_MATMUL_INT8 + _sum01 = vmmlaq_s32(_sum01, _pA, _pB0); + _sum23 = vmmlaq_s32(_sum23, _pA, _pB1); +#else // __ARM_FEATURE_MATMUL_INT8 + _sum00 = vdot_laneq_s32(_sum00, vget_low_s8(_pA), _pB0, 0); + _sum01 = vdot_laneq_s32(_sum01, vget_low_s8(_pA), _pB0, 1); + _sum10 = vdot_laneq_s32(_sum10, vget_low_s8(_pA), _pB0, 2); + _sum11 = vdot_laneq_s32(_sum11, vget_low_s8(_pA), _pB0, 3); + _sum00 = vdot_laneq_s32(_sum00, vget_high_s8(_pA), _pB1, 0); + _sum01 = vdot_laneq_s32(_sum01, vget_high_s8(_pA), _pB1, 1); + _sum10 = vdot_laneq_s32(_sum10, vget_high_s8(_pA), _pB1, 2); + _sum11 = vdot_laneq_s32(_sum11, vget_high_s8(_pA), _pB1, 3); +#endif // __ARM_FEATURE_MATMUL_INT8 + + pA += 16; + pB += 32; + } +#if __ARM_FEATURE_MATMUL_INT8 + _sum0 = vaddq_s32(_sum0, vcombine_s32(vget_low_s32(_sum01), vget_low_s32(_sum23))); + _sum1 = vaddq_s32(_sum1, vcombine_s32(vget_high_s32(_sum01), vget_high_s32(_sum23))); +#else // __ARM_FEATURE_MATMUL_INT8 + int32x2x2_t _sum0x = vzip_s32(_sum00, _sum01); + int32x2x2_t _sum1x = vzip_s32(_sum10, _sum11); + _sum0 = vaddq_s32(_sum0, vcombine_s32(_sum0x.val[0], _sum1x.val[0])); + _sum1 = vaddq_s32(_sum1, vcombine_s32(_sum0x.val[1], _sum1x.val[1])); +#endif // __ARM_FEATURE_MATMUL_INT8 + } +#endif // __ARM_FEATURE_DOTPROD + { +#if __ARM_FEATURE_DOTPROD + int32x2_t _sum00 = vdup_n_s32(0); + int32x2_t _sum01 = vdup_n_s32(0); + int32x2_t _sum10 = vdup_n_s32(0); + int32x2_t _sum11 = vdup_n_s32(0); +#endif // __ARM_FEATURE_DOTPROD + for (; kk + 3 < max_kk; kk += 4) + { + int8x8_t _pA = vld1_s8(pA); + int8x16_t _pB = vld1q_s8(pB); + +#if __ARM_FEATURE_DOTPROD + _sum00 = vdot_laneq_s32(_sum00, _pA, _pB, 0); + _sum01 = vdot_laneq_s32(_sum01, _pA, _pB, 1); + _sum10 = vdot_laneq_s32(_sum10, _pA, _pB, 2); + _sum11 = vdot_laneq_s32(_sum11, _pA, _pB, 3); +#else // __ARM_FEATURE_DOTPROD + int8x8_t _pA0 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 0)); + int8x8_t _pA1 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 1)); + int8x8_t _pA2 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 2)); + int8x8_t _pA3 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 3)); + + int16x8_t _s0 = vmull_s8(_pA0, vget_low_s8(_pB)); + int16x8_t _s1 = vmull_s8(_pA1, vget_low_s8(_pB)); + _s0 = vmlal_s8(_s0, _pA2, vget_high_s8(_pB)); + _s1 = vmlal_s8(_s1, _pA3, vget_high_s8(_pB)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); +#endif // __ARM_FEATURE_DOTPROD + + pA += 8; + pB += 16; + } +#if __ARM_FEATURE_DOTPROD + int32x2x2_t _sum0x = vzip_s32(_sum00, _sum01); + int32x2x2_t _sum1x = vzip_s32(_sum10, _sum11); + _sum0 = vaddq_s32(_sum0, vcombine_s32(_sum0x.val[0], _sum1x.val[0])); + _sum1 = vaddq_s32(_sum1, vcombine_s32(_sum0x.val[1], _sum1x.val[1])); +#endif // __ARM_FEATURE_DOTPROD + } + for (; kk + 1 < max_kk; kk += 2) + { + int16x4_t _pA = vreinterpret_s16_s32(vdup_lane_s32(vreinterpret_s32_s8(vld1_s8(pA)), 0)); + int8x8_t _pB = vld1_s8(pB); + + int16x4x2_t _pA01 = vuzp_s16(_pA, _pA); + int8x8_t _pA0 = vreinterpret_s8_s16(_pA01.val[0]); + int8x8_t _pA1 = vreinterpret_s8_s16(_pA01.val[1]); + + int16x8_t _s0 = vmull_s8(_pA0, _pB); + int16x8_t _s1 = vmull_s8(_pA1, _pB); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + + pA += 4; + pB += 8; + } + for (; kk < max_kk; kk += 1) + { + int8x8_t _pA = vreinterpret_s8_s16(vld1_dup_s16((const short*)pA)); + int8x8_t _pB = vreinterpret_s8_s32(vdup_lane_s32(vreinterpret_s32_s8(vld1_s8(pB)), 0)); + + _pA = vzip_s8(_pA, _pA).val[0]; + _pA = vreinterpret_s8_s16(vzip_s16(vreinterpret_s16_s8(_pA), vreinterpret_s16_s8(_pA)).val[0]); + + int16x8_t _s0 = vmull_s8(_pA, _pB); + _sum0 = vaddw_s16(_sum0, vget_low_s16(_s0)); + _sum1 = vaddw_s16(_sum1, vget_high_s16(_s0)); + + pA += 2; + pB += 4; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + + outptr += 8; + } +#endif // __ARM_NEON + for (; jj + 1 < max_jj; jj += 2) + { +#if __ARM_NEON + int32x4_t _sum; + + if (k == 0) + { + _sum = vdupq_n_s32(0); + } + else + { + _sum = vld1q_s32(outptr); + } + + const signed char* pA = pAT; + int kk = 0; + +#if __ARM_FEATURE_DOTPROD + for (; kk + 7 < max_kk; kk += 8) + { + int8x16_t _pA = vld1q_s8(pA); + int8x16_t _pB = vld1q_s8(pB); + +#if __ARM_FEATURE_MATMUL_INT8 + _sum = vmmlaq_s32(_sum, _pA, _pB); +#else // __ARM_FEATURE_MATMUL_INT8 + int32x4x2_t _pAA = vzipq_s32(vreinterpretq_s32_s8(_pA), vreinterpretq_s32_s8(_pA)); + int8x16_t _pA01 = vreinterpretq_s8_s32(_pAA.val[0]); + int8x16_t _pA23 = vreinterpretq_s8_s32(_pAA.val[1]); + int8x16_t _pB01 = vcombine_s8(vget_low_s8(_pB), vget_low_s8(_pB)); + int8x16_t _pB23 = vcombine_s8(vget_high_s8(_pB), vget_high_s8(_pB)); + + _sum = vdotq_s32(_sum, _pA01, _pB01); + _sum = vdotq_s32(_sum, _pA23, _pB23); +#endif // __ARM_FEATURE_MATMUL_INT8 + + pA += 16; + pB += 16; + } +#endif // __ARM_FEATURE_DOTPROD + for (; kk + 3 < max_kk; kk += 4) + { + int8x8_t _pA = vld1_s8(pA); + int8x8_t _pB = vld1_s8(pB); + +#if __ARM_FEATURE_DOTPROD + int32x2x2_t _pAA = vzip_s32(vreinterpret_s32_s8(_pA), vreinterpret_s32_s8(_pA)); + int8x16_t _pA01 = vreinterpretq_s8_s32(vcombine_s32(_pAA.val[0], _pAA.val[1])); + + int8x16_t _pB01 = vcombine_s8(_pB, _pB); + + _sum = vdotq_s32(_sum, _pA01, _pB01); +#else // __ARM_FEATURE_DOTPROD + int16x4x2_t _pA01 = vzip_s16(vreinterpret_s16_s8(_pA), vreinterpret_s16_s8(_pA)); + int32x2x2_t _pB01 = vzip_s32(vreinterpret_s32_s8(_pB), vreinterpret_s32_s8(_pB)); + + int16x8_t _s0 = vmull_s8(vreinterpret_s8_s16(_pA01.val[0]), vreinterpret_s8_s32(_pB01.val[0])); + _s0 = vmlal_s8(_s0, vreinterpret_s8_s16(_pA01.val[1]), vreinterpret_s8_s32(_pB01.val[1])); + _sum = vpadalq_s16(_sum, _s0); +#endif // __ARM_FEATURE_DOTPROD + + pA += 8; + pB += 8; + } + for (; kk + 1 < max_kk; kk += 2) + { + int8x8_t _pA = vld1_s8(pA); + int8x8_t _pB = vld1_s8(pB); + + _pA = vreinterpret_s8_s16(vzip_s16(vreinterpret_s16_s8(_pA), vreinterpret_s16_s8(_pA)).val[0]); + _pB = vreinterpret_s8_s32(vzip_s32(vreinterpret_s32_s8(_pB), vreinterpret_s32_s8(_pB)).val[0]); + + int16x8_t _s0 = vmull_s8(_pA, _pB); + _sum = vpadalq_s16(_sum, _s0); + + // A0 A1 A2 A3 + // B0 B1 B2 B3 + + // A0 A1 A0 A1 A2 A3 A2 A3 + // B0 B1 B2 B3 B0 B1 B2 B3 + + pA += 4; + pB += 4; + } + for (; kk < max_kk; kk += 1) + { + int8x8_t _pA = vreinterpret_s8_s16(vld1_dup_s16((const short*)pA)); + int8x8_t _pB = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vld1_s8(pB)), 0)); + + _pA = vzip_s8(_pA, _pA).val[0]; + + int16x8_t _s0 = vmull_s8(_pA, _pB); + _sum = vaddw_s16(_sum, vget_low_s16(_s0)); + + // A0 A1 A0 A1 + // B0 B1 B0 B1 + + // A0 A0 A1 A1 + + pA += 2; + pB += 2; + } + + vst1q_s32(outptr, _sum); + + outptr += 4; +#else // __ARM_NEON + int sum00; + int sum10; + int sum01; + int sum11; + + if (k == 0) + { + sum00 = 0; + sum10 = 0; + sum01 = 0; + sum11 = 0; + } + else + { + sum00 = outptr[0]; + sum10 = outptr[1]; + sum01 = outptr[2]; + sum11 = outptr[3]; + } + + const signed char* pA = pAT; + int kk = 0; +#if __ARM_FEATURE_SIMD32 && NCNN_GNU_INLINE_ASM + for (; kk + 1 < max_kk; kk += 2) + { + // fomit-frame-pointer implied in optimized flag spare one register + // let us stay away from error: ‘asm’ operand has impossible constraints --- nihui +#if __OPTIMIZE__ + asm volatile( + "ldr r2, [%0], #4 \n" // int8x4_t _pA = *((int8x4_t*)pA); pA += 4; + "ldr r4, [%1], #4 \n" // int8x4_t _pB = *((int8x4_t*)pB); pB += 4; + "ror r3, r2, #8 \n" // int8x4_t _pA_r8 = __ror(_pA, 8); + "ror r5, r4, #8 \n" // int8x4_t _pB_r8 = __ror(_pB, 8); + "sxtb16 r2, r2 \n" // int16x2_t _pA0 = __sxtb16(_pA); + "sxtb16 r4, r4 \n" // int16x2_t _pA1 = __sxtb16(_pA_r8); + "sxtb16 r3, r3 \n" // int16x2_t _pB0 = __sxtb16(_pB); + "sxtb16 r5, r5 \n" // int16x2_t _pB1 = __sxtb16(_pB_r8); + "smlad %2, r2, r4, %2 \n" // sum00 = __smlad(_pA0, _pB0, sum00); + "smlad %3, r3, r4, %3 \n" // sum10 = __smlad(_pA1, _pB0, sum10); + "smlad %4, r2, r5, %4 \n" // sum01 = __smlad(_pA0, _pB1, sum01); + "smlad %5, r3, r5, %5 \n" // sum11 = __smlad(_pA1, _pB1, sum11); + : "=r"(pA), + "=r"(pB), + "=r"(sum00), + "=r"(sum10), + "=r"(sum01), + "=r"(sum11) + : "0"(pA), + "1"(pB), + "2"(sum00), + "3"(sum10), + "4"(sum01), + "5"(sum11) + : "memory", "r2", "r3", "r4", "r5"); +#else + int _pA0 = *((int*)pA); + int _pB0 = *((int*)pB); + int _pA1; + int _pB1; + asm volatile("ror %0, %1, #8" + : "=r"(_pA1) + : "r"(_pA0) + :); + asm volatile("ror %0, %1, #8" + : "=r"(_pB1) + : "r"(_pB0) + :); + asm volatile("sxtb16 %0, %0" + : "=r"(_pA0) + : "0"(_pA0) + :); + asm volatile("sxtb16 %0, %0" + : "=r"(_pA1) + : "0"(_pA1) + :); + asm volatile("sxtb16 %0, %0" + : "=r"(_pB0) + : "0"(_pB0) + :); + asm volatile("sxtb16 %0, %0" + : "=r"(_pB1) + : "0"(_pB1) + :); + asm volatile("smlad %0, %2, %3, %0" + : "=r"(sum00) + : "0"(sum00), "r"(_pA0), "r"(_pB0) + :); + asm volatile("smlad %0, %2, %3, %0" + : "=r"(sum10) + : "0"(sum10), "r"(_pA1), "r"(_pB0) + :); + asm volatile("smlad %0, %2, %3, %0" + : "=r"(sum01) + : "0"(sum01), "r"(_pA0), "r"(_pB1) + :); + asm volatile("smlad %0, %2, %3, %0" + : "=r"(sum11) + : "0"(sum11), "r"(_pA1), "r"(_pB1) + :); + pA += 4; + pB += 4; +#endif + } +#endif // __ARM_FEATURE_SIMD32 && NCNN_GNU_INLINE_ASM + for (; kk < max_kk; kk += 1) + { + sum00 += pA[0] * pB[0]; + sum10 += pA[1] * pB[0]; + sum01 += pA[0] * pB[1]; + sum11 += pA[1] * pB[1]; + + pA += 2; + pB += 2; + } + + outptr[0] = sum00; + outptr[1] = sum10; + outptr[2] = sum01; + outptr[3] = sum11; + + outptr += 4; +#endif // __ARM_NEON + } + for (; jj < max_jj; jj += 1) + { +#if __ARM_NEON + int32x2_t _sum; + + if (k == 0) + { + _sum = vdup_n_s32(0); + } + else + { + _sum = vld1_s32(outptr); + } +#else // __ARM_NEON + int sum0; + int sum1; + + if (k == 0) + { + sum0 = 0; + sum1 = 0; + } + else + { + sum0 = outptr[0]; + sum1 = outptr[1]; + } +#endif // __ARM_NEON + + const signed char* pA = pAT; + int kk = 0; +#if __ARM_NEON +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + { + int32x4_t _sum0 = vdupq_n_s32(0); + for (; kk + 7 < max_kk; kk += 8) + { + int8x16_t _pA = vld1q_s8(pA); + int8x8_t _pB = vld1_s8(pB); + + int8x16_t _pBB = vcombine_s8(_pB, _pB); + + _sum0 = vdotq_s32(_sum0, _pA, _pBB); + + pA += 16; + pB += 8; + } + int32x2_t _ss = vpadd_s32(vget_low_s32(_sum0), vget_high_s32(_sum0)); + _sum = vadd_s32(_sum, _ss); + } +#else // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 7 < max_kk; kk += 8) + { + int8x16_t _pA = vld1q_s8(pA); + int8x8_t _pB = vld1_s8(pB); + + _sum = vdot_lane_s32(_sum, vget_low_s8(_pA), _pB, 0); + _sum = vdot_lane_s32(_sum, vget_high_s8(_pA), _pB, 1); + + pA += 16; + pB += 8; + } +#endif // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 3 < max_kk; kk += 4) + { + int8x8_t _pA = vld1_s8(pA); + int8x8_t _pB = vreinterpret_s8_s32(vld1_dup_s32((const int*)pB)); + + _sum = vdot_s32(_sum, _pA, _pB); + + pA += 8; + pB += 4; + } +#else // __ARM_FEATURE_DOTPROD + { + int32x4_t _sum0 = vdupq_n_s32(0); + for (; kk + 3 < max_kk; kk += 4) + { + int8x8_t _pA = vld1_s8(pA); + int8x8_t _pB = vreinterpret_s8_s32(vdup_lane_s32(vreinterpret_s32_s8(vld1_s8(pB)), 0)); + + _pB = vreinterpret_s8_s16(vzip_s16(vreinterpret_s16_s8(_pB), vreinterpret_s16_s8(_pB)).val[0]); + + int16x8_t _s0 = vmull_s8(_pA, _pB); + _sum0 = vpadalq_s16(_sum0, _s0); + + pA += 8; + pB += 4; + } + int32x2_t _ss = vadd_s32(vget_low_s32(_sum0), vget_high_s32(_sum0)); + _sum = vadd_s32(_sum, _ss); + } +#endif // __ARM_FEATURE_DOTPROD + int sum0 = vget_lane_s32(_sum, 0); + int sum1 = vget_lane_s32(_sum, 1); + for (; kk + 1 < max_kk; kk += 2) + { + sum0 += pA[0] * pB[0]; + sum0 += pA[1] * pB[1]; + sum1 += pA[2] * pB[0]; + sum1 += pA[3] * pB[1]; + pA += 4; + pB += 2; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk += 1) + { + sum0 += pA[0] * pB[0]; + sum1 += pA[1] * pB[0]; + pA += 2; + pB += 1; + } + + outptr[0] = sum0; + outptr[1] = sum1; + + outptr += 2; + } + + pAT += max_kk * 2; + } + for (; ii < max_ii; ii += 1) + { + const signed char* pB = pBT; + + int jj = 0; +#if __ARM_NEON +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0; + int32x4_t _sum1; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + } + + const signed char* pA = pAT; + int kk = 0; +#if __ARM_FEATURE_DOTPROD + { +#if __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum00 = vdupq_n_s32(0); + int32x4_t _sum01 = vdupq_n_s32(0); + int32x4_t _sum10 = vdupq_n_s32(0); + int32x4_t _sum11 = vdupq_n_s32(0); +#endif // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 7 < max_kk; kk += 8) + { + int8x8_t _pA = vld1_s8(pA); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + int8x16_t _pB2 = vld1q_s8(pB + 32); + int8x16_t _pB3 = vld1q_s8(pB + 48); + +#if __ARM_FEATURE_MATMUL_INT8 + int8x16_t _pAA = vcombine_s8(_pA, _pA); + _sum00 = vdotq_s32(_sum00, _pAA, _pB0); + _sum01 = vdotq_s32(_sum01, _pAA, _pB1); + _sum10 = vdotq_s32(_sum10, _pAA, _pB2); + _sum11 = vdotq_s32(_sum11, _pAA, _pB3); +#else // __ARM_FEATURE_MATMUL_INT8 + _sum0 = vdotq_lane_s32(_sum0, _pB0, _pA, 0); + _sum1 = vdotq_lane_s32(_sum1, _pB1, _pA, 0); + _sum0 = vdotq_lane_s32(_sum0, _pB2, _pA, 1); + _sum1 = vdotq_lane_s32(_sum1, _pB3, _pA, 1); +#endif // __ARM_FEATURE_MATMUL_INT8 + + pA += 8; + pB += 64; + } +#if __ARM_FEATURE_MATMUL_INT8 + _sum0 = vaddq_s32(_sum0, vpaddq_s32(_sum00, _sum01)); + _sum1 = vaddq_s32(_sum1, vpaddq_s32(_sum10, _sum11)); +#endif // __ARM_FEATURE_MATMUL_INT8 + } +#else // __ARM_FEATURE_DOTPROD + { + int32x4_t _sum2 = vdupq_n_s32(0); + int32x4_t _sum3 = vdupq_n_s32(0); + int32x4_t _sum4 = vdupq_n_s32(0); + int32x4_t _sum5 = vdupq_n_s32(0); + int32x4_t _sum6 = vdupq_n_s32(0); + int32x4_t _sum7 = vdupq_n_s32(0); + for (; kk + 15 < max_kk; kk += 16) + { + // TODO + // __builtin_prefetch(pA + 16); + // __builtin_prefetch(pB + 128); + int8x16_t _pA = vld1q_s8(pA); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + int8x16_t _pB2 = vld1q_s8(pB + 32); + int8x16_t _pB3 = vld1q_s8(pB + 48); + int8x16_t _pB4 = vld1q_s8(pB + 64); + int8x16_t _pB5 = vld1q_s8(pB + 80); + int8x16_t _pB6 = vld1q_s8(pB + 96); + int8x16_t _pB7 = vld1q_s8(pB + 112); + + int8x8_t _pA0 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pA)), 0)); + int8x8_t _pA1 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pA)), 1)); + int8x8_t _pA2 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pA)), 2)); + int8x8_t _pA3 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pA)), 3)); + int8x8_t _pA4 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pA)), 0)); + int8x8_t _pA5 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pA)), 1)); + int8x8_t _pA6 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pA)), 2)); + int8x8_t _pA7 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pA)), 3)); + int16x8_t _s0 = vmull_s8(_pA0, vget_low_s8(_pB0)); + int16x8_t _s1 = vmull_s8(_pA0, vget_high_s8(_pB0)); + int16x8_t _s2 = vmull_s8(_pA2, vget_low_s8(_pB2)); + int16x8_t _s3 = vmull_s8(_pA2, vget_high_s8(_pB2)); + int16x8_t _s4 = vmull_s8(_pA4, vget_low_s8(_pB4)); + int16x8_t _s5 = vmull_s8(_pA4, vget_high_s8(_pB4)); + int16x8_t _s6 = vmull_s8(_pA6, vget_low_s8(_pB6)); + int16x8_t _s7 = vmull_s8(_pA6, vget_high_s8(_pB6)); + _s0 = vmlal_s8(_s0, _pA1, vget_low_s8(_pB1)); + _s1 = vmlal_s8(_s1, _pA1, vget_high_s8(_pB1)); + _s2 = vmlal_s8(_s2, _pA3, vget_low_s8(_pB3)); + _s3 = vmlal_s8(_s3, _pA3, vget_high_s8(_pB3)); + _s4 = vmlal_s8(_s4, _pA5, vget_low_s8(_pB5)); + _s5 = vmlal_s8(_s5, _pA5, vget_high_s8(_pB5)); + _s6 = vmlal_s8(_s6, _pA7, vget_low_s8(_pB7)); + _s7 = vmlal_s8(_s7, _pA7, vget_high_s8(_pB7)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + _sum4 = vpadalq_s16(_sum4, _s4); + _sum5 = vpadalq_s16(_sum5, _s5); + _sum6 = vpadalq_s16(_sum6, _s6); + _sum7 = vpadalq_s16(_sum7, _s7); + + pA += 16; + pB += 128; + } + for (; kk + 7 < max_kk; kk += 8) + { + int8x8_t _pA = vld1_s8(pA); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + int8x16_t _pB2 = vld1q_s8(pB + 32); + int8x16_t _pB3 = vld1q_s8(pB + 48); + + int8x8_t _pA0 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 0)); + int8x8_t _pA1 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 1)); + int8x8_t _pA2 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 2)); + int8x8_t _pA3 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 3)); + int16x8_t _s0 = vmull_s8(_pA0, vget_low_s8(_pB0)); + int16x8_t _s1 = vmull_s8(_pA0, vget_high_s8(_pB0)); + int16x8_t _s2 = vmull_s8(_pA2, vget_low_s8(_pB2)); + int16x8_t _s3 = vmull_s8(_pA2, vget_high_s8(_pB2)); + _s0 = vmlal_s8(_s0, _pA1, vget_low_s8(_pB1)); + _s1 = vmlal_s8(_s1, _pA1, vget_high_s8(_pB1)); + _s2 = vmlal_s8(_s2, _pA3, vget_low_s8(_pB3)); + _s3 = vmlal_s8(_s3, _pA3, vget_high_s8(_pB3)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + + pA += 8; + pB += 64; + } + _sum0 = vaddq_s32(_sum0, _sum2); + _sum1 = vaddq_s32(_sum1, _sum3); + _sum0 = vaddq_s32(_sum0, _sum4); + _sum1 = vaddq_s32(_sum1, _sum5); + _sum0 = vaddq_s32(_sum0, _sum6); + _sum1 = vaddq_s32(_sum1, _sum7); + } +#endif // __ARM_FEATURE_DOTPROD + for (; kk + 3 < max_kk; kk += 4) + { + int8x8_t _pA = vreinterpret_s8_s32(vdup_lane_s32(vreinterpret_s32_s8(vld1_s8(pA)), 0)); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + +#if __ARM_FEATURE_DOTPROD + _sum0 = vdotq_lane_s32(_sum0, _pB0, _pA, 0); + _sum1 = vdotq_lane_s32(_sum1, _pB1, _pA, 0); +#else // __ARM_FEATURE_DOTPROD + int8x8_t _pA0 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 0)); + int8x8_t _pA1 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 1)); + int16x8_t _s0 = vmull_s8(_pA0, vget_low_s8(_pB0)); + int16x8_t _s1 = vmull_s8(_pA0, vget_high_s8(_pB0)); + _s0 = vmlal_s8(_s0, _pA1, vget_low_s8(_pB1)); + _s1 = vmlal_s8(_s1, _pA1, vget_high_s8(_pB1)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); +#endif // __ARM_FEATURE_DOTPROD + + pA += 4; + pB += 32; + } + for (; kk + 1 < max_kk; kk += 2) + { + int8x8_t _pA = vreinterpret_s8_s16(vld1_dup_s16((const short*)pA)); + int8x16_t _pB = vld1q_s8(pB); + + int16x8_t _s0 = vmull_s8(_pA, vget_low_s8(_pB)); + int16x8_t _s1 = vmull_s8(_pA, vget_high_s8(_pB)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + + pA += 2; + pB += 16; + } + for (; kk < max_kk; kk += 1) + { + int8x8_t _pA = vld1_dup_s8(pA); + int8x8_t _pB = vld1_s8(pB); + + int16x8_t _s0 = vmull_s8(_pA, _pB); + _sum0 = vaddw_s16(_sum0, vget_low_s16(_s0)); + _sum1 = vaddw_s16(_sum1, vget_high_s16(_s0)); + + pA += 1; + pB += 8; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + + outptr += 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + } + + const signed char* pA = pAT; + int kk = 0; +#if __ARM_FEATURE_DOTPROD + { +#if __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum00 = vdupq_n_s32(0); + int32x4_t _sum01 = vdupq_n_s32(0); +#endif // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 7 < max_kk; kk += 8) + { + int8x8_t _pA = vld1_s8(pA); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + +#if __ARM_FEATURE_MATMUL_INT8 + int8x16_t _pAA = vcombine_s8(_pA, _pA); + _sum00 = vdotq_s32(_sum00, _pAA, _pB0); + _sum01 = vdotq_s32(_sum01, _pAA, _pB1); +#else // __ARM_FEATURE_MATMUL_INT8 + _sum0 = vdotq_lane_s32(_sum0, _pB0, _pA, 0); + _sum0 = vdotq_lane_s32(_sum0, _pB1, _pA, 1); +#endif // __ARM_FEATURE_MATMUL_INT8 + + pA += 8; + pB += 32; + } +#if __ARM_FEATURE_MATMUL_INT8 + _sum0 = vaddq_s32(_sum0, vpaddq_s32(_sum00, _sum01)); +#endif // __ARM_FEATURE_MATMUL_INT8 + } +#else // __ARM_FEATURE_DOTPROD + { + int32x4_t _sum1 = vdupq_n_s32(0); + int32x4_t _sum2 = vdupq_n_s32(0); + int32x4_t _sum3 = vdupq_n_s32(0); + for (; kk + 15 < max_kk; kk += 16) + { + // TODO + // __builtin_prefetch(pA + 16); + // __builtin_prefetch(pB + 64); + int8x16_t _pA = vld1q_s8(pA); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + int8x16_t _pB2 = vld1q_s8(pB + 32); + int8x16_t _pB3 = vld1q_s8(pB + 48); + + int8x8_t _pA0 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pA)), 0)); + int8x8_t _pA1 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pA)), 1)); + int8x8_t _pA2 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pA)), 2)); + int8x8_t _pA3 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pA)), 3)); + int8x8_t _pA4 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pA)), 0)); + int8x8_t _pA5 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pA)), 1)); + int8x8_t _pA6 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pA)), 2)); + int8x8_t _pA7 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pA)), 3)); + int16x8_t _s0 = vmull_s8(_pA0, vget_low_s8(_pB0)); + int16x8_t _s1 = vmull_s8(_pA2, vget_low_s8(_pB1)); + int16x8_t _s2 = vmull_s8(_pA4, vget_low_s8(_pB2)); + int16x8_t _s3 = vmull_s8(_pA6, vget_low_s8(_pB3)); + _s0 = vmlal_s8(_s0, _pA1, vget_high_s8(_pB0)); + _s1 = vmlal_s8(_s1, _pA3, vget_high_s8(_pB1)); + _s2 = vmlal_s8(_s2, _pA5, vget_high_s8(_pB2)); + _s3 = vmlal_s8(_s3, _pA7, vget_high_s8(_pB3)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + + pA += 16; + pB += 64; + } + for (; kk + 7 < max_kk; kk += 8) + { + int8x8_t _pA = vld1_s8(pA); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + + int8x8_t _pA0 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 0)); + int8x8_t _pA1 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 1)); + int8x8_t _pA2 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 2)); + int8x8_t _pA3 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 3)); + int16x8_t _s0 = vmull_s8(_pA0, vget_low_s8(_pB0)); + int16x8_t _s1 = vmull_s8(_pA2, vget_low_s8(_pB1)); + _s0 = vmlal_s8(_s0, _pA1, vget_high_s8(_pB0)); + _s1 = vmlal_s8(_s1, _pA3, vget_high_s8(_pB1)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + + pA += 8; + pB += 32; + } + _sum0 = vaddq_s32(_sum0, _sum1); + _sum0 = vaddq_s32(_sum0, _sum2); + _sum0 = vaddq_s32(_sum0, _sum3); + } +#endif // __ARM_FEATURE_DOTPROD + for (; kk + 3 < max_kk; kk += 4) + { + int8x8_t _pA = vld1_s8(pA); + int8x16_t _pB = vld1q_s8(pB); + +#if __ARM_FEATURE_DOTPROD + _sum0 = vdotq_lane_s32(_sum0, _pB, _pA, 0); +#else // __ARM_FEATURE_DOTPROD + int8x8_t _pA0 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 0)); + int8x8_t _pA1 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 1)); + int16x8_t _s0 = vmull_s8(_pA0, vget_low_s8(_pB)); + _s0 = vmlal_s8(_s0, _pA1, vget_high_s8(_pB)); + _sum0 = vpadalq_s16(_sum0, _s0); +#endif // __ARM_FEATURE_DOTPROD + + pA += 4; + pB += 16; + } + for (; kk + 1 < max_kk; kk += 2) + { + int8x8_t _pA = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vld1_s8(pA)), 0)); + int8x8_t _pB = vld1_s8(pB); + + int16x8_t _s0 = vmull_s8(_pA, _pB); + _sum0 = vpadalq_s16(_sum0, _s0); + + pA += 2; + pB += 8; + } + for (; kk < max_kk; kk += 1) + { + int8x8_t _pA = vld1_dup_s8(pA); + int8x8_t _pB = vreinterpret_s8_s32(vdup_lane_s32(vreinterpret_s32_s8(vld1_s8(pB)), 0)); + + int16x8_t _s0 = vmull_s8(_pA, _pB); + _sum0 = vaddw_s16(_sum0, vget_low_s16(_s0)); + + pA += 1; + pB += 4; + } + + vst1q_s32(outptr, _sum0); + + outptr += 4; + } +#endif // __ARM_NEON + for (; jj + 1 < max_jj; jj += 2) + { +#if __ARM_NEON + int32x2_t _sum; + + if (k == 0) + { + _sum = vdup_n_s32(0); + } + else + { + _sum = vld1_s32(outptr); + } +#else // __ARM_NEON + int sum0; + int sum1; + + if (k == 0) + { + sum0 = 0; + sum1 = 0; + } + else + { + sum0 = outptr[0]; + sum1 = outptr[1]; + } +#endif // __ARM_NEON + + const signed char* pA = pAT; + int kk = 0; +#if __ARM_NEON +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + { + int32x4_t _sum0 = vdupq_n_s32(0); + for (; kk + 7 < max_kk; kk += 8) + { + int8x8_t _pA = vld1_s8(pA); + int8x16_t _pB = vld1q_s8(pB); + + int8x16_t _pAA = vcombine_s8(_pA, _pA); + + _sum0 = vdotq_s32(_sum0, _pAA, _pB); + + pA += 8; + pB += 16; + } + int32x2_t _ss = vpadd_s32(vget_low_s32(_sum0), vget_high_s32(_sum0)); + _sum = vadd_s32(_sum, _ss); + } +#else // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 7 < max_kk; kk += 8) + { + int8x8_t _pA = vld1_s8(pA); + int8x16_t _pB = vld1q_s8(pB); + + _sum = vdot_lane_s32(_sum, vget_low_s8(_pB), _pA, 0); + _sum = vdot_lane_s32(_sum, vget_high_s8(_pB), _pA, 1); + + pA += 8; + pB += 16; + } +#endif // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 3 < max_kk; kk += 4) + { + int8x8_t _pA = vreinterpret_s8_s32(vld1_dup_s32((const int*)pA)); + int8x8_t _pB = vld1_s8(pB); + + _sum = vdot_s32(_sum, _pA, _pB); + + pA += 4; + pB += 8; + } +#else // __ARM_FEATURE_DOTPROD + { + int32x4_t _sum0 = vdupq_n_s32(0); + int32x4_t _sum1 = vdupq_n_s32(0); + for (; kk + 15 < max_kk; kk += 16) + { + int8x16_t _pA = vld1q_s8(pA); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + + int16x8x2_t _pAA = vzipq_s16(vreinterpretq_s16_s8(_pA), vreinterpretq_s16_s8(_pA)); + + int8x8_t _pA0 = vreinterpret_s8_s16(vget_low_s16(_pAA.val[0])); + int8x8_t _pA1 = vreinterpret_s8_s16(vget_high_s16(_pAA.val[0])); + int8x8_t _pA2 = vreinterpret_s8_s16(vget_low_s16(_pAA.val[1])); + int8x8_t _pA3 = vreinterpret_s8_s16(vget_high_s16(_pAA.val[1])); + + int16x8_t _s0 = vmull_s8(_pA0, vget_low_s8(_pB0)); + int16x8_t _s1 = vmull_s8(_pA2, vget_low_s8(_pB1)); + _s0 = vmlal_s8(_s0, _pA1, vget_high_s8(_pB0)); + _s1 = vmlal_s8(_s1, _pA3, vget_high_s8(_pB1)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + + pA += 16; + pB += 32; + } + _sum0 = vaddq_s32(_sum0, _sum1); + for (; kk + 7 < max_kk; kk += 8) + { + int8x8_t _pA = vld1_s8(pA); + int8x16_t _pB = vld1q_s8(pB); + + int16x4x2_t _pAA = vzip_s16(vreinterpret_s16_s8(_pA), vreinterpret_s16_s8(_pA)); + + int8x8_t _pA0 = vreinterpret_s8_s16(_pAA.val[0]); + int8x8_t _pA1 = vreinterpret_s8_s16(_pAA.val[1]); + + int16x8_t _s0 = vmull_s8(_pA0, vget_low_s8(_pB)); + _s0 = vmlal_s8(_s0, _pA1, vget_high_s8(_pB)); + _sum0 = vpadalq_s16(_sum0, _s0); + + pA += 8; + pB += 16; + } + for (; kk + 3 < max_kk; kk += 4) + { + int8x8_t _pA = vreinterpret_s8_s32(vdup_lane_s32(vreinterpret_s32_s8(vld1_s8(pA)), 0)); + int8x8_t _pB = vld1_s8(pB); + + _pA = vreinterpret_s8_s16(vzip_s16(vreinterpret_s16_s8(_pA), vreinterpret_s16_s8(_pA)).val[0]); + + int16x8_t _s0 = vmull_s8(_pA, _pB); + _sum0 = vpadalq_s16(_sum0, _s0); + + pA += 4; + pB += 8; + } + int32x2_t _ss = vadd_s32(vget_low_s32(_sum0), vget_high_s32(_sum0)); + _sum = vadd_s32(_sum, _ss); + } +#endif // __ARM_FEATURE_DOTPROD + int sum0 = vget_lane_s32(_sum, 0); + int sum1 = vget_lane_s32(_sum, 1); + for (; kk + 1 < max_kk; kk += 2) + { + sum0 += pA[0] * pB[0]; + sum0 += pA[1] * pB[1]; + sum1 += pA[0] * pB[2]; + sum1 += pA[1] * pB[3]; + pA += 2; + pB += 4; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk += 1) + { + sum0 += pA[0] * pB[0]; + sum1 += pA[0] * pB[1]; + pA += 1; + pB += 2; + } + + outptr[0] = sum0; + outptr[1] = sum1; + + outptr += 2; + } + for (; jj < max_jj; jj += 1) + { + int sum; + + if (k == 0) + { + sum = 0; + } + else + { + sum = outptr[0]; + } + + const signed char* pA = pAT; + int kk = 0; +#if __ARM_NEON + int32x4_t _sum = vdupq_n_s32(0); + int32x4_t _sum1 = vdupq_n_s32(0); + for (; kk + 31 < max_kk; kk += 32) + { + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA1 = vld1q_s8(pA + 16); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + +#if __ARM_FEATURE_DOTPROD + _sum = vdotq_s32(_sum, _pA0, _pB0); + _sum1 = vdotq_s32(_sum1, _pA1, _pB1); +#else // __ARM_FEATURE_DOTPROD + int16x8_t _s0 = vmull_s8(vget_low_s8(_pA0), vget_low_s8(_pB0)); + int16x8_t _s1 = vmull_s8(vget_low_s8(_pA1), vget_low_s8(_pB1)); + _s0 = vmlal_s8(_s0, vget_high_s8(_pA0), vget_high_s8(_pB0)); + _s1 = vmlal_s8(_s1, vget_high_s8(_pA1), vget_high_s8(_pB1)); + _sum = vpadalq_s16(_sum, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); +#endif // __ARM_FEATURE_DOTPROD + + pA += 32; + pB += 32; + } + _sum = vaddq_s32(_sum, _sum1); + for (; kk + 15 < max_kk; kk += 16) + { + int8x16_t _pA = vld1q_s8(pA); + int8x16_t _pB = vld1q_s8(pB); + +#if __ARM_FEATURE_DOTPROD + _sum = vdotq_s32(_sum, _pA, _pB); +#else // __ARM_FEATURE_DOTPROD + int16x8_t _s0 = vmull_s8(vget_low_s8(_pA), vget_low_s8(_pB)); + _s0 = vmlal_s8(_s0, vget_high_s8(_pA), vget_high_s8(_pB)); + _sum = vpadalq_s16(_sum, _s0); +#endif // __ARM_FEATURE_DOTPROD + + pA += 16; + pB += 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + int8x8_t _pA = vld1_s8(pA); + int8x8_t _pB = vld1_s8(pB); + + int16x8_t _s0 = vmull_s8(_pA, _pB); + _sum = vpadalq_s16(_sum, _s0); + + pA += 8; + pB += 8; + } +#if __aarch64__ + sum += vaddvq_s32(_sum); +#else + int32x2_t _ss = vadd_s32(vget_low_s32(_sum), vget_high_s32(_sum)); + _ss = vpadd_s32(_ss, _ss); + sum += vget_lane_s32(_ss, 0); +#endif +#endif // __ARM_NEON + for (; kk < max_kk; kk += 1) + { + sum += pA[0] * pB[0]; + pA += 1; + pB += 1; + } + + outptr[0] = sum; + + outptr += 1; + } + + pAT += max_kk; + } +} + +static void get_optimal_tile_mnk_int8(int M, int N, int K, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int& TILE_M, int& TILE_N, int& TILE_K, int nT) +{ + // resolve optimal tile size from cache size + const size_t l2_cache_size = get_cpu_level2_cache_size(); + + if (nT == 0) + nT = get_physical_big_cpu_count(); + + int tile_size = (int)sqrtf((float)l2_cache_size / (2 * sizeof(signed char) + sizeof(int))); + + TILE_M = std::max(8, tile_size / 8 * 8); +#if __aarch64__ + TILE_N = std::max(8, tile_size / 8 * 8); +#else + TILE_N = std::max(4, tile_size / 4 * 4); +#endif + TILE_K = std::max(8, tile_size / 8 * 8); + + if (K > 0) + { + int nn_K = (K + TILE_K - 1) / TILE_K; + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 7) / 8 * 8); + + if (nn_K == 1) + { + tile_size = (int)((float)l2_cache_size / 2 / sizeof(signed char) / TILE_K); + + TILE_M = std::max(8, tile_size / 8 * 8); +#if __aarch64__ + TILE_N = std::max(8, tile_size / 8 * 8); +#else + TILE_N = std::max(4, tile_size / 4 * 4); +#endif + } + } + + TILE_M *= std::min(nT, get_physical_cpu_count()); + + if (M > 0) + { + int nn_M = (M + TILE_M - 1) / TILE_M; + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 7) / 8 * 8); + } + + if (N > 0) + { + int nn_N = (N + TILE_N - 1) / TILE_N; +#if __aarch64__ + TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 7) / 8 * 8); +#else + TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 3) / 4 * 4); +#endif + } + + if (nT > 1) + { + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 7) / 8 * 8); + } + + // always take constant TILE_M/N/K value when provided + if (constant_TILE_M > 0) + { + TILE_M = (constant_TILE_M + 7) / 8 * 8; + } + + if (constant_TILE_N > 0) + { +#if __aarch64__ + TILE_N = (constant_TILE_N + 7) / 8 * 8; +#else + TILE_N = (constant_TILE_N + 3) / 4 * 4; +#endif + } + + if (constant_TILE_K > 0) + { + TILE_K = (constant_TILE_K + 7) / 8 * 8; + } +} diff --git a/src/layer/arm/gemm_int8_bf16s.h b/src/layer/arm/gemm_int8_bf16s.h new file mode 100644 index 000000000000..350f20ab4c0f --- /dev/null +++ b/src/layer/arm/gemm_int8_bf16s.h @@ -0,0 +1,8566 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 +void pack_A_tile_bf16_to_int8_i8mm(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales); +void transpose_pack_A_tile_bf16_to_int8_i8mm(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales); +void pack_B_tile_bf16_to_int8_i8mm(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale); +void transpose_pack_B_tile_bf16_to_int8_i8mm(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale); +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 +void pack_A_tile_bf16_to_int8_asimddp(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales); +void transpose_pack_A_tile_bf16_to_int8_asimddp(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales); +void pack_B_tile_bf16_to_int8_asimddp(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale); +void transpose_pack_B_tile_bf16_to_int8_asimddp(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale); +void unpack_output_tile_int32_to_bf16_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta); +void transpose_unpack_output_tile_int32_to_bf16_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta); +#endif + +static void compute_A_tile_bf16_int8_scales(const Mat& A, Mat& scales, float B_scale, Mat& out_descales, int i, int max_ii) +{ + const int elempack = A.elempack; + const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + const int K = A.w; + + // NCNN_LOGE("compute_A_tile_bf16_int8_scales %d %d", max_ii, elempack); + + const float v127_B_scale = 127.f * B_scale; + + float* ps = scales; + float* pods = out_descales; + +#if __ARM_NEON + if (elempack == 4) + { +#if __aarch64__ + float32x4_t _v127 = vdupq_n_f32(127.f); + float32x4_t _v127_B_scale = vdupq_n_f32(v127_B_scale); +#endif + + for (int ii = 0; ii + 3 < max_ii; ii += 4) + { + const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep; + + float32x4_t _absmax0 = vdupq_n_f32(0.f); + float32x4_t _absmax1 = vdupq_n_f32(0.f); + float32x4_t _absmax2 = vdupq_n_f32(0.f); + float32x4_t _absmax3 = vdupq_n_f32(0.f); + int kk = 0; + for (; kk + 3 < K; kk += 4) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + _absmax2 = vmaxq_f32(_absmax2, vabsq_f32(_p2)); + _absmax3 = vmaxq_f32(_absmax3, vabsq_f32(_p3)); + p0 += 16; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax2); + _absmax1 = vmaxq_f32(_absmax1, _absmax3); + for (; kk + 1 < K; kk += 2) + { + uint16x8_t _p = vld1q_u16(p0); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + p0 += 8; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax1); + for (; kk < K; kk++) + { + float32x4_t _p = bfloat2float(vld1_u16(p0)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p)); + p0 += 4; + } + +#if __aarch64__ + float32x4_t _scale = vdivq_f32(_v127, _absmax0); + float32x4_t _out_descale = vdivq_f32(_absmax0, _v127_B_scale); + + vst1q_f32(ps, _scale); + vst1q_f32(pods, _out_descale); +#else + // float32x4_t _recp_absmax = vrecpeq_f32(_absmax0); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax0, _recp_absmax), _recp_absmax); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax0, _recp_absmax), _recp_absmax); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax0, _recp_absmax), _recp_absmax); + // float32x4_t _scale = vmulq_f32(_v127, _recp_absmax); + // float32x4_t _out_descale = vmulq_f32(_absmax0, _recp_v127_B_scale); + + float tmp[4]; + vst1q_f32(tmp, _absmax0); + + ps[0] = 127.f / tmp[0]; + ps[1] = 127.f / tmp[1]; + ps[2] = 127.f / tmp[2]; + ps[3] = 127.f / tmp[3]; + + pods[0] = tmp[0] / v127_B_scale; + pods[1] = tmp[1] / v127_B_scale; + pods[2] = tmp[2] / v127_B_scale; + pods[3] = tmp[3] / v127_B_scale; + +#endif + ps += 4; + pods += 4; + } + } +#endif // __ARM_NEON + if (elempack == 1) + { + for (int ii = 0; ii < max_ii; ii++) + { + const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep; + + float absmax = 0.f; + int kk = 0; +#if __ARM_NEON + float32x4_t _absmax0 = vdupq_n_f32(0.f); + float32x4_t _absmax1 = vdupq_n_f32(0.f); + float32x4_t _absmax2 = vdupq_n_f32(0.f); + float32x4_t _absmax3 = vdupq_n_f32(0.f); + for (; kk + 15 < K; kk += 16) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + _absmax2 = vmaxq_f32(_absmax2, vabsq_f32(_p2)); + _absmax3 = vmaxq_f32(_absmax3, vabsq_f32(_p3)); + p0 += 16; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax2); + _absmax1 = vmaxq_f32(_absmax1, _absmax3); + for (; kk + 7 < K; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + p0 += 8; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax1); + for (; kk + 3 < K; kk += 4) + { + float32x4_t _p = bfloat2float(vld1_u16(p0)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p)); + p0 += 4; + } + float32x2_t _aa = vmax_f32(vget_low_f32(_absmax0), vget_high_f32(_absmax0)); + absmax = std::max(absmax, std::max(vget_lane_f32(_aa, 0), vget_lane_f32(_aa, 1))); +#endif // __ARM_NEON + for (; kk < K; kk++) + { + absmax = std::max(absmax, (float)fabsf(bfloat16_to_float32(p0[0]))); + p0++; + } + + ps[0] = 127.f / absmax; + pods[0] = absmax / v127_B_scale; + ps++; + pods++; + } + } +} + +static void pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + pack_A_tile_bf16_to_int8_i8mm(A, AT, i, max_ii, k, max_kk, scales); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + pack_A_tile_bf16_to_int8_asimddp(A, AT, i, max_ii, k, max_kk, scales); + return; + } +#endif + + const int elempack = A.elempack; + const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + + // NCNN_LOGE("pack_A_tile_bf16_to_int8 %d %d", max_ii, elempack); + + signed char* pp = AT; + + int ii = 0; +#if __ARM_NEON + for (; ii + 7 < max_ii; ii += 8) + { + const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k * elempack; + + float32x4_t _scale0 = vld1q_f32((const float*)scales + ii); + float32x4_t _scale1 = vld1q_f32((const float*)scales + ii + 4); + + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { +#if __ARM_FEATURE_DOTPROD + uint16x8x4_t _p = vld4q_u16(p0); + uint16x8x4_t _q = vld4q_u16(p0 + A_hstep * 4); + + float32x4_t _p0 = vmulq_laneq_f32(bfloat2float(vget_low_u16(_p.val[0])), _scale0, 0); + float32x4_t _p1 = vmulq_laneq_f32(bfloat2float(vget_low_u16(_p.val[1])), _scale0, 1); + float32x4_t _p2 = vmulq_laneq_f32(bfloat2float(vget_low_u16(_p.val[2])), _scale0, 2); + float32x4_t _p3 = vmulq_laneq_f32(bfloat2float(vget_low_u16(_p.val[3])), _scale0, 3); + float32x4_t _p4 = vmulq_laneq_f32(bfloat2float(vget_high_u16(_p.val[0])), _scale0, 0); + float32x4_t _p5 = vmulq_laneq_f32(bfloat2float(vget_high_u16(_p.val[1])), _scale0, 1); + float32x4_t _p6 = vmulq_laneq_f32(bfloat2float(vget_high_u16(_p.val[2])), _scale0, 2); + float32x4_t _p7 = vmulq_laneq_f32(bfloat2float(vget_high_u16(_p.val[3])), _scale0, 3); + float32x4_t _p8 = vmulq_laneq_f32(bfloat2float(vget_low_u16(_q.val[0])), _scale1, 0); + float32x4_t _p9 = vmulq_laneq_f32(bfloat2float(vget_low_u16(_q.val[1])), _scale1, 1); + float32x4_t _pa = vmulq_laneq_f32(bfloat2float(vget_low_u16(_q.val[2])), _scale1, 2); + float32x4_t _pb = vmulq_laneq_f32(bfloat2float(vget_low_u16(_q.val[3])), _scale1, 3); + float32x4_t _pc = vmulq_laneq_f32(bfloat2float(vget_high_u16(_q.val[0])), _scale1, 0); + float32x4_t _pd = vmulq_laneq_f32(bfloat2float(vget_high_u16(_q.val[1])), _scale1, 1); + float32x4_t _pe = vmulq_laneq_f32(bfloat2float(vget_high_u16(_q.val[2])), _scale1, 2); + float32x4_t _pf = vmulq_laneq_f32(bfloat2float(vget_high_u16(_q.val[3])), _scale1, 3); + +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p4); + int8x8_t _r1 = float2int8(_p1, _p5); + int8x8_t _r2 = float2int8(_p2, _p6); + int8x8_t _r3 = float2int8(_p3, _p7); + int8x8_t _r4 = float2int8(_p8, _pc); + int8x8_t _r5 = float2int8(_p9, _pd); + int8x8_t _r6 = float2int8(_pa, _pe); + int8x8_t _r7 = float2int8(_pb, _pf); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p8, _p9); + int8x8_t _r3 = float2int8(_pa, _pb); + int8x8_t _r4 = float2int8(_p4, _p5); + int8x8_t _r5 = float2int8(_p6, _p7); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); +#endif // __ARM_FEATURE_MATMUL_INT8 + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); +#else // __ARM_FEATURE_DOTPROD + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + 16); + uint16x8_t _s = vld1q_u16(p0 + 24); + uint16x8_t _t = vld1q_u16(p0 + A_hstep * 4); + uint16x8_t _u = vld1q_u16(p0 + A_hstep * 4 + 8); + uint16x8_t _v = vld1q_u16(p0 + A_hstep * 4 + 16); + uint16x8_t _w = vld1q_u16(p0 + A_hstep * 4 + 24); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + float32x4_t _p8 = bfloat2float(vget_low_u16(_t)); + float32x4_t _p9 = bfloat2float(vget_high_u16(_t)); + float32x4_t _pa = bfloat2float(vget_low_u16(_u)); + float32x4_t _pb = bfloat2float(vget_high_u16(_u)); + float32x4_t _pc = bfloat2float(vget_low_u16(_v)); + float32x4_t _pd = bfloat2float(vget_high_u16(_v)); + float32x4_t _pe = bfloat2float(vget_low_u16(_w)); + float32x4_t _pf = bfloat2float(vget_high_u16(_w)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale0); + _p2 = vmulq_f32(_p2, _scale0); + _p3 = vmulq_f32(_p3, _scale0); + _p4 = vmulq_f32(_p4, _scale0); + _p5 = vmulq_f32(_p5, _scale0); + _p6 = vmulq_f32(_p6, _scale0); + _p7 = vmulq_f32(_p7, _scale0); + _p8 = vmulq_f32(_p8, _scale1); + _p9 = vmulq_f32(_p9, _scale1); + _pa = vmulq_f32(_pa, _scale1); + _pb = vmulq_f32(_pb, _scale1); + _pc = vmulq_f32(_pc, _scale1); + _pd = vmulq_f32(_pd, _scale1); + _pe = vmulq_f32(_pe, _scale1); + _pf = vmulq_f32(_pf, _scale1); + + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p8), float2int8(_p2, _pa)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p9), float2int8(_p3, _pb)); + int8x16x2_t _r23; + _r23.val[0] = vcombine_s8(float2int8(_p4, _pc), float2int8(_p6, _pe)); + _r23.val[1] = vcombine_s8(float2int8(_p5, _pd), float2int8(_p7, _pf)); + + vst2q_s8(pp, _r01); + vst2q_s8(pp + 32, _r23); +#endif // __ARM_FEATURE_DOTPROD + + pp += 64; + p0 += 32; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + uint16x4x4_t _p = vld4_u16(p0); + uint16x4x4_t _q = vld4_u16(p0 + A_hstep * 4); + + float32x4_t _p0 = vmulq_laneq_f32(bfloat2float(_p.val[0]), _scale0, 0); + float32x4_t _p1 = vmulq_laneq_f32(bfloat2float(_p.val[1]), _scale0, 1); + float32x4_t _p2 = vmulq_laneq_f32(bfloat2float(_p.val[2]), _scale0, 2); + float32x4_t _p3 = vmulq_laneq_f32(bfloat2float(_p.val[3]), _scale0, 3); + float32x4_t _p4 = vmulq_laneq_f32(bfloat2float(_q.val[0]), _scale1, 0); + float32x4_t _p5 = vmulq_laneq_f32(bfloat2float(_q.val[1]), _scale1, 1); + float32x4_t _p6 = vmulq_laneq_f32(bfloat2float(_q.val[2]), _scale1, 2); + float32x4_t _p7 = vmulq_laneq_f32(bfloat2float(_q.val[3]), _scale1, 3); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); +#else // __ARM_FEATURE_DOTPROD + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + A_hstep * 4); + uint16x8_t _s = vld1q_u16(p0 + A_hstep * 4 + 8); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale0); + _p2 = vmulq_f32(_p2, _scale0); + _p3 = vmulq_f32(_p3, _scale0); + _p4 = vmulq_f32(_p4, _scale1); + _p5 = vmulq_f32(_p5, _scale1); + _p6 = vmulq_f32(_p6, _scale1); + _p7 = vmulq_f32(_p7, _scale1); + + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p4), float2int8(_p2, _p6)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p5), float2int8(_p3, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += 16; + } + for (; kk + 1 < max_kk; kk += 2) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + A_hstep * 4); + + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p0n = bfloat2float(vget_high_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p1n = bfloat2float(vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale0); + _p0n = vmulq_f32(_p0n, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + _p1n = vmulq_f32(_p1n, _scale1); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p1); + _r01.val[1] = float2int8(_p0n, _p1n); + + vst2_s8(pp, _r01); + + pp += 16; + p0 += 8; + } + for (; kk < max_kk; kk++) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + A_hstep * 4)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + A_hstep); + uint16x8_t _r = vld1q_u16(p0 + A_hstep * 2); + uint16x8_t _s = vld1q_u16(p0 + A_hstep * 3); + uint16x8_t _t = vld1q_u16(p0 + A_hstep * 4); + uint16x8_t _u = vld1q_u16(p0 + A_hstep * 5); + uint16x8_t _v = vld1q_u16(p0 + A_hstep * 6); + uint16x8_t _w = vld1q_u16(p0 + A_hstep * 7); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + float32x4_t _p8 = bfloat2float(vget_low_u16(_t)); + float32x4_t _p9 = bfloat2float(vget_high_u16(_t)); + float32x4_t _pa = bfloat2float(vget_low_u16(_u)); + float32x4_t _pb = bfloat2float(vget_high_u16(_u)); + float32x4_t _pc = bfloat2float(vget_low_u16(_v)); + float32x4_t _pd = bfloat2float(vget_high_u16(_v)); + float32x4_t _pe = bfloat2float(vget_low_u16(_w)); + float32x4_t _pf = bfloat2float(vget_high_u16(_w)); + +#if __aarch64__ + _p0 = vmulq_laneq_f32(_p0, _scale0, 0); + _p1 = vmulq_laneq_f32(_p1, _scale0, 0); + _p2 = vmulq_laneq_f32(_p2, _scale0, 1); + _p3 = vmulq_laneq_f32(_p3, _scale0, 1); + _p4 = vmulq_laneq_f32(_p4, _scale0, 2); + _p5 = vmulq_laneq_f32(_p5, _scale0, 2); + _p6 = vmulq_laneq_f32(_p6, _scale0, 3); + _p7 = vmulq_laneq_f32(_p7, _scale0, 3); + _p8 = vmulq_laneq_f32(_p8, _scale1, 0); + _p9 = vmulq_laneq_f32(_p9, _scale1, 0); + _pa = vmulq_laneq_f32(_pa, _scale1, 1); + _pb = vmulq_laneq_f32(_pb, _scale1, 1); + _pc = vmulq_laneq_f32(_pc, _scale1, 2); + _pd = vmulq_laneq_f32(_pd, _scale1, 2); + _pe = vmulq_laneq_f32(_pe, _scale1, 3); + _pf = vmulq_laneq_f32(_pf, _scale1, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale0), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale0), 0); + _p2 = vmulq_lane_f32(_p2, vget_low_f32(_scale0), 1); + _p3 = vmulq_lane_f32(_p3, vget_low_f32(_scale0), 1); + _p4 = vmulq_lane_f32(_p4, vget_high_f32(_scale0), 0); + _p5 = vmulq_lane_f32(_p5, vget_high_f32(_scale0), 0); + _p6 = vmulq_lane_f32(_p6, vget_high_f32(_scale0), 1); + _p7 = vmulq_lane_f32(_p7, vget_high_f32(_scale0), 1); + _p8 = vmulq_lane_f32(_p8, vget_low_f32(_scale1), 0); + _p9 = vmulq_lane_f32(_p9, vget_low_f32(_scale1), 0); + _pa = vmulq_lane_f32(_pa, vget_low_f32(_scale1), 1); + _pb = vmulq_lane_f32(_pb, vget_low_f32(_scale1), 1); + _pc = vmulq_lane_f32(_pc, vget_high_f32(_scale1), 0); + _pd = vmulq_lane_f32(_pd, vget_high_f32(_scale1), 0); + _pe = vmulq_lane_f32(_pe, vget_high_f32(_scale1), 1); + _pf = vmulq_lane_f32(_pf, vget_high_f32(_scale1), 1); +#endif + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p4, _p6); + int8x8_t _r2 = float2int8(_p8, _pa); + int8x8_t _r3 = float2int8(_pc, _pe); + int8x8_t _r4 = float2int8(_p1, _p3); + int8x8_t _r5 = float2int8(_p5, _p7); + int8x8_t _r6 = float2int8(_p9, _pb); + int8x8_t _r7 = float2int8(_pd, _pf); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p4, _p6)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p8, _pa)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_pc, _pe)); + int16x4_t _t4 = vreinterpret_s16_s8(float2int8(_p1, _p3)); + int16x4_t _t5 = vreinterpret_s16_s8(float2int8(_p5, _p7)); + int16x4_t _t6 = vreinterpret_s16_s8(float2int8(_p9, _pb)); + int16x4_t _t7 = vreinterpret_s16_s8(float2int8(_pd, _pf)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int16x4x2_t _t45 = vuzp_s16(_t4, _t5); + int16x4x2_t _t67 = vuzp_s16(_t6, _t7); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r2 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); + int8x8_t _r4 = vreinterpret_s8_s16(_t45.val[0]); + int8x8_t _r5 = vreinterpret_s8_s16(_t67.val[0]); + int8x8_t _r6 = vreinterpret_s8_s16(_t45.val[1]); + int8x8_t _r7 = vreinterpret_s8_s16(_t67.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); + + pp += 64; + p0 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + A_hstep)); + float32x4_t _p2 = bfloat2float(vld1_u16(p0 + A_hstep * 2)); + float32x4_t _p3 = bfloat2float(vld1_u16(p0 + A_hstep * 3)); + float32x4_t _p4 = bfloat2float(vld1_u16(p0 + A_hstep * 4)); + float32x4_t _p5 = bfloat2float(vld1_u16(p0 + A_hstep * 5)); + float32x4_t _p6 = bfloat2float(vld1_u16(p0 + A_hstep * 6)); + float32x4_t _p7 = bfloat2float(vld1_u16(p0 + A_hstep * 7)); + +#if __aarch64__ + _p0 = vmulq_laneq_f32(_p0, _scale0, 0); + _p1 = vmulq_laneq_f32(_p1, _scale0, 1); + _p2 = vmulq_laneq_f32(_p2, _scale0, 2); + _p3 = vmulq_laneq_f32(_p3, _scale0, 3); + _p4 = vmulq_laneq_f32(_p4, _scale1, 0); + _p5 = vmulq_laneq_f32(_p5, _scale1, 1); + _p6 = vmulq_laneq_f32(_p6, _scale1, 2); + _p7 = vmulq_laneq_f32(_p7, _scale1, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale0), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale0), 1); + _p2 = vmulq_lane_f32(_p2, vget_high_f32(_scale0), 0); + _p3 = vmulq_lane_f32(_p3, vget_high_f32(_scale0), 1); + _p4 = vmulq_lane_f32(_p4, vget_low_f32(_scale1), 0); + _p5 = vmulq_lane_f32(_p5, vget_low_f32(_scale1), 1); + _p6 = vmulq_lane_f32(_p6, vget_high_f32(_scale1), 0); + _p7 = vmulq_lane_f32(_p7, vget_high_f32(_scale1), 1); +#endif + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p4, _p5)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p6, _p7)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r2 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + + pp += 32; + p0 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[1], _p, 1); + _p = vsetq_lane_u16(p0[A_hstep], _p, 2); + _p = vsetq_lane_u16(p0[A_hstep + 1], _p, 3); + _p = vsetq_lane_u16(p0[A_hstep * 2], _p, 4); + _p = vsetq_lane_u16(p0[A_hstep * 2 + 1], _p, 5); + _p = vsetq_lane_u16(p0[A_hstep * 3], _p, 6); + _p = vsetq_lane_u16(p0[A_hstep * 3 + 1], _p, 7); + uint16x8_t _q = uint16x8_t(); + _q = vsetq_lane_u16(p0[A_hstep * 4], _q, 0); + _q = vsetq_lane_u16(p0[A_hstep * 4 + 1], _q, 1); + _q = vsetq_lane_u16(p0[A_hstep * 5], _q, 2); + _q = vsetq_lane_u16(p0[A_hstep * 5 + 1], _q, 3); + _q = vsetq_lane_u16(p0[A_hstep * 6], _q, 4); + _q = vsetq_lane_u16(p0[A_hstep * 6 + 1], _q, 5); + _q = vsetq_lane_u16(p0[A_hstep * 7], _q, 6); + _q = vsetq_lane_u16(p0[A_hstep * 7 + 1], _q, 7); + float32x4_t _p01 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p23 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p45 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p67 = bfloat2float(vget_high_u16(_q)); + + float32x4x2_t _scale01 = vzipq_f32(_scale0, _scale0); + float32x4x2_t _scale23 = vzipq_f32(_scale1, _scale1); + + _p01 = vmulq_f32(_p01, _scale01.val[0]); + _p23 = vmulq_f32(_p23, _scale01.val[1]); + _p45 = vmulq_f32(_p45, _scale23.val[0]); + _p67 = vmulq_f32(_p67, _scale23.val[1]); + + int8x8_t _r0 = float2int8(_p01, _p23); + int8x8_t _r1 = float2int8(_p45, _p67); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += 2; + } + for (; kk < max_kk; kk++) + { + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[A_hstep], _p, 1); + _p = vsetq_lane_u16(p0[A_hstep * 2], _p, 2); + _p = vsetq_lane_u16(p0[A_hstep * 3], _p, 3); + _p = vsetq_lane_u16(p0[A_hstep * 4], _p, 4); + _p = vsetq_lane_u16(p0[A_hstep * 5], _p, 5); + _p = vsetq_lane_u16(p0[A_hstep * 6], _p, 6); + _p = vsetq_lane_u16(p0[A_hstep * 7], _p, 7); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + + pp += 8; + p0++; + } + } + } + for (; ii + 3 < max_ii; ii += 4) + { + const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k * elempack; + + float32x4_t _scale = vld1q_f32((const float*)scales + ii); + + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { +#if __ARM_FEATURE_DOTPROD + uint16x8x4_t _p = vld4q_u16(p0); + + float32x4_t _p0 = vmulq_laneq_f32(bfloat2float(vget_low_u16(_p.val[0])), _scale, 0); + float32x4_t _p1 = vmulq_laneq_f32(bfloat2float(vget_low_u16(_p.val[1])), _scale, 1); + float32x4_t _p2 = vmulq_laneq_f32(bfloat2float(vget_low_u16(_p.val[2])), _scale, 2); + float32x4_t _p3 = vmulq_laneq_f32(bfloat2float(vget_low_u16(_p.val[3])), _scale, 3); + float32x4_t _p4 = vmulq_laneq_f32(bfloat2float(vget_high_u16(_p.val[0])), _scale, 0); + float32x4_t _p5 = vmulq_laneq_f32(bfloat2float(vget_high_u16(_p.val[1])), _scale, 1); + float32x4_t _p6 = vmulq_laneq_f32(bfloat2float(vget_high_u16(_p.val[2])), _scale, 2); + float32x4_t _p7 = vmulq_laneq_f32(bfloat2float(vget_high_u16(_p.val[3])), _scale, 3); + +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p4); + int8x8_t _r1 = float2int8(_p1, _p5); + int8x8_t _r2 = float2int8(_p2, _p6); + int8x8_t _r3 = float2int8(_p3, _p7); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#endif // __ARM_FEATURE_MATMUL_INT8 + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); +#else // __ARM_FEATURE_DOTPROD + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + 16); + uint16x8_t _s = vld1q_u16(p0 + 24); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p2), float2int8(_p4, _p6)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p3), float2int8(_p5, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += 32; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + uint16x4x4_t _p = vld4_u16(p0); + + float32x4_t _p0 = vmulq_laneq_f32(bfloat2float(_p.val[0]), _scale, 0); + float32x4_t _p1 = vmulq_laneq_f32(bfloat2float(_p.val[1]), _scale, 1); + float32x4_t _p2 = vmulq_laneq_f32(bfloat2float(_p.val[2]), _scale, 2); + float32x4_t _p3 = vmulq_laneq_f32(bfloat2float(_p.val[3]), _scale, 3); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); +#else // __ARM_FEATURE_DOTPROD + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p2); + _r01.val[1] = float2int8(_p1, _p3); + + vst2_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 16; + p0 += 16; + } + for (; kk + 1 < max_kk; kk += 2) + { + uint16x8_t _p = vld1q_u16(p0); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + float32x4x2_t _p01 = vzipq_f32(_p0, _p1); + + int8x8_t _r01 = float2int8(_p01.val[0], _p01.val[1]); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += 8; + } + for (; kk < max_kk; kk++) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + _p0 = vmulq_f32(_p0, _scale); + int8x8_t _r0 = float2int8(_p0, _p0); + + pp[0] = vget_lane_s8(_r0, 0); + pp[1] = vget_lane_s8(_r0, 1); + pp[2] = vget_lane_s8(_r0, 2); + pp[3] = vget_lane_s8(_r0, 3); + + pp += 4; + p0 += 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + A_hstep); + uint16x8_t _r = vld1q_u16(p0 + A_hstep * 2); + uint16x8_t _s = vld1q_u16(p0 + A_hstep * 3); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + +#if __aarch64__ + _p0 = vmulq_laneq_f32(_p0, _scale, 0); + _p1 = vmulq_laneq_f32(_p1, _scale, 0); + _p2 = vmulq_laneq_f32(_p2, _scale, 1); + _p3 = vmulq_laneq_f32(_p3, _scale, 1); + _p4 = vmulq_laneq_f32(_p4, _scale, 2); + _p5 = vmulq_laneq_f32(_p5, _scale, 2); + _p6 = vmulq_laneq_f32(_p6, _scale, 3); + _p7 = vmulq_laneq_f32(_p7, _scale, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale), 0); + _p2 = vmulq_lane_f32(_p2, vget_low_f32(_scale), 1); + _p3 = vmulq_lane_f32(_p3, vget_low_f32(_scale), 1); + _p4 = vmulq_lane_f32(_p4, vget_high_f32(_scale), 0); + _p5 = vmulq_lane_f32(_p5, vget_high_f32(_scale), 0); + _p6 = vmulq_lane_f32(_p6, vget_high_f32(_scale), 1); + _p7 = vmulq_lane_f32(_p7, vget_high_f32(_scale), 1); +#endif + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p4, _p6); + int8x8_t _r2 = float2int8(_p1, _p3); + int8x8_t _r3 = float2int8(_p5, _p7); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p4, _p6)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p1, _p3)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p5, _p7)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r2 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + + pp += 32; + p0 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + A_hstep)); + float32x4_t _p2 = bfloat2float(vld1_u16(p0 + A_hstep * 2)); + float32x4_t _p3 = bfloat2float(vld1_u16(p0 + A_hstep * 3)); + +#if __aarch64__ + _p0 = vmulq_laneq_f32(_p0, _scale, 0); + _p1 = vmulq_laneq_f32(_p1, _scale, 1); + _p2 = vmulq_laneq_f32(_p2, _scale, 2); + _p3 = vmulq_laneq_f32(_p3, _scale, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale), 1); + _p2 = vmulq_lane_f32(_p2, vget_high_f32(_scale), 0); + _p3 = vmulq_lane_f32(_p3, vget_high_f32(_scale), 1); +#endif + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[1], _p, 1); + _p = vsetq_lane_u16(p0[A_hstep], _p, 2); + _p = vsetq_lane_u16(p0[A_hstep + 1], _p, 3); + _p = vsetq_lane_u16(p0[A_hstep * 2], _p, 4); + _p = vsetq_lane_u16(p0[A_hstep * 2 + 1], _p, 5); + _p = vsetq_lane_u16(p0[A_hstep * 3], _p, 6); + _p = vsetq_lane_u16(p0[A_hstep * 3 + 1], _p, 7); + float32x4_t _p01 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p23 = bfloat2float(vget_high_u16(_p)); + + float32x4x2_t _scale01 = vzipq_f32(_scale, _scale); + + _p01 = vmulq_f32(_p01, _scale01.val[0]); + _p23 = vmulq_f32(_p23, _scale01.val[1]); + + int8x8_t _r0 = float2int8(_p01, _p23); + + vst1_s8(pp, _r0); + + pp += 8; + p0 += 2; + } + for (; kk < max_kk; kk++) + { + uint16x4_t _p = uint16x4_t(); + _p = vset_lane_u16(p0[0], _p, 0); + _p = vset_lane_u16(p0[A_hstep], _p, 1); + _p = vset_lane_u16(p0[A_hstep * 2], _p, 2); + _p = vset_lane_u16(p0[A_hstep * 3], _p, 3); + float32x4_t _p0 = bfloat2float(_p); + + _p0 = vmulq_f32(_p0, _scale); + int8x8_t _r0 = float2int8(_p0, _p0); + + pp[0] = vget_lane_s8(_r0, 0); + pp[1] = vget_lane_s8(_r0, 1); + pp[2] = vget_lane_s8(_r0, 2); + pp[3] = vget_lane_s8(_r0, 3); + + pp += 4; + p0++; + } + } + } +#endif // __ARM_NEON + for (; ii + 1 < max_ii; ii += 2) + { + const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k; + + const float scale0 = scales[ii]; + const float scale1 = scales[ii + 1]; + + // if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + float32x4_t _scale0 = vdupq_n_f32(scale0); + float32x4_t _scale1 = vdupq_n_f32(scale1); + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + A_hstep); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale0); + _p2 = vmulq_f32(_p2, _scale1); + _p3 = vmulq_f32(_p3, _scale1); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p1, _p3); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p2)); + float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p2)); + float32x4_t _t2 = vcombine_f32(vget_low_f32(_p1), vget_low_f32(_p3)); + float32x4_t _t3 = vcombine_f32(vget_high_f32(_p1), vget_high_f32(_p3)); + int8x8_t _r0 = float2int8(_t0, _t1); + int8x8_t _r1 = float2int8(_t2, _t3); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r0); + vst1_s8(pp + 8, _r1); + + pp += 16; + p0 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + A_hstep)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p1)); + float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p1)); + int8x8_t _r0 = float2int8(_t0, _t1); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r0); + + pp += 8; + p0 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale0); + pp[1] = float2int8(bfloat16_to_float32(p0[1]) * scale0); + pp[2] = float2int8(bfloat16_to_float32(p0[A_hstep]) * scale1); + pp[3] = float2int8(bfloat16_to_float32(p0[A_hstep + 1]) * scale1); + pp += 4; + p0 += 2; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale0); + pp[1] = float2int8(bfloat16_to_float32(p0[A_hstep]) * scale1); + pp += 2; + p0++; + } + } + } + for (; ii < max_ii; ii += 1) + { + const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k; + + const float scale = scales[ii]; + + // if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + float32x4_t _scale = vdupq_n_f32(scale); + for (; kk + 15 < max_kk; kk += 16) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + + vst1_s8(pp, _r0); + + pp += 8; + p0 += 8; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); + pp += 1; + p0++; + } + } + } +} + +static void transpose_compute_A_tile_bf16_int8_scales(const Mat& A, Mat& scales, float B_scale, Mat& out_descales, int i, int max_ii) +{ + const int elempack = A.elempack; + const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + const int K = A.dims == 3 ? A.c : A.h; + + // NCNN_LOGE("transpose_compute_A_tile_bf16_int8_scales %d %d", max_ii, elempack); + + const float v127_B_scale = 127.f * B_scale; + +#if __ARM_NEON +#if __aarch64__ + float32x4_t _v127 = vdupq_n_f32(127.f); + float32x4_t _v127_B_scale = vdupq_n_f32(v127_B_scale); +#endif +#endif + + float* ps = scales; + float* pods = out_descales; + +#if __ARM_NEON + if (elempack == 4) + { + int ii = 0; + for (; ii + 3 < max_ii; ii += 4) + { + const unsigned short* p0 = (const unsigned short*)A + (i + ii) * 4; + + float32x4_t _absmax0 = vdupq_n_f32(0.f); + float32x4_t _absmax1 = vdupq_n_f32(0.f); + float32x4_t _absmax2 = vdupq_n_f32(0.f); + float32x4_t _absmax3 = vdupq_n_f32(0.f); + for (int kk = 0; kk < K; kk++) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + _absmax2 = vmaxq_f32(_absmax2, vabsq_f32(_p2)); + _absmax3 = vmaxq_f32(_absmax3, vabsq_f32(_p3)); + p0 += A_hstep * 4; + } + float32x2_t _aa0 = vmax_f32(vget_low_f32(_absmax0), vget_high_f32(_absmax0)); + float32x2_t _aa1 = vmax_f32(vget_low_f32(_absmax1), vget_high_f32(_absmax1)); + float32x2_t _aa2 = vmax_f32(vget_low_f32(_absmax2), vget_high_f32(_absmax2)); + float32x2_t _aa3 = vmax_f32(vget_low_f32(_absmax3), vget_high_f32(_absmax3)); + float32x2_t _aa01 = vpmax_f32(_aa0, _aa1); + float32x2_t _aa23 = vpmax_f32(_aa2, _aa3); + float32x4_t _absmax = vcombine_f32(_aa01, _aa23); + +#if __aarch64__ + float32x4_t _scale = vdivq_f32(_v127, _absmax); + float32x4_t _out_descale = vdivq_f32(_absmax, _v127_B_scale); + + vst1q_f32(ps, _scale); + vst1q_f32(pods, _out_descale); +#else + float tmp[4]; + vst1q_f32(tmp, _absmax); + + ps[0] = 127.f / tmp[0]; + ps[1] = 127.f / tmp[1]; + ps[2] = 127.f / tmp[2]; + ps[3] = 127.f / tmp[3]; + + pods[0] = tmp[0] / v127_B_scale; + pods[1] = tmp[1] / v127_B_scale; + pods[2] = tmp[2] / v127_B_scale; + pods[3] = tmp[3] / v127_B_scale; + + // float32x4_t _recp_absmax = vrecpeq_f32(_absmax); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax, _recp_absmax), _recp_absmax); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax, _recp_absmax), _recp_absmax); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax, _recp_absmax), _recp_absmax); + // float32x4_t _scale = vmulq_f32(_v127, _recp_absmax); + // float32x4_t _out_descale = vmulq_f32(_absmax, _recp_v127_B_scale); +#endif + + ps += 4; + pods += 4; + } + for (; ii < max_ii; ii++) + { + const unsigned short* p0 = (const unsigned short*)A + (i + ii) * 4; + + float32x4_t _absmax0 = vdupq_n_f32(0.f); + float32x4_t _absmax1 = vdupq_n_f32(0.f); + float32x4_t _absmax2 = vdupq_n_f32(0.f); + float32x4_t _absmax3 = vdupq_n_f32(0.f); + int kk = 0; + for (; kk + 3 < K; kk += 4) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + A_hstep * 4)); + float32x4_t _p2 = bfloat2float(vld1_u16(p0 + A_hstep * 8)); + float32x4_t _p3 = bfloat2float(vld1_u16(p0 + A_hstep * 12)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + _absmax2 = vmaxq_f32(_absmax2, vabsq_f32(_p2)); + _absmax3 = vmaxq_f32(_absmax3, vabsq_f32(_p3)); + p0 += A_hstep * 16; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax2); + _absmax1 = vmaxq_f32(_absmax1, _absmax3); + for (; kk + 1 < K; kk += 2) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + A_hstep * 4)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + p0 += A_hstep * 8; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax1); + for (; kk < K; kk++) + { + float32x4_t _p = bfloat2float(vld1_u16(p0)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p)); + p0 += A_hstep * 4; + } + float32x2_t _aa = vmax_f32(vget_low_f32(_absmax0), vget_high_f32(_absmax0)); + float absmax = std::max(vget_lane_f32(_aa, 0), vget_lane_f32(_aa, 1)); + + ps[0] = 127.f / absmax; + pods[0] = absmax / v127_B_scale; + ps++; + pods++; + } + } +#endif // __ARM_NEON + if (elempack == 1) + { + int ii = 0; +#if __ARM_NEON + for (; ii + 3 < max_ii; ii += 4) + { + const unsigned short* p0 = (const unsigned short*)A + (i + ii); + + float32x4_t _absmax0 = vdupq_n_f32(0.f); + float32x4_t _absmax1 = vdupq_n_f32(0.f); + float32x4_t _absmax2 = vdupq_n_f32(0.f); + float32x4_t _absmax3 = vdupq_n_f32(0.f); + int kk = 0; + for (; kk + 3 < K; kk += 4) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + A_hstep)); + float32x4_t _p2 = bfloat2float(vld1_u16(p0 + A_hstep * 2)); + float32x4_t _p3 = bfloat2float(vld1_u16(p0 + A_hstep * 3)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + _absmax2 = vmaxq_f32(_absmax2, vabsq_f32(_p2)); + _absmax3 = vmaxq_f32(_absmax3, vabsq_f32(_p3)); + p0 += A_hstep * 4; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax2); + _absmax1 = vmaxq_f32(_absmax1, _absmax3); + for (; kk + 1 < K; kk += 2) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + A_hstep)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + p0 += A_hstep * 2; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax1); + for (; kk < K; kk++) + { + float32x4_t _p = bfloat2float(vld1_u16(p0)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p)); + p0 += A_hstep; + } + +#if __aarch64__ + float32x4_t _scale = vdivq_f32(_v127, _absmax0); + float32x4_t _out_descale = vdivq_f32(_absmax0, _v127_B_scale); + + vst1q_f32(ps, _scale); + vst1q_f32(pods, _out_descale); +#else + float tmp[4]; + vst1q_f32(tmp, _absmax0); + + ps[0] = 127.f / tmp[0]; + ps[1] = 127.f / tmp[1]; + ps[2] = 127.f / tmp[2]; + ps[3] = 127.f / tmp[3]; + + pods[0] = tmp[0] / v127_B_scale; + pods[1] = tmp[1] / v127_B_scale; + pods[2] = tmp[2] / v127_B_scale; + pods[3] = tmp[3] / v127_B_scale; + + // float32x4_t _recp_absmax = vrecpeq_f32(_absmax0); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax0, _recp_absmax), _recp_absmax); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax0, _recp_absmax), _recp_absmax); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax0, _recp_absmax), _recp_absmax); + // float32x4_t _scale = vmulq_f32(_v127, _recp_absmax); + // float32x4_t _out_descale = vmulq_f32(_absmax0, _recp_v127_B_scale); +#endif + + ps += 4; + pods += 4; + } +#endif // __ARM_NEON + for (; ii < max_ii; ii++) + { + const unsigned short* p0 = (const unsigned short*)A + (i + ii); + + float absmax = 0.f; + for (int kk = 0; kk < K; kk++) + { + absmax = std::max(absmax, (float)fabsf(bfloat16_to_float32(p0[0]))); + p0 += A_hstep; + } + + ps[0] = 127.f / absmax; + pods[0] = absmax / v127_B_scale; + ps++; + pods++; + } + } +} + +static void transpose_pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + transpose_pack_A_tile_bf16_to_int8_i8mm(A, AT, i, max_ii, k, max_kk, scales); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + transpose_pack_A_tile_bf16_to_int8_asimddp(A, AT, i, max_ii, k, max_kk, scales); + return; + } +#endif + + const int elempack = A.elempack; + const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + + // NCNN_LOGE("transpose_pack_A_tile_bf16_to_int8 %d %d", max_ii, elempack); + + signed char* pp = AT; + + int ii = 0; +#if __ARM_NEON + for (; ii + 7 < max_ii; ii += 8) + { + const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * elempack; + + float32x4_t _scale0 = vld1q_f32((const float*)scales + ii); + float32x4_t _scale1 = vld1q_f32((const float*)scales + ii + 4); + + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + 16); + uint16x8_t _s = vld1q_u16(p0 + 24); + uint16x8_t _t = vld1q_u16(p0 + A_hstep * 4); + uint16x8_t _u = vld1q_u16(p0 + A_hstep * 4 + 8); + uint16x8_t _v = vld1q_u16(p0 + A_hstep * 4 + 16); + uint16x8_t _w = vld1q_u16(p0 + A_hstep * 4 + 24); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + float32x4_t _p8 = bfloat2float(vget_low_u16(_t)); + float32x4_t _p9 = bfloat2float(vget_high_u16(_t)); + float32x4_t _pa = bfloat2float(vget_low_u16(_u)); + float32x4_t _pb = bfloat2float(vget_high_u16(_u)); + float32x4_t _pc = bfloat2float(vget_low_u16(_v)); + float32x4_t _pd = bfloat2float(vget_high_u16(_v)); + float32x4_t _pe = bfloat2float(vget_low_u16(_w)); + float32x4_t _pf = bfloat2float(vget_high_u16(_w)); + +#if __aarch64__ + _p0 = vmulq_laneq_f32(_p0, _scale0, 0); + _p1 = vmulq_laneq_f32(_p1, _scale0, 1); + _p2 = vmulq_laneq_f32(_p2, _scale0, 2); + _p3 = vmulq_laneq_f32(_p3, _scale0, 3); + _p4 = vmulq_laneq_f32(_p4, _scale1, 0); + _p5 = vmulq_laneq_f32(_p5, _scale1, 1); + _p6 = vmulq_laneq_f32(_p6, _scale1, 2); + _p7 = vmulq_laneq_f32(_p7, _scale1, 3); + _p8 = vmulq_laneq_f32(_p8, _scale0, 0); + _p9 = vmulq_laneq_f32(_p9, _scale0, 1); + _pa = vmulq_laneq_f32(_pa, _scale0, 2); + _pb = vmulq_laneq_f32(_pb, _scale0, 3); + _pc = vmulq_laneq_f32(_pc, _scale1, 0); + _pd = vmulq_laneq_f32(_pd, _scale1, 1); + _pe = vmulq_laneq_f32(_pe, _scale1, 2); + _pf = vmulq_laneq_f32(_pf, _scale1, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale0), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale0), 1); + _p2 = vmulq_lane_f32(_p2, vget_high_f32(_scale0), 0); + _p3 = vmulq_lane_f32(_p3, vget_high_f32(_scale0), 1); + _p4 = vmulq_lane_f32(_p4, vget_low_f32(_scale1), 0); + _p5 = vmulq_lane_f32(_p5, vget_low_f32(_scale1), 1); + _p6 = vmulq_lane_f32(_p6, vget_high_f32(_scale1), 0); + _p7 = vmulq_lane_f32(_p7, vget_high_f32(_scale1), 1); + _p8 = vmulq_lane_f32(_p8, vget_low_f32(_scale0), 0); + _p9 = vmulq_lane_f32(_p9, vget_low_f32(_scale0), 1); + _pa = vmulq_lane_f32(_pa, vget_high_f32(_scale0), 0); + _pb = vmulq_lane_f32(_pb, vget_high_f32(_scale0), 1); + _pc = vmulq_lane_f32(_pc, vget_low_f32(_scale1), 0); + _pd = vmulq_lane_f32(_pd, vget_low_f32(_scale1), 1); + _pe = vmulq_lane_f32(_pe, vget_high_f32(_scale1), 0); + _pf = vmulq_lane_f32(_pf, vget_high_f32(_scale1), 1); +#endif + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p8); + int8x8_t _r1 = float2int8(_p1, _p9); + int8x8_t _r2 = float2int8(_p2, _pa); + int8x8_t _r3 = float2int8(_p3, _pb); + int8x8_t _r4 = float2int8(_p4, _pc); + int8x8_t _r5 = float2int8(_p5, _pd); + int8x8_t _r6 = float2int8(_p6, _pe); + int8x8_t _r7 = float2int8(_p7, _pf); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); + + int16x8_t _r01 = vreinterpretq_s16_s8(vcombine_s8(_r0, _r1)); + int16x8_t _r23 = vreinterpretq_s16_s8(vcombine_s8(_r2, _r3)); + int16x8_t _r45 = vreinterpretq_s16_s8(vcombine_s8(_r4, _r5)); + int16x8_t _r67 = vreinterpretq_s16_s8(vcombine_s8(_r6, _r7)); + int16x8x2_t _rr0 = vuzpq_s16(_r01, _r23); + int16x8x2_t _rr1 = vuzpq_s16(_r45, _r67); + + vst1q_s8(pp, vreinterpretq_s8_s16(_rr0.val[0])); + vst1q_s8(pp + 16, vreinterpretq_s8_s16(_rr0.val[1])); + vst1q_s8(pp + 32, vreinterpretq_s8_s16(_rr1.val[0])); + vst1q_s8(pp + 48, vreinterpretq_s8_s16(_rr1.val[1])); +#endif // __ARM_FEATURE_DOTPROD + + pp += 64; + p0 += A_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + 16); + uint16x8_t _s = vld1q_u16(p0 + 24); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + +#if __aarch64__ + _p0 = vmulq_laneq_f32(_p0, _scale0, 0); + _p1 = vmulq_laneq_f32(_p1, _scale0, 1); + _p2 = vmulq_laneq_f32(_p2, _scale0, 2); + _p3 = vmulq_laneq_f32(_p3, _scale0, 3); + _p4 = vmulq_laneq_f32(_p4, _scale1, 0); + _p5 = vmulq_laneq_f32(_p5, _scale1, 1); + _p6 = vmulq_laneq_f32(_p6, _scale1, 2); + _p7 = vmulq_laneq_f32(_p7, _scale1, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale0), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale0), 1); + _p2 = vmulq_lane_f32(_p2, vget_high_f32(_scale0), 0); + _p3 = vmulq_lane_f32(_p3, vget_high_f32(_scale0), 1); + _p4 = vmulq_lane_f32(_p4, vget_low_f32(_scale1), 0); + _p5 = vmulq_lane_f32(_p5, vget_low_f32(_scale1), 1); + _p6 = vmulq_lane_f32(_p6, vget_high_f32(_scale1), 0); + _p7 = vmulq_lane_f32(_p7, vget_high_f32(_scale1), 1); +#endif + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + +#if __ARM_FEATURE_DOTPROD + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); +#else // __ARM_FEATURE_DOTPROD + int16x8_t _r01 = vreinterpretq_s16_s8(vcombine_s8(_r0, _r1)); + int16x8_t _r23 = vreinterpretq_s16_s8(vcombine_s8(_r2, _r3)); + int16x8x2_t _rr = vuzpq_s16(_r01, _r23); + + vst1q_s8(pp, vreinterpretq_s8_s16(_rr.val[0])); + vst1q_s8(pp + 16, vreinterpretq_s8_s16(_rr.val[1])); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += A_hstep * 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + A_hstep); + uint16x8_t _r = vld1q_u16(p0 + A_hstep * 2); + uint16x8_t _s = vld1q_u16(p0 + A_hstep * 3); + uint16x8_t _t = vld1q_u16(p0 + A_hstep * 4); + uint16x8_t _u = vld1q_u16(p0 + A_hstep * 5); + uint16x8_t _v = vld1q_u16(p0 + A_hstep * 6); + uint16x8_t _w = vld1q_u16(p0 + A_hstep * 7); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + float32x4_t _p8 = bfloat2float(vget_low_u16(_t)); + float32x4_t _p9 = bfloat2float(vget_high_u16(_t)); + float32x4_t _pa = bfloat2float(vget_low_u16(_u)); + float32x4_t _pb = bfloat2float(vget_high_u16(_u)); + float32x4_t _pc = bfloat2float(vget_low_u16(_v)); + float32x4_t _pd = bfloat2float(vget_high_u16(_v)); + float32x4_t _pe = bfloat2float(vget_low_u16(_w)); + float32x4_t _pf = bfloat2float(vget_high_u16(_w)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + _p2 = vmulq_f32(_p2, _scale0); + _p3 = vmulq_f32(_p3, _scale1); + _p4 = vmulq_f32(_p4, _scale0); + _p5 = vmulq_f32(_p5, _scale1); + _p6 = vmulq_f32(_p6, _scale0); + _p7 = vmulq_f32(_p7, _scale1); + _p8 = vmulq_f32(_p8, _scale0); + _p9 = vmulq_f32(_p9, _scale1); + _pa = vmulq_f32(_pa, _scale0); + _pb = vmulq_f32(_pb, _scale1); + _pc = vmulq_f32(_pc, _scale0); + _pd = vmulq_f32(_pd, _scale1); + _pe = vmulq_f32(_pe, _scale0); + _pf = vmulq_f32(_pf, _scale1); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8x2_t _r04 = vzip_s8(_r0, _r4); + int8x8x2_t _r15 = vzip_s8(_r1, _r5); + int8x8x2_t _r26 = vzip_s8(_r2, _r6); + int8x8x2_t _r37 = vzip_s8(_r3, _r7); + int8x16x4_t _r0123; + _r0123.val[0] = vcombine_s8(_r04.val[0], _r04.val[1]); + _r0123.val[1] = vcombine_s8(_r15.val[0], _r15.val[1]); + _r0123.val[2] = vcombine_s8(_r26.val[0], _r26.val[1]); + _r0123.val[3] = vcombine_s8(_r37.val[0], _r37.val[1]); + + vst4q_s8(pp, _r0123); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8x4_t _r0123; + _r0123.val[0] = _r0; + _r0123.val[1] = _r1; + _r0123.val[2] = _r2; + _r0123.val[3] = _r3; + int8x8x4_t _r4567; + _r4567.val[0] = _r4; + _r4567.val[1] = _r5; + _r4567.val[2] = _r6; + _r4567.val[3] = _r7; + + vst4_s8(pp, _r0123); + vst4_s8(pp + 32, _r4567); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(_r0, _r2); + _r01.val[1] = vcombine_s8(_r1, _r3); + int8x16x2_t _r23; + _r23.val[0] = vcombine_s8(_r4, _r6); + _r23.val[1] = vcombine_s8(_r5, _r7); + + vst2q_s8(pp, _r01); + vst2q_s8(pp + 32, _r23); +#endif // __ARM_FEATURE_DOTPROD + + pp += 64; + p0 += A_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + A_hstep); + uint16x8_t _r = vld1q_u16(p0 + A_hstep * 2); + uint16x8_t _s = vld1q_u16(p0 + A_hstep * 3); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + _p2 = vmulq_f32(_p2, _scale0); + _p3 = vmulq_f32(_p3, _scale1); + _p4 = vmulq_f32(_p4, _scale0); + _p5 = vmulq_f32(_p5, _scale1); + _p6 = vmulq_f32(_p6, _scale0); + _p7 = vmulq_f32(_p7, _scale1); + +#if __ARM_FEATURE_DOTPROD + int8x8x4_t _r0123; + _r0123.val[0] = float2int8(_p0, _p1); + _r0123.val[1] = float2int8(_p2, _p3); + _r0123.val[2] = float2int8(_p4, _p5); + _r0123.val[3] = float2int8(_p6, _p7); + + vst4_s8(pp, _r0123); +#else // __ARM_FEATURE_DOTPROD + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p1), float2int8(_p4, _p5)); + _r01.val[1] = vcombine_s8(float2int8(_p2, _p3), float2int8(_p6, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += A_hstep * 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + A_hstep); + + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + _p2 = vmulq_f32(_p2, _scale0); + _p3 = vmulq_f32(_p3, _scale1); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p1); + _r01.val[1] = float2int8(_p2, _p3); + + vst2_s8(pp, _r01); + + pp += 16; + p0 += A_hstep * 2; + } + for (; kk < max_kk; kk++) + { + uint16x8_t _p = vld1q_u16(p0); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += A_hstep; + } + } + } + for (; ii + 3 < max_ii; ii += 4) + { + const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * elempack; + + float32x4_t _scale = vld1q_f32((const float*)scales + ii); + + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + A_hstep * 4); + uint16x8_t _s = vld1q_u16(p0 + A_hstep * 4 + 8); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + +#if __aarch64__ + _p0 = vmulq_laneq_f32(_p0, _scale, 0); + _p1 = vmulq_laneq_f32(_p1, _scale, 1); + _p2 = vmulq_laneq_f32(_p2, _scale, 2); + _p3 = vmulq_laneq_f32(_p3, _scale, 3); + _p4 = vmulq_laneq_f32(_p4, _scale, 0); + _p5 = vmulq_laneq_f32(_p5, _scale, 1); + _p6 = vmulq_laneq_f32(_p6, _scale, 2); + _p7 = vmulq_laneq_f32(_p7, _scale, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale), 1); + _p2 = vmulq_lane_f32(_p2, vget_high_f32(_scale), 0); + _p3 = vmulq_lane_f32(_p3, vget_high_f32(_scale), 1); + _p4 = vmulq_lane_f32(_p4, vget_low_f32(_scale), 0); + _p5 = vmulq_lane_f32(_p5, vget_low_f32(_scale), 1); + _p6 = vmulq_lane_f32(_p6, vget_high_f32(_scale), 0); + _p7 = vmulq_lane_f32(_p7, vget_high_f32(_scale), 1); +#endif + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p4); + int8x8_t _r1 = float2int8(_p1, _p5); + int8x8_t _r2 = float2int8(_p2, _p6); + int8x8_t _r3 = float2int8(_p3, _p7); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p4, _p5)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p6, _p7)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r2 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + + pp += 32; + p0 += A_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + +#if __aarch64__ + _p0 = vmulq_laneq_f32(_p0, _scale, 0); + _p1 = vmulq_laneq_f32(_p1, _scale, 1); + _p2 = vmulq_laneq_f32(_p2, _scale, 2); + _p3 = vmulq_laneq_f32(_p3, _scale, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale), 1); + _p2 = vmulq_lane_f32(_p2, vget_high_f32(_scale), 0); + _p3 = vmulq_lane_f32(_p3, vget_high_f32(_scale), 1); +#endif + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += A_hstep * 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + A_hstep)); + float32x4_t _p2 = bfloat2float(vld1_u16(p0 + A_hstep * 2)); + float32x4_t _p3 = bfloat2float(vld1_u16(p0 + A_hstep * 3)); + float32x4_t _p4 = bfloat2float(vld1_u16(p0 + A_hstep * 4)); + float32x4_t _p5 = bfloat2float(vld1_u16(p0 + A_hstep * 5)); + float32x4_t _p6 = bfloat2float(vld1_u16(p0 + A_hstep * 6)); + float32x4_t _p7 = bfloat2float(vld1_u16(p0 + A_hstep * 7)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + float32x4x2_t _p04 = vzipq_f32(_p0, _p4); + float32x4x2_t _p15 = vzipq_f32(_p1, _p5); + float32x4x2_t _p26 = vzipq_f32(_p2, _p6); + float32x4x2_t _p37 = vzipq_f32(_p3, _p7); + int8x8x4_t _r0123; + _r0123.val[0] = float2int8(_p04.val[0], _p04.val[1]); + _r0123.val[1] = float2int8(_p15.val[0], _p15.val[1]); + _r0123.val[2] = float2int8(_p26.val[0], _p26.val[1]); + _r0123.val[3] = float2int8(_p37.val[0], _p37.val[1]); + + vst4_s8(pp, _r0123); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8x4_t _r0123; + _r0123.val[0] = float2int8(_p0, _p4); + _r0123.val[1] = float2int8(_p1, _p5); + _r0123.val[2] = float2int8(_p2, _p6); + _r0123.val[3] = float2int8(_p3, _p7); + + vst4_s8(pp, _r0123); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p2), float2int8(_p4, _p6)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p3), float2int8(_p5, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += A_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + A_hstep)); + float32x4_t _p2 = bfloat2float(vld1_u16(p0 + A_hstep * 2)); + float32x4_t _p3 = bfloat2float(vld1_u16(p0 + A_hstep * 3)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + +#if __ARM_FEATURE_DOTPROD + transpose4x4_ps(_p0, _p1, _p2, _p3); + + int8x8_t _r01 = float2int8(_p0, _p1); + int8x8_t _r23 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r01, _r23)); +#else // __ARM_FEATURE_DOTPROD + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p2); + _r01.val[1] = float2int8(_p1, _p3); + + vst2_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 16; + p0 += A_hstep * 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + A_hstep)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + float32x4x2_t _p01 = vzipq_f32(_p0, _p1); + + int8x8_t _r01 = float2int8(_p01.val[0], _p01.val[1]); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += A_hstep * 2; + } + for (; kk < max_kk; kk++) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + _p0 = vmulq_f32(_p0, _scale); + int8x8_t _r0 = float2int8(_p0, _p0); + + pp[0] = vget_lane_s8(_r0, 0); + pp[1] = vget_lane_s8(_r0, 1); + pp[2] = vget_lane_s8(_r0, 2); + pp[3] = vget_lane_s8(_r0, 3); + pp += 4; + p0 += A_hstep; + } + } + } +#endif // __ARM_NEON + for (; ii + 1 < max_ii; ii += 2) + { + const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * elempack; + + const float scale0 = scales[ii]; + const float scale1 = scales[ii + 1]; + +#if __ARM_NEON + float32x4_t _scale0 = vdupq_n_f32(scale0); + float32x4_t _scale1 = vdupq_n_f32(scale1); + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + A_hstep * 4); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + _p2 = vmulq_f32(_p2, _scale0); + _p3 = vmulq_f32(_p3, _scale1); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p1, _p3); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p1, _p3)); + int16x4x2_t _t01 = vzip_s16(_t0, _t1); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += A_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + uint16x8_t _p = vld1q_u16(p0); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r01 = float2int8(_p0, _p1); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p1)); + float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p1)); + int8x8_t _r01 = float2int8(_t0, _t1); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r01); + + pp += 8; + p0 += A_hstep * 4; + } + } +#endif // __ARM_NEON + if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + float32x4_t _scale = vzipq_f32(_scale0, _scale1).val[0]; + for (; kk + 7 < max_kk; kk += 8) + { +#if __ARM_FEATURE_DOTPROD + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[1], _p, 1); + _p = vsetq_lane_u16(p0[A_hstep], _p, 2); + _p = vsetq_lane_u16(p0[A_hstep + 1], _p, 3); + _p = vsetq_lane_u16(p0[A_hstep * 2], _p, 4); + _p = vsetq_lane_u16(p0[A_hstep * 2 + 1], _p, 5); + _p = vsetq_lane_u16(p0[A_hstep * 3], _p, 6); + _p = vsetq_lane_u16(p0[A_hstep * 3 + 1], _p, 7); + uint16x8_t _q = uint16x8_t(); + _q = vsetq_lane_u16(p0[A_hstep * 4], _q, 0); + _q = vsetq_lane_u16(p0[A_hstep * 4 + 1], _q, 1); + _q = vsetq_lane_u16(p0[A_hstep * 5], _q, 2); + _q = vsetq_lane_u16(p0[A_hstep * 5 + 1], _q, 3); + _q = vsetq_lane_u16(p0[A_hstep * 6], _q, 4); + _q = vsetq_lane_u16(p0[A_hstep * 6 + 1], _q, 5); + _q = vsetq_lane_u16(p0[A_hstep * 7], _q, 6); + _q = vsetq_lane_u16(p0[A_hstep * 7 + 1], _q, 7); + float32x4_t _p01 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p23 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p45 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p67 = bfloat2float(vget_high_u16(_q)); + + _p01 = vmulq_f32(_p01, _scale); + _p23 = vmulq_f32(_p23, _scale); + _p45 = vmulq_f32(_p45, _scale); + _p67 = vmulq_f32(_p67, _scale); + + int8x8_t _r0 = float2int8(_p01, _p23); + int8x8_t _r1 = float2int8(_p45, _p67); + +#if __ARM_FEATURE_MATMUL_INT8 + int8x8x2_t _r01 = vuzp_s8(_r0, _r1); + + vst1q_s8(pp, vcombine_s8(_r01.val[0], _r01.val[1])); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8x2_t _r01 = vtrn_s8(_r0, _r1); + int8x8x2_t _rr01 = vuzp_s8(_r01.val[0], _r01.val[1]); + + vst1q_s8(pp, vcombine_s8(_rr01.val[0], _rr01.val[1])); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[1], _p, 1); + _p = vsetq_lane_u16(p0[A_hstep * 2], _p, 2); + _p = vsetq_lane_u16(p0[A_hstep * 2 + 1], _p, 3); + _p = vsetq_lane_u16(p0[A_hstep * 4], _p, 4); + _p = vsetq_lane_u16(p0[A_hstep * 4 + 1], _p, 5); + _p = vsetq_lane_u16(p0[A_hstep * 6], _p, 6); + _p = vsetq_lane_u16(p0[A_hstep * 6 + 1], _p, 7); + uint16x8_t _q = uint16x8_t(); + _q = vsetq_lane_u16(p0[A_hstep], _q, 0); + _q = vsetq_lane_u16(p0[A_hstep + 1], _q, 1); + _q = vsetq_lane_u16(p0[A_hstep * 3], _q, 2); + _q = vsetq_lane_u16(p0[A_hstep * 3 + 1], _q, 3); + _q = vsetq_lane_u16(p0[A_hstep * 5], _q, 4); + _q = vsetq_lane_u16(p0[A_hstep * 5 + 1], _q, 5); + _q = vsetq_lane_u16(p0[A_hstep * 7], _q, 6); + _q = vsetq_lane_u16(p0[A_hstep * 7 + 1], _q, 7); + float32x4_t _p02 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p46 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p13 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p57 = bfloat2float(vget_high_u16(_q)); + + _p02 = vmulq_f32(_p02, _scale); + _p46 = vmulq_f32(_p46, _scale); + _p13 = vmulq_f32(_p13, _scale); + _p57 = vmulq_f32(_p57, _scale); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p02, _p46); + _r01.val[1] = float2int8(_p13, _p57); + + vst2_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 16; + p0 += A_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[1], _p, 1); + _p = vsetq_lane_u16(p0[A_hstep], _p, 2); + _p = vsetq_lane_u16(p0[A_hstep + 1], _p, 3); + _p = vsetq_lane_u16(p0[A_hstep * 2], _p, 4); + _p = vsetq_lane_u16(p0[A_hstep * 2 + 1], _p, 5); + _p = vsetq_lane_u16(p0[A_hstep * 3], _p, 6); + _p = vsetq_lane_u16(p0[A_hstep * 3 + 1], _p, 7); + float32x4_t _p01 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p23 = bfloat2float(vget_high_u16(_p)); + + _p01 = vmulq_f32(_p01, _scale); + _p23 = vmulq_f32(_p23, _scale); + + float32x4x2_t _pp = vuzpq_f32(_p01, _p23); + int8x8_t _r01 = float2int8(_pp.val[0], _pp.val[1]); +#else // __ARM_FEATURE_DOTPROD + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[1], _p, 1); + _p = vsetq_lane_u16(p0[A_hstep * 2], _p, 2); + _p = vsetq_lane_u16(p0[A_hstep * 2 + 1], _p, 3); + _p = vsetq_lane_u16(p0[A_hstep], _p, 4); + _p = vsetq_lane_u16(p0[A_hstep + 1], _p, 5); + _p = vsetq_lane_u16(p0[A_hstep * 3], _p, 6); + _p = vsetq_lane_u16(p0[A_hstep * 3 + 1], _p, 7); + float32x4_t _p02 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p13 = bfloat2float(vget_high_u16(_p)); + + _p02 = vmulq_f32(_p02, _scale); + _p13 = vmulq_f32(_p13, _scale); + + float32x4x2_t _pp = vzipq_f32(_p02, _p13); + int8x8_t _r01 = float2int8(_pp.val[0], _pp.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r01); + + pp += 8; + p0 += A_hstep * 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale0); + pp[1] = float2int8(bfloat16_to_float32(p0[A_hstep + 0]) * scale0); + pp[2] = float2int8(bfloat16_to_float32(p0[1]) * scale1); + pp[3] = float2int8(bfloat16_to_float32(p0[A_hstep + 1]) * scale1); + pp += 4; + p0 += A_hstep * 2; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale0); + pp[1] = float2int8(bfloat16_to_float32(p0[1]) * scale1); + pp += 2; + p0 += A_hstep; + } + } + } + for (; ii < max_ii; ii += 1) + { + const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * elempack; + + const float scale = scales[ii]; + +#if __ARM_NEON + float32x4_t _scale = vdupq_n_f32(scale); + if (elempack == 4) + { + int kk = 0; + for (; kk + 15 < max_kk; kk += 16) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + A_hstep * 4)); + float32x4_t _p2 = bfloat2float(vld1_u16(p0 + A_hstep * 8)); + float32x4_t _p3 = bfloat2float(vld1_u16(p0 + A_hstep * 12)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + int8x8_t _r23 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r01, _r23)); + + pp += 16; + p0 += A_hstep * 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + A_hstep * 4)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += A_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); + pp[1] = float2int8(bfloat16_to_float32(p0[1]) * scale); + pp[2] = float2int8(bfloat16_to_float32(p0[2]) * scale); + pp[3] = float2int8(bfloat16_to_float32(p0[3]) * scale); + pp += 4; + p0 += A_hstep * 4; + } + } +#endif // __ARM_NEON + if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + for (; kk + 15 < max_kk; kk += 16) + { + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[A_hstep], _p, 1); + _p = vsetq_lane_u16(p0[A_hstep * 2], _p, 2); + _p = vsetq_lane_u16(p0[A_hstep * 3], _p, 3); + _p = vsetq_lane_u16(p0[A_hstep * 4], _p, 4); + _p = vsetq_lane_u16(p0[A_hstep * 5], _p, 5); + _p = vsetq_lane_u16(p0[A_hstep * 6], _p, 6); + _p = vsetq_lane_u16(p0[A_hstep * 7], _p, 7); + uint16x8_t _q = uint16x8_t(); + _q = vsetq_lane_u16(p0[A_hstep * 8], _q, 0); + _q = vsetq_lane_u16(p0[A_hstep * 9], _q, 1); + _q = vsetq_lane_u16(p0[A_hstep * 10], _q, 2); + _q = vsetq_lane_u16(p0[A_hstep * 11], _q, 3); + _q = vsetq_lane_u16(p0[A_hstep * 12], _q, 4); + _q = vsetq_lane_u16(p0[A_hstep * 13], _q, 5); + _q = vsetq_lane_u16(p0[A_hstep * 14], _q, 6); + _q = vsetq_lane_u16(p0[A_hstep * 15], _q, 7); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + int8x8_t _r23 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r01, _r23)); + + pp += 16; + p0 += A_hstep * 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[A_hstep], _p, 1); + _p = vsetq_lane_u16(p0[A_hstep * 2], _p, 2); + _p = vsetq_lane_u16(p0[A_hstep * 3], _p, 3); + _p = vsetq_lane_u16(p0[A_hstep * 4], _p, 4); + _p = vsetq_lane_u16(p0[A_hstep * 5], _p, 5); + _p = vsetq_lane_u16(p0[A_hstep * 6], _p, 6); + _p = vsetq_lane_u16(p0[A_hstep * 7], _p, 7); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += A_hstep * 8; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); + pp += 1; + p0 += A_hstep; + } + } + } +} + +static void compute_B_bf16_int8_scale(const Mat& B, float& scale) +{ + float absmax = 0.f; +#if __ARM_NEON + float32x4_t _absmax = vdupq_n_f32(0.f); +#endif + for (int i = 0; i < (B.dims == 3 ? B.c : B.h); i++) + { + const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; + const unsigned short* ptr = (const unsigned short*)B + i * B_hstep * B.elempack; + + const int size = B.w * B.elempack; + + int j = 0; +#if __ARM_NEON + for (; j + 7 < size; j += 8) + { + uint16x8_t _p = vld1q_u16(ptr); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + _absmax = vmaxq_f32(_absmax, vabsq_f32(_p0)); + _absmax = vmaxq_f32(_absmax, vabsq_f32(_p1)); + ptr += 8; + } + for (; j + 3 < size; j += 4) + { + float32x4_t _p = bfloat2float(vld1_u16(ptr)); + _absmax = vmaxq_f32(_absmax, vabsq_f32(_p)); + ptr += 4; + } +#endif + for (; j < size; j++) + { + absmax = std::max(absmax, (float)fabsf(bfloat16_to_float32(ptr[0]))); + ptr++; + } + } +#if __ARM_NEON + float32x2_t _aa = vmax_f32(vget_low_f32(_absmax), vget_high_f32(_absmax)); + absmax = std::max(absmax, std::max(vget_lane_f32(_aa, 0), vget_lane_f32(_aa, 1))); +#endif + + scale = absmax == 0.f ? 1.f : 127.f / absmax; +} + +static void pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + pack_B_tile_bf16_to_int8_i8mm(B, BT, j, max_jj, k, max_kk, scale); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + pack_B_tile_bf16_to_int8_asimddp(B, BT, j, max_jj, k, max_kk, scale); + return; + } +#endif + + const int elempack = B.elempack; + const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; + + // NCNN_LOGE("pack_B_tile_bf16_to_int8 %d %d", max_jj, elempack); + + signed char* pp = BT; + +#if __ARM_NEON + float32x4_t _scale = vdupq_n_f32(scale); +#endif + + int jj = 0; +#if __ARM_NEON +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k * elempack; + + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { +#if __ARM_FEATURE_DOTPROD + uint16x8x4_t _p = vld4q_u16(p0); + uint16x8x4_t _q = vld4q_u16(p0 + B_hstep * 4); + + float32x4_t _p0 = vmulq_f32(bfloat2float(vget_low_u16(_p.val[0])), _scale); + float32x4_t _p1 = vmulq_f32(bfloat2float(vget_low_u16(_p.val[1])), _scale); + float32x4_t _p2 = vmulq_f32(bfloat2float(vget_low_u16(_p.val[2])), _scale); + float32x4_t _p3 = vmulq_f32(bfloat2float(vget_low_u16(_p.val[3])), _scale); + float32x4_t _p4 = vmulq_f32(bfloat2float(vget_high_u16(_p.val[0])), _scale); + float32x4_t _p5 = vmulq_f32(bfloat2float(vget_high_u16(_p.val[1])), _scale); + float32x4_t _p6 = vmulq_f32(bfloat2float(vget_high_u16(_p.val[2])), _scale); + float32x4_t _p7 = vmulq_f32(bfloat2float(vget_high_u16(_p.val[3])), _scale); + float32x4_t _p8 = vmulq_f32(bfloat2float(vget_low_u16(_q.val[0])), _scale); + float32x4_t _p9 = vmulq_f32(bfloat2float(vget_low_u16(_q.val[1])), _scale); + float32x4_t _pa = vmulq_f32(bfloat2float(vget_low_u16(_q.val[2])), _scale); + float32x4_t _pb = vmulq_f32(bfloat2float(vget_low_u16(_q.val[3])), _scale); + float32x4_t _pc = vmulq_f32(bfloat2float(vget_high_u16(_q.val[0])), _scale); + float32x4_t _pd = vmulq_f32(bfloat2float(vget_high_u16(_q.val[1])), _scale); + float32x4_t _pe = vmulq_f32(bfloat2float(vget_high_u16(_q.val[2])), _scale); + float32x4_t _pf = vmulq_f32(bfloat2float(vget_high_u16(_q.val[3])), _scale); + +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p4); + int8x8_t _r1 = float2int8(_p1, _p5); + int8x8_t _r2 = float2int8(_p2, _p6); + int8x8_t _r3 = float2int8(_p3, _p7); + int8x8_t _r4 = float2int8(_p8, _pc); + int8x8_t _r5 = float2int8(_p9, _pd); + int8x8_t _r6 = float2int8(_pa, _pe); + int8x8_t _r7 = float2int8(_pb, _pf); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p8, _p9); + int8x8_t _r3 = float2int8(_pa, _pb); + int8x8_t _r4 = float2int8(_p4, _p5); + int8x8_t _r5 = float2int8(_p6, _p7); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); +#endif // __ARM_FEATURE_MATMUL_INT8 + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); +#else // __ARM_FEATURE_DOTPROD + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + 16); + uint16x8_t _s = vld1q_u16(p0 + 24); + uint16x8_t _t = vld1q_u16(p0 + B_hstep * 4); + uint16x8_t _u = vld1q_u16(p0 + B_hstep * 4 + 8); + uint16x8_t _v = vld1q_u16(p0 + B_hstep * 4 + 16); + uint16x8_t _w = vld1q_u16(p0 + B_hstep * 4 + 24); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + float32x4_t _p8 = bfloat2float(vget_low_u16(_t)); + float32x4_t _p9 = bfloat2float(vget_high_u16(_t)); + float32x4_t _pa = bfloat2float(vget_low_u16(_u)); + float32x4_t _pb = bfloat2float(vget_high_u16(_u)); + float32x4_t _pc = bfloat2float(vget_low_u16(_v)); + float32x4_t _pd = bfloat2float(vget_high_u16(_v)); + float32x4_t _pe = bfloat2float(vget_low_u16(_w)); + float32x4_t _pf = bfloat2float(vget_high_u16(_w)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + _p8 = vmulq_f32(_p8, _scale); + _p9 = vmulq_f32(_p9, _scale); + _pa = vmulq_f32(_pa, _scale); + _pb = vmulq_f32(_pb, _scale); + _pc = vmulq_f32(_pc, _scale); + _pd = vmulq_f32(_pd, _scale); + _pe = vmulq_f32(_pe, _scale); + _pf = vmulq_f32(_pf, _scale); + + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p8), float2int8(_p2, _pa)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p9), float2int8(_p3, _pb)); + int8x16x2_t _r23; + _r23.val[0] = vcombine_s8(float2int8(_p4, _pc), float2int8(_p6, _pe)); + _r23.val[1] = vcombine_s8(float2int8(_p5, _pd), float2int8(_p7, _pf)); + + vst2q_s8(pp, _r01); + vst2q_s8(pp + 32, _r23); +#endif // __ARM_FEATURE_DOTPROD + + pp += 64; + p0 += 32; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + uint16x4x4_t _p = vld4_u16(p0); + uint16x4x4_t _q = vld4_u16(p0 + B_hstep * 4); + + float32x4_t _p0 = vmulq_f32(bfloat2float(_p.val[0]), _scale); + float32x4_t _p1 = vmulq_f32(bfloat2float(_p.val[1]), _scale); + float32x4_t _p2 = vmulq_f32(bfloat2float(_p.val[2]), _scale); + float32x4_t _p3 = vmulq_f32(bfloat2float(_p.val[3]), _scale); + float32x4_t _p4 = vmulq_f32(bfloat2float(_q.val[0]), _scale); + float32x4_t _p5 = vmulq_f32(bfloat2float(_q.val[1]), _scale); + float32x4_t _p6 = vmulq_f32(bfloat2float(_q.val[2]), _scale); + float32x4_t _p7 = vmulq_f32(bfloat2float(_q.val[3]), _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); +#else // __ARM_FEATURE_DOTPROD + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + B_hstep * 4); + uint16x8_t _s = vld1q_u16(p0 + B_hstep * 4 + 8); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p4), float2int8(_p2, _p6)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p5), float2int8(_p3, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += 16; + } + for (; kk + 1 < max_kk; kk += 2) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + B_hstep * 4); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p2); + _r01.val[1] = float2int8(_p1, _p3); + + vst2_s8(pp, _r01); + + pp += 16; + p0 += 8; + } + for (; kk < max_kk; kk++) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + B_hstep * 4)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + + vst1_s8(pp, _r0); + + pp += 8; + p0 += 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + B_hstep); + uint16x8_t _r = vld1q_u16(p0 + B_hstep * 2); + uint16x8_t _s = vld1q_u16(p0 + B_hstep * 3); + uint16x8_t _t = vld1q_u16(p0 + B_hstep * 4); + uint16x8_t _u = vld1q_u16(p0 + B_hstep * 5); + uint16x8_t _v = vld1q_u16(p0 + B_hstep * 6); + uint16x8_t _w = vld1q_u16(p0 + B_hstep * 7); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + float32x4_t _p8 = bfloat2float(vget_low_u16(_t)); + float32x4_t _p9 = bfloat2float(vget_high_u16(_t)); + float32x4_t _pa = bfloat2float(vget_low_u16(_u)); + float32x4_t _pb = bfloat2float(vget_high_u16(_u)); + float32x4_t _pc = bfloat2float(vget_low_u16(_v)); + float32x4_t _pd = bfloat2float(vget_high_u16(_v)); + float32x4_t _pe = bfloat2float(vget_low_u16(_w)); + float32x4_t _pf = bfloat2float(vget_high_u16(_w)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + _p8 = vmulq_f32(_p8, _scale); + _p9 = vmulq_f32(_p9, _scale); + _pa = vmulq_f32(_pa, _scale); + _pb = vmulq_f32(_pb, _scale); + _pc = vmulq_f32(_pc, _scale); + _pd = vmulq_f32(_pd, _scale); + _pe = vmulq_f32(_pe, _scale); + _pf = vmulq_f32(_pf, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p4, _p6); + int8x8_t _r2 = float2int8(_p8, _pa); + int8x8_t _r3 = float2int8(_pc, _pe); + int8x8_t _r4 = float2int8(_p1, _p3); + int8x8_t _r5 = float2int8(_p5, _p7); + int8x8_t _r6 = float2int8(_p9, _pb); + int8x8_t _r7 = float2int8(_pd, _pf); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p4, _p6)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p8, _pa)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_pc, _pe)); + int16x4_t _t4 = vreinterpret_s16_s8(float2int8(_p1, _p3)); + int16x4_t _t5 = vreinterpret_s16_s8(float2int8(_p5, _p7)); + int16x4_t _t6 = vreinterpret_s16_s8(float2int8(_p9, _pb)); + int16x4_t _t7 = vreinterpret_s16_s8(float2int8(_pd, _pf)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int16x4x2_t _t45 = vuzp_s16(_t4, _t5); + int16x4x2_t _t67 = vuzp_s16(_t6, _t7); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r2 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); + int8x8_t _r4 = vreinterpret_s8_s16(_t45.val[0]); + int8x8_t _r5 = vreinterpret_s8_s16(_t67.val[0]); + int8x8_t _r6 = vreinterpret_s8_s16(_t45.val[1]); + int8x8_t _r7 = vreinterpret_s8_s16(_t67.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); + + pp += 64; + p0 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + B_hstep)); + float32x4_t _p2 = bfloat2float(vld1_u16(p0 + B_hstep * 2)); + float32x4_t _p3 = bfloat2float(vld1_u16(p0 + B_hstep * 3)); + float32x4_t _p4 = bfloat2float(vld1_u16(p0 + B_hstep * 4)); + float32x4_t _p5 = bfloat2float(vld1_u16(p0 + B_hstep * 5)); + float32x4_t _p6 = bfloat2float(vld1_u16(p0 + B_hstep * 6)); + float32x4_t _p7 = bfloat2float(vld1_u16(p0 + B_hstep * 7)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p4, _p5)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p6, _p7)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r2 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + + pp += 32; + p0 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[1], _p, 1); + _p = vsetq_lane_u16(p0[B_hstep], _p, 2); + _p = vsetq_lane_u16(p0[B_hstep + 1], _p, 3); + _p = vsetq_lane_u16(p0[B_hstep * 2], _p, 4); + _p = vsetq_lane_u16(p0[B_hstep * 2 + 1], _p, 5); + _p = vsetq_lane_u16(p0[B_hstep * 3], _p, 6); + _p = vsetq_lane_u16(p0[B_hstep * 3 + 1], _p, 7); + uint16x8_t _q = uint16x8_t(); + _q = vsetq_lane_u16(p0[B_hstep * 4], _q, 0); + _q = vsetq_lane_u16(p0[B_hstep * 4 + 1], _q, 1); + _q = vsetq_lane_u16(p0[B_hstep * 5], _q, 2); + _q = vsetq_lane_u16(p0[B_hstep * 5 + 1], _q, 3); + _q = vsetq_lane_u16(p0[B_hstep * 6], _q, 4); + _q = vsetq_lane_u16(p0[B_hstep * 6 + 1], _q, 5); + _q = vsetq_lane_u16(p0[B_hstep * 7], _q, 6); + _q = vsetq_lane_u16(p0[B_hstep * 7 + 1], _q, 7); + float32x4_t _p01 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p23 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p45 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p67 = bfloat2float(vget_high_u16(_q)); + + _p01 = vmulq_f32(_p01, _scale); + _p23 = vmulq_f32(_p23, _scale); + _p45 = vmulq_f32(_p45, _scale); + _p67 = vmulq_f32(_p67, _scale); + + int8x8_t _r0 = float2int8(_p01, _p23); + int8x8_t _r1 = float2int8(_p45, _p67); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += 2; + } + for (; kk < max_kk; kk++) + { + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[B_hstep], _p, 1); + _p = vsetq_lane_u16(p0[B_hstep * 2], _p, 2); + _p = vsetq_lane_u16(p0[B_hstep * 3], _p, 3); + _p = vsetq_lane_u16(p0[B_hstep * 4], _p, 4); + _p = vsetq_lane_u16(p0[B_hstep * 5], _p, 5); + _p = vsetq_lane_u16(p0[B_hstep * 6], _p, 6); + _p = vsetq_lane_u16(p0[B_hstep * 7], _p, 7); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + + vst1_s8(pp, _r0); + + pp += 8; + p0++; + } + } + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k * elempack; + + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { +#if __ARM_FEATURE_DOTPROD + uint16x8x4_t _p = vld4q_u16(p0); + + float32x4_t _p0 = vmulq_f32(bfloat2float(vget_low_u16(_p.val[0])), _scale); + float32x4_t _p1 = vmulq_f32(bfloat2float(vget_low_u16(_p.val[1])), _scale); + float32x4_t _p2 = vmulq_f32(bfloat2float(vget_low_u16(_p.val[2])), _scale); + float32x4_t _p3 = vmulq_f32(bfloat2float(vget_low_u16(_p.val[3])), _scale); + float32x4_t _p4 = vmulq_f32(bfloat2float(vget_high_u16(_p.val[0])), _scale); + float32x4_t _p5 = vmulq_f32(bfloat2float(vget_high_u16(_p.val[1])), _scale); + float32x4_t _p6 = vmulq_f32(bfloat2float(vget_high_u16(_p.val[2])), _scale); + float32x4_t _p7 = vmulq_f32(bfloat2float(vget_high_u16(_p.val[3])), _scale); + +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p4); + int8x8_t _r1 = float2int8(_p1, _p5); + int8x8_t _r2 = float2int8(_p2, _p6); + int8x8_t _r3 = float2int8(_p3, _p7); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#endif // __ARM_FEATURE_MATMUL_INT8 + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); +#else // __ARM_FEATURE_DOTPROD + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + 16); + uint16x8_t _s = vld1q_u16(p0 + 24); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p2), float2int8(_p4, _p6)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p3), float2int8(_p5, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += 32; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + uint16x4x4_t _p = vld4_u16(p0); + + float32x4_t _p0 = vmulq_f32(bfloat2float(_p.val[0]), _scale); + float32x4_t _p1 = vmulq_f32(bfloat2float(_p.val[1]), _scale); + float32x4_t _p2 = vmulq_f32(bfloat2float(_p.val[2]), _scale); + float32x4_t _p3 = vmulq_f32(bfloat2float(_p.val[3]), _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); +#else // __ARM_FEATURE_DOTPROD + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p2); + _r01.val[1] = float2int8(_p1, _p3); + + vst2_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 16; + p0 += 16; + } + for (; kk + 1 < max_kk; kk += 2) + { + uint16x8_t _p = vld1q_u16(p0); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + float32x4x2_t _p01 = vzipq_f32(_p0, _p1); + + int8x8_t _r01 = float2int8(_p01.val[0], _p01.val[1]); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += 8; + } + for (; kk < max_kk; kk++) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + _p0 = vmulq_f32(_p0, _scale); + int8x8_t _r0 = float2int8(_p0, _p0); + + pp[0] = vget_lane_s8(_r0, 0); + pp[1] = vget_lane_s8(_r0, 1); + pp[2] = vget_lane_s8(_r0, 2); + pp[3] = vget_lane_s8(_r0, 3); + + pp += 4; + p0 += 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + B_hstep); + uint16x8_t _r = vld1q_u16(p0 + B_hstep * 2); + uint16x8_t _s = vld1q_u16(p0 + B_hstep * 3); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p4, _p6); + int8x8_t _r2 = float2int8(_p1, _p3); + int8x8_t _r3 = float2int8(_p5, _p7); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p4, _p6)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p1, _p3)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p5, _p7)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r2 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + + pp += 32; + p0 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + B_hstep)); + float32x4_t _p2 = bfloat2float(vld1_u16(p0 + B_hstep * 2)); + float32x4_t _p3 = bfloat2float(vld1_u16(p0 + B_hstep * 3)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[1], _p, 1); + _p = vsetq_lane_u16(p0[B_hstep], _p, 2); + _p = vsetq_lane_u16(p0[B_hstep + 1], _p, 3); + _p = vsetq_lane_u16(p0[B_hstep * 2], _p, 4); + _p = vsetq_lane_u16(p0[B_hstep * 2 + 1], _p, 5); + _p = vsetq_lane_u16(p0[B_hstep * 3], _p, 6); + _p = vsetq_lane_u16(p0[B_hstep * 3 + 1], _p, 7); + float32x4_t _p01 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p23 = bfloat2float(vget_high_u16(_p)); + + _p01 = vmulq_f32(_p01, _scale); + _p23 = vmulq_f32(_p23, _scale); + + int8x8_t _r0 = float2int8(_p01, _p23); + + vst1_s8(pp, _r0); + + pp += 8; + p0 += 2; + } + for (; kk < max_kk; kk++) + { + uint16x4_t _p = uint16x4_t(); + _p = vset_lane_u16(p0[0], _p, 0); + _p = vset_lane_u16(p0[B_hstep], _p, 1); + _p = vset_lane_u16(p0[B_hstep * 2], _p, 2); + _p = vset_lane_u16(p0[B_hstep * 3], _p, 3); + float32x4_t _p0 = bfloat2float(_p); + + _p0 = vmulq_f32(_p0, _scale); + int8x8_t _r0 = float2int8(_p0, _p0); + + pp[0] = vget_lane_s8(_r0, 0); + pp[1] = vget_lane_s8(_r0, 1); + pp[2] = vget_lane_s8(_r0, 2); + pp[3] = vget_lane_s8(_r0, 3); + + pp += 4; + p0++; + } + } + } +#endif // __ARM_NEON + for (; jj + 1 < max_jj; jj += 2) + { + const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k; + + // if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + B_hstep); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p1, _p3); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p2)); + float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p2)); + float32x4_t _t2 = vcombine_f32(vget_low_f32(_p1), vget_low_f32(_p3)); + float32x4_t _t3 = vcombine_f32(vget_high_f32(_p1), vget_high_f32(_p3)); + int8x8_t _r0 = float2int8(_t0, _t1); + int8x8_t _r1 = float2int8(_t2, _t3); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r0); + vst1_s8(pp + 8, _r1); + + pp += 16; + p0 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + B_hstep)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p1)); + float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p1)); + int8x8_t _r0 = float2int8(_t0, _t1); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r0); + + pp += 8; + p0 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); + pp[1] = float2int8(bfloat16_to_float32(p0[1]) * scale); + pp[2] = float2int8(bfloat16_to_float32(p0[B_hstep]) * scale); + pp[3] = float2int8(bfloat16_to_float32(p0[B_hstep + 1]) * scale); + pp += 4; + p0 += 2; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); + pp[1] = float2int8(bfloat16_to_float32(p0[B_hstep]) * scale); + pp += 2; + p0++; + } + } + } + for (; jj < max_jj; jj += 1) + { + const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k; + + // if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + for (; kk + 15 < max_kk; kk += 16) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + + vst1_s8(pp, _r0); + + pp += 8; + p0 += 8; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); + pp += 1; + p0++; + } + } + } +} + +static void transpose_pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + transpose_pack_B_tile_bf16_to_int8_i8mm(B, BT, j, max_jj, k, max_kk, scale); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + transpose_pack_B_tile_bf16_to_int8_asimddp(B, BT, j, max_jj, k, max_kk, scale); + return; + } +#endif + + const int elempack = B.elempack; + const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; + + // NCNN_LOGE("transpose_pack_B_tile_bf16_to_int8 %d %d", max_jj, elempack); + + signed char* pp = BT; + +#if __ARM_NEON + float32x4_t _scale = vdupq_n_f32(scale); +#endif + + int jj = 0; +#if __ARM_NEON +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * elempack; + + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + 16); + uint16x8_t _s = vld1q_u16(p0 + 24); + uint16x8_t _t = vld1q_u16(p0 + B_hstep * 4); + uint16x8_t _u = vld1q_u16(p0 + B_hstep * 4 + 8); + uint16x8_t _v = vld1q_u16(p0 + B_hstep * 4 + 16); + uint16x8_t _w = vld1q_u16(p0 + B_hstep * 4 + 24); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + float32x4_t _p8 = bfloat2float(vget_low_u16(_t)); + float32x4_t _p9 = bfloat2float(vget_high_u16(_t)); + float32x4_t _pa = bfloat2float(vget_low_u16(_u)); + float32x4_t _pb = bfloat2float(vget_high_u16(_u)); + float32x4_t _pc = bfloat2float(vget_low_u16(_v)); + float32x4_t _pd = bfloat2float(vget_high_u16(_v)); + float32x4_t _pe = bfloat2float(vget_low_u16(_w)); + float32x4_t _pf = bfloat2float(vget_high_u16(_w)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + _p8 = vmulq_f32(_p8, _scale); + _p9 = vmulq_f32(_p9, _scale); + _pa = vmulq_f32(_pa, _scale); + _pb = vmulq_f32(_pb, _scale); + _pc = vmulq_f32(_pc, _scale); + _pd = vmulq_f32(_pd, _scale); + _pe = vmulq_f32(_pe, _scale); + _pf = vmulq_f32(_pf, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p8); + int8x8_t _r1 = float2int8(_p1, _p9); + int8x8_t _r2 = float2int8(_p2, _pa); + int8x8_t _r3 = float2int8(_p3, _pb); + int8x8_t _r4 = float2int8(_p4, _pc); + int8x8_t _r5 = float2int8(_p5, _pd); + int8x8_t _r6 = float2int8(_p6, _pe); + int8x8_t _r7 = float2int8(_p7, _pf); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); + + int16x8_t _r01 = vreinterpretq_s16_s8(vcombine_s8(_r0, _r1)); + int16x8_t _r23 = vreinterpretq_s16_s8(vcombine_s8(_r2, _r3)); + int16x8_t _r45 = vreinterpretq_s16_s8(vcombine_s8(_r4, _r5)); + int16x8_t _r67 = vreinterpretq_s16_s8(vcombine_s8(_r6, _r7)); + int16x8x2_t _rr0 = vuzpq_s16(_r01, _r23); + int16x8x2_t _rr1 = vuzpq_s16(_r45, _r67); + + vst1q_s8(pp, vreinterpretq_s8_s16(_rr0.val[0])); + vst1q_s8(pp + 16, vreinterpretq_s8_s16(_rr0.val[1])); + vst1q_s8(pp + 32, vreinterpretq_s8_s16(_rr1.val[0])); + vst1q_s8(pp + 48, vreinterpretq_s8_s16(_rr1.val[1])); +#endif // __ARM_FEATURE_DOTPROD + + pp += 64; + p0 += B_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + 16); + uint16x8_t _s = vld1q_u16(p0 + 24); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + +#if __ARM_FEATURE_DOTPROD + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); +#else // __ARM_FEATURE_DOTPROD + int16x8_t _r01 = vreinterpretq_s16_s8(vcombine_s8(_r0, _r1)); + int16x8_t _r23 = vreinterpretq_s16_s8(vcombine_s8(_r2, _r3)); + int16x8x2_t _rr = vuzpq_s16(_r01, _r23); + + vst1q_s8(pp, vreinterpretq_s8_s16(_rr.val[0])); + vst1q_s8(pp + 16, vreinterpretq_s8_s16(_rr.val[1])); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += B_hstep * 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + B_hstep); + uint16x8_t _r = vld1q_u16(p0 + B_hstep * 2); + uint16x8_t _s = vld1q_u16(p0 + B_hstep * 3); + uint16x8_t _t = vld1q_u16(p0 + B_hstep * 4); + uint16x8_t _u = vld1q_u16(p0 + B_hstep * 5); + uint16x8_t _v = vld1q_u16(p0 + B_hstep * 6); + uint16x8_t _w = vld1q_u16(p0 + B_hstep * 7); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + float32x4_t _p8 = bfloat2float(vget_low_u16(_t)); + float32x4_t _p9 = bfloat2float(vget_high_u16(_t)); + float32x4_t _pa = bfloat2float(vget_low_u16(_u)); + float32x4_t _pb = bfloat2float(vget_high_u16(_u)); + float32x4_t _pc = bfloat2float(vget_low_u16(_v)); + float32x4_t _pd = bfloat2float(vget_high_u16(_v)); + float32x4_t _pe = bfloat2float(vget_low_u16(_w)); + float32x4_t _pf = bfloat2float(vget_high_u16(_w)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + _p8 = vmulq_f32(_p8, _scale); + _p9 = vmulq_f32(_p9, _scale); + _pa = vmulq_f32(_pa, _scale); + _pb = vmulq_f32(_pb, _scale); + _pc = vmulq_f32(_pc, _scale); + _pd = vmulq_f32(_pd, _scale); + _pe = vmulq_f32(_pe, _scale); + _pf = vmulq_f32(_pf, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8x2_t _r04 = vzip_s8(_r0, _r4); + int8x8x2_t _r15 = vzip_s8(_r1, _r5); + int8x8x2_t _r26 = vzip_s8(_r2, _r6); + int8x8x2_t _r37 = vzip_s8(_r3, _r7); + int8x16x4_t _r0123; + _r0123.val[0] = vcombine_s8(_r04.val[0], _r04.val[1]); + _r0123.val[1] = vcombine_s8(_r15.val[0], _r15.val[1]); + _r0123.val[2] = vcombine_s8(_r26.val[0], _r26.val[1]); + _r0123.val[3] = vcombine_s8(_r37.val[0], _r37.val[1]); + + vst4q_s8(pp, _r0123); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8x4_t _r0123; + _r0123.val[0] = _r0; + _r0123.val[1] = _r1; + _r0123.val[2] = _r2; + _r0123.val[3] = _r3; + int8x8x4_t _r4567; + _r4567.val[0] = _r4; + _r4567.val[1] = _r5; + _r4567.val[2] = _r6; + _r4567.val[3] = _r7; + + vst4_s8(pp, _r0123); + vst4_s8(pp + 32, _r4567); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(_r0, _r2); + _r01.val[1] = vcombine_s8(_r1, _r3); + int8x16x2_t _r23; + _r23.val[0] = vcombine_s8(_r4, _r6); + _r23.val[1] = vcombine_s8(_r5, _r7); + + vst2q_s8(pp, _r01); + vst2q_s8(pp + 32, _r23); +#endif // __ARM_FEATURE_DOTPROD + + pp += 64; + p0 += B_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + B_hstep); + uint16x8_t _r = vld1q_u16(p0 + B_hstep * 2); + uint16x8_t _s = vld1q_u16(p0 + B_hstep * 3); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + +#if __ARM_FEATURE_DOTPROD + int8x8x4_t _r0123; + _r0123.val[0] = float2int8(_p0, _p1); + _r0123.val[1] = float2int8(_p2, _p3); + _r0123.val[2] = float2int8(_p4, _p5); + _r0123.val[3] = float2int8(_p6, _p7); + + vst4_s8(pp, _r0123); +#else // __ARM_FEATURE_DOTPROD + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p1), float2int8(_p4, _p5)); + _r01.val[1] = vcombine_s8(float2int8(_p2, _p3), float2int8(_p6, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += B_hstep * 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + B_hstep); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p1); + _r01.val[1] = float2int8(_p2, _p3); + + vst2_s8(pp, _r01); + + pp += 16; + p0 += B_hstep * 2; + } + for (; kk < max_kk; kk++) + { + uint16x8_t _p = vld1q_u16(p0); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + + vst1_s8(pp, _r0); + + pp += 8; + p0 += B_hstep; + } + } + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * elempack; + + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + B_hstep * 4); + uint16x8_t _s = vld1q_u16(p0 + B_hstep * 4 + 8); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p4); + int8x8_t _r1 = float2int8(_p1, _p5); + int8x8_t _r2 = float2int8(_p2, _p6); + int8x8_t _r3 = float2int8(_p3, _p7); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p4, _p5)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p6, _p7)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r2 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + + pp += 32; + p0 += B_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += B_hstep * 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + B_hstep)); + float32x4_t _p2 = bfloat2float(vld1_u16(p0 + B_hstep * 2)); + float32x4_t _p3 = bfloat2float(vld1_u16(p0 + B_hstep * 3)); + float32x4_t _p4 = bfloat2float(vld1_u16(p0 + B_hstep * 4)); + float32x4_t _p5 = bfloat2float(vld1_u16(p0 + B_hstep * 5)); + float32x4_t _p6 = bfloat2float(vld1_u16(p0 + B_hstep * 6)); + float32x4_t _p7 = bfloat2float(vld1_u16(p0 + B_hstep * 7)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + float32x4x2_t _p04 = vzipq_f32(_p0, _p4); + float32x4x2_t _p15 = vzipq_f32(_p1, _p5); + float32x4x2_t _p26 = vzipq_f32(_p2, _p6); + float32x4x2_t _p37 = vzipq_f32(_p3, _p7); + int8x8x4_t _r0123; + _r0123.val[0] = float2int8(_p04.val[0], _p04.val[1]); + _r0123.val[1] = float2int8(_p15.val[0], _p15.val[1]); + _r0123.val[2] = float2int8(_p26.val[0], _p26.val[1]); + _r0123.val[3] = float2int8(_p37.val[0], _p37.val[1]); + + vst4_s8(pp, _r0123); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8x4_t _r0123; + _r0123.val[0] = float2int8(_p0, _p4); + _r0123.val[1] = float2int8(_p1, _p5); + _r0123.val[2] = float2int8(_p2, _p6); + _r0123.val[3] = float2int8(_p3, _p7); + + vst4_s8(pp, _r0123); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p2), float2int8(_p4, _p6)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p3), float2int8(_p5, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += B_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + B_hstep)); + float32x4_t _p2 = bfloat2float(vld1_u16(p0 + B_hstep * 2)); + float32x4_t _p3 = bfloat2float(vld1_u16(p0 + B_hstep * 3)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + +#if __ARM_FEATURE_DOTPROD + transpose4x4_ps(_p0, _p1, _p2, _p3); + int8x8_t _r01 = float2int8(_p0, _p1); + int8x8_t _r23 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r01, _r23)); +#else // __ARM_FEATURE_DOTPROD + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p2); + _r01.val[1] = float2int8(_p1, _p3); + + vst2_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 16; + p0 += B_hstep * 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + B_hstep)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + float32x4x2_t _p01 = vzipq_f32(_p0, _p1); + int8x8_t _r01 = float2int8(_p01.val[0], _p01.val[1]); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += B_hstep * 2; + } + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); + pp[1] = float2int8(bfloat16_to_float32(p0[1]) * scale); + pp[2] = float2int8(bfloat16_to_float32(p0[2]) * scale); + pp[3] = float2int8(bfloat16_to_float32(p0[3]) * scale); + pp += 4; + p0 += B_hstep; + } + } + } +#endif // __ARM_NEON + for (; jj + 1 < max_jj; jj += 2) + { + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * elempack; + +#if __ARM_NEON + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + B_hstep * 4); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p1, _p3); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p1, _p3)); + int16x4x2_t _t01 = vzip_s16(_t0, _t1); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += B_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + uint16x8_t _p = vld1q_u16(p0); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r01 = float2int8(_p0, _p1); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p1)); + float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p1)); + int8x8_t _r01 = float2int8(_t0, _t1); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r01); + + pp += 8; + p0 += B_hstep * 4; + } + } +#endif // __ARM_NEON + if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + for (; kk + 7 < max_kk; kk += 8) + { +#if __ARM_FEATURE_DOTPROD + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[1], _p, 1); + _p = vsetq_lane_u16(p0[B_hstep], _p, 2); + _p = vsetq_lane_u16(p0[B_hstep + 1], _p, 3); + _p = vsetq_lane_u16(p0[B_hstep * 2], _p, 4); + _p = vsetq_lane_u16(p0[B_hstep * 2 + 1], _p, 5); + _p = vsetq_lane_u16(p0[B_hstep * 3], _p, 6); + _p = vsetq_lane_u16(p0[B_hstep * 3 + 1], _p, 7); + uint16x8_t _q = uint16x8_t(); + _q = vsetq_lane_u16(p0[B_hstep * 4], _q, 0); + _q = vsetq_lane_u16(p0[B_hstep * 4 + 1], _q, 1); + _q = vsetq_lane_u16(p0[B_hstep * 5], _q, 2); + _q = vsetq_lane_u16(p0[B_hstep * 5 + 1], _q, 3); + _q = vsetq_lane_u16(p0[B_hstep * 6], _q, 4); + _q = vsetq_lane_u16(p0[B_hstep * 6 + 1], _q, 5); + _q = vsetq_lane_u16(p0[B_hstep * 7], _q, 6); + _q = vsetq_lane_u16(p0[B_hstep * 7 + 1], _q, 7); + float32x4_t _p01 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p23 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p45 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p67 = bfloat2float(vget_high_u16(_q)); + + _p01 = vmulq_f32(_p01, _scale); + _p23 = vmulq_f32(_p23, _scale); + _p45 = vmulq_f32(_p45, _scale); + _p67 = vmulq_f32(_p67, _scale); + + int8x8_t _r0 = float2int8(_p01, _p23); + int8x8_t _r1 = float2int8(_p45, _p67); + +#if __ARM_FEATURE_MATMUL_INT8 + int8x8x2_t _r01 = vuzp_s8(_r0, _r1); + + vst1q_s8(pp, vcombine_s8(_r01.val[0], _r01.val[1])); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8x2_t _r01 = vtrn_s8(_r0, _r1); + int8x8x2_t _rr01 = vuzp_s8(_r01.val[0], _r01.val[1]); + + vst1q_s8(pp, vcombine_s8(_rr01.val[0], _rr01.val[1])); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[1], _p, 1); + _p = vsetq_lane_u16(p0[B_hstep * 2], _p, 2); + _p = vsetq_lane_u16(p0[B_hstep * 2 + 1], _p, 3); + _p = vsetq_lane_u16(p0[B_hstep * 4], _p, 4); + _p = vsetq_lane_u16(p0[B_hstep * 4 + 1], _p, 5); + _p = vsetq_lane_u16(p0[B_hstep * 6], _p, 6); + _p = vsetq_lane_u16(p0[B_hstep * 6 + 1], _p, 7); + uint16x8_t _q = uint16x8_t(); + _q = vsetq_lane_u16(p0[B_hstep], _q, 0); + _q = vsetq_lane_u16(p0[B_hstep + 1], _q, 1); + _q = vsetq_lane_u16(p0[B_hstep * 3], _q, 2); + _q = vsetq_lane_u16(p0[B_hstep * 3 + 1], _q, 3); + _q = vsetq_lane_u16(p0[B_hstep * 5], _q, 4); + _q = vsetq_lane_u16(p0[B_hstep * 5 + 1], _q, 5); + _q = vsetq_lane_u16(p0[B_hstep * 7], _q, 6); + _q = vsetq_lane_u16(p0[B_hstep * 7 + 1], _q, 7); + float32x4_t _p02 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p46 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p13 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p57 = bfloat2float(vget_high_u16(_q)); + + _p02 = vmulq_f32(_p02, _scale); + _p46 = vmulq_f32(_p46, _scale); + _p13 = vmulq_f32(_p13, _scale); + _p57 = vmulq_f32(_p57, _scale); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p02, _p46); + _r01.val[1] = float2int8(_p13, _p57); + + vst2_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 16; + p0 += B_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[1], _p, 1); + _p = vsetq_lane_u16(p0[B_hstep], _p, 2); + _p = vsetq_lane_u16(p0[B_hstep + 1], _p, 3); + _p = vsetq_lane_u16(p0[B_hstep * 2], _p, 4); + _p = vsetq_lane_u16(p0[B_hstep * 2 + 1], _p, 5); + _p = vsetq_lane_u16(p0[B_hstep * 3], _p, 6); + _p = vsetq_lane_u16(p0[B_hstep * 3 + 1], _p, 7); + float32x4_t _p01 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p23 = bfloat2float(vget_high_u16(_p)); + + _p01 = vmulq_f32(_p01, _scale); + _p23 = vmulq_f32(_p23, _scale); + + float32x4x2_t _pp = vuzpq_f32(_p01, _p23); + int8x8_t _r01 = float2int8(_pp.val[0], _pp.val[1]); +#else // __ARM_FEATURE_DOTPROD + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[1], _p, 1); + _p = vsetq_lane_u16(p0[B_hstep * 2], _p, 2); + _p = vsetq_lane_u16(p0[B_hstep * 2 + 1], _p, 3); + _p = vsetq_lane_u16(p0[B_hstep], _p, 4); + _p = vsetq_lane_u16(p0[B_hstep + 1], _p, 5); + _p = vsetq_lane_u16(p0[B_hstep * 3], _p, 6); + _p = vsetq_lane_u16(p0[B_hstep * 3 + 1], _p, 7); + float32x4_t _p02 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p13 = bfloat2float(vget_high_u16(_p)); + + _p02 = vmulq_f32(_p02, _scale); + _p13 = vmulq_f32(_p13, _scale); + + float32x4x2_t _pp = vzipq_f32(_p02, _p13); + int8x8_t _r01 = float2int8(_pp.val[0], _pp.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r01); + + pp += 8; + p0 += B_hstep * 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); + pp[1] = float2int8(bfloat16_to_float32(p0[B_hstep + 0]) * scale); + pp[2] = float2int8(bfloat16_to_float32(p0[1]) * scale); + pp[3] = float2int8(bfloat16_to_float32(p0[B_hstep + 1]) * scale); + pp += 4; + p0 += B_hstep * 2; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); + pp[1] = float2int8(bfloat16_to_float32(p0[1]) * scale); + pp += 2; + p0 += B_hstep; + } + } + } + for (; jj < max_jj; jj += 1) + { + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * elempack; + +#if __ARM_NEON + if (elempack == 4) + { + int kk = 0; + for (; kk + 15 < max_kk; kk += 16) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + B_hstep * 4)); + float32x4_t _p2 = bfloat2float(vld1_u16(p0 + B_hstep * 8)); + float32x4_t _p3 = bfloat2float(vld1_u16(p0 + B_hstep * 12)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + int8x8_t _r23 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r01, _r23)); + + pp += 16; + p0 += B_hstep * 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + B_hstep * 4)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += B_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); + pp[1] = float2int8(bfloat16_to_float32(p0[1]) * scale); + pp[2] = float2int8(bfloat16_to_float32(p0[2]) * scale); + pp[3] = float2int8(bfloat16_to_float32(p0[3]) * scale); + pp += 4; + p0 += B_hstep * 4; + } + } +#endif // __ARM_NEON + if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + for (; kk + 15 < max_kk; kk += 16) + { + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[B_hstep], _p, 1); + _p = vsetq_lane_u16(p0[B_hstep * 2], _p, 2); + _p = vsetq_lane_u16(p0[B_hstep * 3], _p, 3); + _p = vsetq_lane_u16(p0[B_hstep * 4], _p, 4); + _p = vsetq_lane_u16(p0[B_hstep * 5], _p, 5); + _p = vsetq_lane_u16(p0[B_hstep * 6], _p, 6); + _p = vsetq_lane_u16(p0[B_hstep * 7], _p, 7); + uint16x8_t _q = uint16x8_t(); + _q = vsetq_lane_u16(p0[B_hstep * 8], _q, 0); + _q = vsetq_lane_u16(p0[B_hstep * 9], _q, 1); + _q = vsetq_lane_u16(p0[B_hstep * 10], _q, 2); + _q = vsetq_lane_u16(p0[B_hstep * 11], _q, 3); + _q = vsetq_lane_u16(p0[B_hstep * 12], _q, 4); + _q = vsetq_lane_u16(p0[B_hstep * 13], _q, 5); + _q = vsetq_lane_u16(p0[B_hstep * 14], _q, 6); + _q = vsetq_lane_u16(p0[B_hstep * 15], _q, 7); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + int8x8_t _r23 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r01, _r23)); + + pp += 16; + p0 += B_hstep * 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[B_hstep], _p, 1); + _p = vsetq_lane_u16(p0[B_hstep * 2], _p, 2); + _p = vsetq_lane_u16(p0[B_hstep * 3], _p, 3); + _p = vsetq_lane_u16(p0[B_hstep * 4], _p, 4); + _p = vsetq_lane_u16(p0[B_hstep * 5], _p, 5); + _p = vsetq_lane_u16(p0[B_hstep * 6], _p, 6); + _p = vsetq_lane_u16(p0[B_hstep * 7], _p, 7); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += B_hstep * 8; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); + pp += 1; + p0 += B_hstep; + } + } + } +} + +static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + unpack_output_tile_int32_to_bf16_asimddp(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales, alpha, beta); + return; + } +#endif + + const int out_elempack = top_blob.elempack; + const int out_hstep = top_blob.dims == 3 ? (int)top_blob.cstep : top_blob.w; + + const int c_hstep = C.dims == 3 ? (int)C.cstep : C.w; + const int c_elempack = C.elempack; + const unsigned short* pC = C; + + // NCNN_LOGE("unpack_output_tile_int32_to_bf16 %d %d %d %d %d %d %d", i, max_ii, j, max_jj, out_elempack, broadcast_type_C, c_elempack); + + const int* pp = topT; + + int ii = 0; +#if __ARM_NEON + for (; ii + 7 < max_ii; ii += 8) + { + unsigned short* p0 = (unsigned short*)top_blob + (i + ii) * out_hstep + j * out_elempack; + + float32x4_t _descale0 = vld1q_f32((const float*)descales + ii); + float32x4_t _descale1 = vld1q_f32((const float*)descales + ii + 4); + + float32x4_t _c0; + float32x4_t _c1; + if (pC) + { + if (broadcast_type_C == 0) + { + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const unsigned short*)C + i + ii; + uint16x8_t _c = vld1q_u16(pC); + _c0 = bfloat2float(vget_low_u16(_c)); + _c1 = bfloat2float(vget_high_u16(_c)); + _c0 = vmulq_n_f32(_c0, beta); + _c1 = vmulq_n_f32(_c1, beta); + } + if (broadcast_type_C == 3) + { + pC = (const unsigned short*)C + (i + ii) * c_hstep + j * c_elempack; + } + if (broadcast_type_C == 4) + { + pC = (const unsigned short*)C + j; + } + } + + int jj = 0; +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + int32x4_t _sum8 = vld1q_s32(pp + 32); + int32x4_t _sum9 = vld1q_s32(pp + 36); + int32x4_t _suma = vld1q_s32(pp + 40); + int32x4_t _sumb = vld1q_s32(pp + 44); + int32x4_t _sumc = vld1q_s32(pp + 48); + int32x4_t _sumd = vld1q_s32(pp + 52); + int32x4_t _sume = vld1q_s32(pp + 56); + int32x4_t _sumf = vld1q_s32(pp + 60); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 +#else + // from + // a0 b1 c2 d3 + // e4 f5 g6 h7 + // e0 f1 g2 h3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // g4 h5 e6 f7 + // g0 h1 e2 f3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // e7 f6 g5 h4 + // e3 f2 g1 h0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // g7 h6 e5 f4 + // g3 h2 e1 f0 + // c7 d6 a5 b4 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 + { + _sum8 = vrev64q_s32(_sum8); + _sum9 = vrev64q_s32(_sum9); + _suma = vrev64q_s32(_suma); + _sumb = vrev64q_s32(_sumb); + _sumc = vrev64q_s32(_sumc); + _sumd = vrev64q_s32(_sumd); + _sume = vrev64q_s32(_sume); + _sumf = vrev64q_s32(_sumf); + _sum8 = vextq_s32(_sum8, _sum8, 2); + _sum9 = vextq_s32(_sum9, _sum9, 2); + _suma = vextq_s32(_suma, _suma, 2); + _sumb = vextq_s32(_sumb, _sumb, 2); + _sumc = vextq_s32(_sumc, _sumc, 2); + _sumd = vextq_s32(_sumd, _sumd, 2); + _sume = vextq_s32(_sume, _sume, 2); + _sumf = vextq_s32(_sumf, _sumf, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sumc); + int32x4x2_t _t1 = vzipq_s32(_sum4, _sum8); + int32x4x2_t _t2 = vzipq_s32(_sum2, _sume); + int32x4x2_t _t3 = vzipq_s32(_sum6, _suma); + int32x4x2_t _t4 = vzipq_s32(_sum3, _sumf); + int32x4x2_t _t5 = vzipq_s32(_sum7, _sumb); + int32x4x2_t _t6 = vzipq_s32(_sum1, _sumd); + int32x4x2_t _t7 = vzipq_s32(_sum5, _sum9); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum8 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); + _sum9 = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); + _suma = vcombine_s32(vget_low_s32(_t5.val[1]), vget_low_s32(_t4.val[1])); + _sumb = vcombine_s32(vget_high_s32(_t5.val[1]), vget_high_s32(_t4.val[1])); + _sumc = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); + _sumd = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); + _sume = vcombine_s32(vget_low_s32(_t7.val[1]), vget_low_s32(_t6.val[1])); + _sumf = vcombine_s32(vget_high_s32(_t7.val[1]), vget_high_s32(_t6.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + _sum9 = vrev64q_s32(_sum9); + _sumb = vrev64q_s32(_sumb); + _sumd = vrev64q_s32(_sumd); + _sumf = vrev64q_s32(_sumf); + } +#endif + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum8), _descale0); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum9), _descale0); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_suma), _descale0); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sumb), _descale0); + float32x4_t _f8 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); + float32x4_t _f9 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); + float32x4_t _fa = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); + float32x4_t _fb = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); + float32x4_t _fc = vmulq_f32(vcvtq_f32_s32(_sumc), _descale1); + float32x4_t _fd = vmulq_f32(vcvtq_f32_s32(_sumd), _descale1); + float32x4_t _fe = vmulq_f32(vcvtq_f32_s32(_sume), _descale1); + float32x4_t _ff = vmulq_f32(vcvtq_f32_s32(_sumf), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c0); + _fa = vaddq_f32(_fa, _c0); + _fb = vaddq_f32(_fb, _c0); + _fc = vaddq_f32(_fc, _c0); + _fd = vaddq_f32(_fd, _c0); + _fe = vaddq_f32(_fe, _c0); + _ff = vaddq_f32(_ff, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + _f8 = vaddq_f32(_f8, _c1); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c1); + _fb = vaddq_f32(_fb, _c1); + _fc = vaddq_f32(_fc, _c1); + _fd = vaddq_f32(_fd, _c1); + _fe = vaddq_f32(_fe, _c1); + _ff = vaddq_f32(_ff, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + uint16x8_t _c45 = vld1q_u16(pC + 16); + uint16x8_t _c67 = vld1q_u16(pC + 24); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); + float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); + float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); + float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c01 = vld1q_u16(pC + c_hstep * 4); + _c23 = vld1q_u16(pC + c_hstep * 4 + 8); + _c45 = vld1q_u16(pC + c_hstep * 4 + 16); + _c67 = vld1q_u16(pC + c_hstep * 4 + 24); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + _c4 = bfloat2float(vget_low_u16(_c45)); + _c5 = bfloat2float(vget_high_u16(_c45)); + _c6 = bfloat2float(vget_low_u16(_c67)); + _c7 = bfloat2float(vget_high_u16(_c67)); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); + } + pC += 32; + } + if (c_elempack == 1) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + c_hstep); + uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); + uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); + transpose8x4_u16(_c01, _c23, _c45, _c67); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); + float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); + float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); + float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c01 = vld1q_u16(pC + c_hstep * 4); + _c23 = vld1q_u16(pC + c_hstep * 5); + _c45 = vld1q_u16(pC + c_hstep * 6); + _c67 = vld1q_u16(pC + c_hstep * 7); + transpose8x4_u16(_c01, _c23, _c45, _c67); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + _c4 = bfloat2float(vget_low_u16(_c45)); + _c5 = bfloat2float(vget_high_u16(_c45)); + _c6 = bfloat2float(vget_low_u16(_c67)); + _c7 = bfloat2float(vget_high_u16(_c67)); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); + } + pC += 8; + } + } + if (broadcast_type_C == 4) + { + uint16x8_t _cc = vld1q_u16(pC); + float32x4_t _cc0 = bfloat2float(vget_low_u16(_cc)); + float32x4_t _cc1 = bfloat2float(vget_high_u16(_cc)); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _cc0 = vmulq_f32(_cc0, _beta); + _cc1 = vmulq_f32(_cc1, _beta); + } + _c0 = vdupq_laneq_f32(_cc0, 0); + _c1 = vdupq_laneq_f32(_cc0, 1); + float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); + float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); + float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); + float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); + float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); + float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + pC += 8; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + _f8 = vmulq_f32(_f8, _alpha); + _f9 = vmulq_f32(_f9, _alpha); + _fa = vmulq_f32(_fa, _alpha); + _fb = vmulq_f32(_fb, _alpha); + _fc = vmulq_f32(_fc, _alpha); + _fd = vmulq_f32(_fd, _alpha); + _fe = vmulq_f32(_fe, _alpha); + _ff = vmulq_f32(_ff, _alpha); + } + + uint16x4_t _bf0 = float2bfloat(_f0); + uint16x4_t _bf1 = float2bfloat(_f1); + uint16x4_t _bf2 = float2bfloat(_f2); + uint16x4_t _bf3 = float2bfloat(_f3); + uint16x4_t _bf4 = float2bfloat(_f4); + uint16x4_t _bf5 = float2bfloat(_f5); + uint16x4_t _bf6 = float2bfloat(_f6); + uint16x4_t _bf7 = float2bfloat(_f7); + uint16x4_t _bf8 = float2bfloat(_f8); + uint16x4_t _bf9 = float2bfloat(_f9); + uint16x4_t _bfa = float2bfloat(_fa); + uint16x4_t _bfb = float2bfloat(_fb); + uint16x4_t _bfc = float2bfloat(_fc); + uint16x4_t _bfd = float2bfloat(_fd); + uint16x4_t _bfe = float2bfloat(_fe); + uint16x4_t _bff = float2bfloat(_ff); + + if (out_elempack == 4) + { + vst1q_u16(p0, vcombine_u16(_bf0, _bf1)); + vst1q_u16(p0 + 8, vcombine_u16(_bf2, _bf3)); + vst1q_u16(p0 + 16, vcombine_u16(_bf4, _bf5)); + vst1q_u16(p0 + 24, vcombine_u16(_bf6, _bf7)); + vst1q_u16(p0 + out_hstep * 4, vcombine_u16(_bf8, _bf9)); + vst1q_u16(p0 + out_hstep * 4 + 8, vcombine_u16(_bfa, _bfb)); + vst1q_u16(p0 + out_hstep * 4 + 16, vcombine_u16(_bfc, _bfd)); + vst1q_u16(p0 + out_hstep * 4 + 24, vcombine_u16(_bfe, _bff)); + p0 += 32; + } + if (out_elempack == 1) + { + transpose4x4_u16(_bf0, _bf1, _bf2, _bf3); + transpose4x4_u16(_bf4, _bf5, _bf6, _bf7); + transpose4x4_u16(_bf8, _bf9, _bfa, _bfb); + transpose4x4_u16(_bfc, _bfd, _bfe, _bff); + vst1q_u16(p0, vcombine_u16(_bf0, _bf4)); + vst1q_u16(p0 + out_hstep, vcombine_u16(_bf1, _bf5)); + vst1q_u16(p0 + out_hstep * 2, vcombine_u16(_bf2, _bf6)); + vst1q_u16(p0 + out_hstep * 3, vcombine_u16(_bf3, _bf7)); + vst1q_u16(p0 + out_hstep * 4, vcombine_u16(_bf8, _bfc)); + vst1q_u16(p0 + out_hstep * 5, vcombine_u16(_bf9, _bfd)); + vst1q_u16(p0 + out_hstep * 6, vcombine_u16(_bfa, _bfe)); + vst1q_u16(p0 + out_hstep * 7, vcombine_u16(_bfb, _bff)); + p0 += 8; + } + + pp += 64; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 +#else + // from + // a0 b1 c2 d3 + // e0 f1 g2 h3 + // c0 d1 a2 b3 + // g0 h1 e2 f3 + // a3 b2 c1 d0 + // e3 f2 g1 h0 + // c3 d2 a1 b0 + // g3 h2 e1 f0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + { + _sum4 = vrev64q_s32(_sum4); + _sum5 = vrev64q_s32(_sum5); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + _sum4 = vextq_s32(_sum4, _sum4, 2); + _sum5 = vextq_s32(_sum5, _sum5, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c1); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c1); + _f7 = vaddq_f32(_f7, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + _c01 = vld1q_u16(pC + c_hstep * 4); + _c23 = vld1q_u16(pC + c_hstep * 4 + 8); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); + } + pC += 16; + } + if (c_elempack == 1) + { + uint16x4_t _cc0 = vld1_u16(pC); + uint16x4_t _cc1 = vld1_u16(pC + c_hstep); + uint16x4_t _cc2 = vld1_u16(pC + c_hstep * 2); + uint16x4_t _cc3 = vld1_u16(pC + c_hstep * 3); + transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); + _c0 = bfloat2float(_cc0); + _c1 = bfloat2float(_cc1); + float32x4_t _c2 = bfloat2float(_cc2); + float32x4_t _c3 = bfloat2float(_cc3); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + _cc0 = vld1_u16(pC + c_hstep * 4); + _cc1 = vld1_u16(pC + c_hstep * 5); + _cc2 = vld1_u16(pC + c_hstep * 6); + _cc3 = vld1_u16(pC + c_hstep * 7); + transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); + _c0 = bfloat2float(_cc0); + _c1 = bfloat2float(_cc1); + _c2 = bfloat2float(_cc2); + _c3 = bfloat2float(_cc3); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); + } + pC += 4; + } + } + if (broadcast_type_C == 4) + { + float32x4_t _c = bfloat2float(vld1_u16(pC)); + _c = vmulq_n_f32(_c, beta); +#if __aarch64__ + _c0 = vdupq_laneq_f32(_c, 0); + _c1 = vdupq_laneq_f32(_c, 1); + float32x4_t _c2 = vdupq_laneq_f32(_c, 2); + float32x4_t _c3 = vdupq_laneq_f32(_c, 3); +#else + _c0 = vdupq_lane_f32(vget_low_f32(_c), 0); + _c1 = vdupq_lane_f32(vget_low_f32(_c), 1); + float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_c), 0); + float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_c), 1); +#endif + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + pC += 4; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + + uint16x4_t _bf0 = float2bfloat(_f0); + uint16x4_t _bf1 = float2bfloat(_f1); + uint16x4_t _bf2 = float2bfloat(_f2); + uint16x4_t _bf3 = float2bfloat(_f3); + uint16x4_t _bf4 = float2bfloat(_f4); + uint16x4_t _bf5 = float2bfloat(_f5); + uint16x4_t _bf6 = float2bfloat(_f6); + uint16x4_t _bf7 = float2bfloat(_f7); + + if (out_elempack == 4) + { + vst1q_u16(p0, vcombine_u16(_bf0, _bf1)); + vst1q_u16(p0 + 8, vcombine_u16(_bf2, _bf3)); + vst1q_u16(p0 + out_hstep * 4, vcombine_u16(_bf4, _bf5)); + vst1q_u16(p0 + out_hstep * 4 + 8, vcombine_u16(_bf6, _bf7)); + p0 += 16; + } + if (out_elempack == 1) + { + transpose4x4_u16(_bf0, _bf1, _bf2, _bf3); + transpose4x4_u16(_bf4, _bf5, _bf6, _bf7); + vst1_u16(p0, _bf0); + vst1_u16(p0 + out_hstep, _bf1); + vst1_u16(p0 + out_hstep * 2, _bf2); + vst1_u16(p0 + out_hstep * 3, _bf3); + vst1_u16(p0 + out_hstep * 4, _bf4); + vst1_u16(p0 + out_hstep * 5, _bf5); + vst1_u16(p0 + out_hstep * 6, _bf6); + vst1_u16(p0 + out_hstep * 7, _bf7); + p0 += 4; + } + + pp += 32; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // e0 f0 g0 h0 + // e1 f1 g1 h1 +#else + // from + // a0 b1 c0 d1 + // e0 f1 g0 h1 + // a1 b0 c1 d0 + // e1 f0 g1 h0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + { + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum2); + int32x4x2_t _t1 = vzipq_s32(_sum1, _sum3); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[0]), vget_low_s32(_t1.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[0]), vget_high_s32(_t1.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale1); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c1); + _f3 = vaddq_f32(_f3, _c1); + } + if (broadcast_type_C == 3) + { + uint16x8_t _c01; + uint16x8_t _c23; + if (c_elempack == 4) + { + _c01 = vld1q_u16(pC); + _c23 = vld1q_u16(pC + c_hstep * 4); + pC += 8; + } + if (c_elempack == 1) + { + _c01 = uint16x8_t(); + _c01 = vsetq_lane_u16(pC[0], _c01, 0); + _c01 = vsetq_lane_u16(pC[c_hstep], _c01, 1); + _c01 = vsetq_lane_u16(pC[c_hstep * 2], _c01, 2); + _c01 = vsetq_lane_u16(pC[c_hstep * 3], _c01, 3); + _c01 = vsetq_lane_u16(pC[1], _c01, 4); + _c01 = vsetq_lane_u16(pC[c_hstep + 1], _c01, 5); + _c01 = vsetq_lane_u16(pC[c_hstep * 2 + 1], _c01, 6); + _c01 = vsetq_lane_u16(pC[c_hstep * 3 + 1], _c01, 7); + _c23 = uint16x8_t(); + _c23 = vsetq_lane_u16(pC[c_hstep * 4], _c23, 0); + _c23 = vsetq_lane_u16(pC[c_hstep * 5], _c23, 1); + _c23 = vsetq_lane_u16(pC[c_hstep * 6], _c23, 2); + _c23 = vsetq_lane_u16(pC[c_hstep * 7], _c23, 3); + _c23 = vsetq_lane_u16(pC[c_hstep * 4 + 1], _c23, 4); + _c23 = vsetq_lane_u16(pC[c_hstep * 5 + 1], _c23, 5); + _c23 = vsetq_lane_u16(pC[c_hstep * 6 + 1], _c23, 6); + _c23 = vsetq_lane_u16(pC[c_hstep * 7 + 1], _c23, 7); + pC += 2; + } + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); + _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1]) * beta); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + pC += 2; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + + uint16x4_t _bf0 = float2bfloat(_f0); + uint16x4_t _bf1 = float2bfloat(_f1); + uint16x4_t _bf2 = float2bfloat(_f2); + uint16x4_t _bf3 = float2bfloat(_f3); + + if (out_elempack == 4) + { + vst1q_u16(p0, vcombine_u16(_bf0, _bf1)); + vst1q_u16(p0 + out_hstep * 4, vcombine_u16(_bf2, _bf3)); + p0 += 8; + } + if (out_elempack == 1) + { + p0[0] = vget_lane_u16(_bf0, 0); + p0[1] = vget_lane_u16(_bf1, 0); + p0[out_hstep] = vget_lane_u16(_bf0, 1); + p0[out_hstep + 1] = vget_lane_u16(_bf1, 1); + p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); + p0[out_hstep * 2 + 1] = vget_lane_u16(_bf1, 2); + p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); + p0[out_hstep * 3 + 1] = vget_lane_u16(_bf1, 3); + p0[out_hstep * 4] = vget_lane_u16(_bf2, 0); + p0[out_hstep * 4 + 1] = vget_lane_u16(_bf3, 0); + p0[out_hstep * 5] = vget_lane_u16(_bf2, 1); + p0[out_hstep * 5 + 1] = vget_lane_u16(_bf3, 1); + p0[out_hstep * 6] = vget_lane_u16(_bf2, 2); + p0[out_hstep * 6 + 1] = vget_lane_u16(_bf3, 2); + p0[out_hstep * 7] = vget_lane_u16(_bf2, 3); + p0[out_hstep * 7 + 1] = vget_lane_u16(_bf3, 3); + p0 += 2; + } + + pp += 16; + } + for (; jj < max_jj; jj++) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 4) + { + _c0 = bfloat2float(vld1_u16(pC)); + _c1 = bfloat2float(vld1_u16(pC + c_hstep * 4)); + pC += 4; + } + if (c_elempack == 1) + { + uint16x8_t _c01 = uint16x8_t(); + _c01 = vsetq_lane_u16(pC[0], _c01, 0); + _c01 = vsetq_lane_u16(pC[c_hstep], _c01, 1); + _c01 = vsetq_lane_u16(pC[c_hstep * 2], _c01, 2); + _c01 = vsetq_lane_u16(pC[c_hstep * 3], _c01, 3); + _c01 = vsetq_lane_u16(pC[c_hstep * 4], _c01, 4); + _c01 = vsetq_lane_u16(pC[c_hstep * 5], _c01, 5); + _c01 = vsetq_lane_u16(pC[c_hstep * 6], _c01, 6); + _c01 = vsetq_lane_u16(pC[c_hstep * 7], _c01, 7); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + pC += 1; + } + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + pC += 1; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + + uint16x4_t _bf0 = float2bfloat(_f0); + uint16x4_t _bf1 = float2bfloat(_f1); + + if (out_elempack == 4) + { + vst1_u16(p0, _bf0); + vst1_u16(p0 + out_hstep * 4, _bf1); + p0 += 4; + } + if (out_elempack == 1) + { + p0[0] = vget_lane_u16(_bf0, 0); + p0[out_hstep] = vget_lane_u16(_bf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); + p0[out_hstep * 4] = vget_lane_u16(_bf1, 0); + p0[out_hstep * 5] = vget_lane_u16(_bf1, 1); + p0[out_hstep * 6] = vget_lane_u16(_bf1, 2); + p0[out_hstep * 7] = vget_lane_u16(_bf1, 3); + p0++; + } + + pp += 8; + } + } + for (; ii + 3 < max_ii; ii += 4) + { + unsigned short* p0 = (unsigned short*)top_blob + (i + ii) * out_hstep + j * out_elempack; + + float32x4_t _descale = vld1q_f32((const float*)descales + ii); + + float32x4_t _c0; + if (pC) + { + if (broadcast_type_C == 0) + { + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const unsigned short*)C + i + ii; + _c0 = bfloat2float(vld1_u16(pC)); + _c0 = vmulq_n_f32(_c0, beta); + } + if (broadcast_type_C == 3) + { + pC = (const unsigned short*)C + (i + ii) * c_hstep + j * c_elempack; + } + if (broadcast_type_C == 4) + { + pC = (const unsigned short*)C + j; + } + } + + int jj = 0; +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 +#else + // from + // a0 b1 c2 d3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // c7 d6 a5 b4 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + { + _sum4 = vrev64q_s32(_sum4); + _sum5 = vrev64q_s32(_sum5); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + _sum4 = vextq_s32(_sum4, _sum4, 2); + _sum5 = vextq_s32(_sum5, _sum5, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 3) + { + uint16x8_t _c01; + uint16x8_t _c23; + uint16x8_t _c45; + uint16x8_t _c67; + if (c_elempack == 4) + { + _c01 = vld1q_u16(pC); + _c23 = vld1q_u16(pC + 8); + _c45 = vld1q_u16(pC + 16); + _c67 = vld1q_u16(pC + 24); + pC += 32; + } + if (c_elempack == 1) + { + _c01 = vld1q_u16(pC); + _c23 = vld1q_u16(pC + c_hstep); + _c45 = vld1q_u16(pC + c_hstep * 2); + _c67 = vld1q_u16(pC + c_hstep * 3); + transpose8x4_u16(_c01, _c23, _c45, _c67); + pC += 8; + } + _c0 = bfloat2float(vget_low_u16(_c01)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); + float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); + float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); + float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + } + if (broadcast_type_C == 4) + { + uint16x8_t _c = vld1q_u16(pC); + float32x4_t _cc0 = bfloat2float(vget_low_u16(_c)); + float32x4_t _cc1 = bfloat2float(vget_high_u16(_c)); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _cc0 = vmulq_f32(_cc0, _beta); + _cc1 = vmulq_f32(_cc1, _beta); + } + _c0 = vdupq_laneq_f32(_cc0, 0); + float32x4_t _c1 = vdupq_laneq_f32(_cc0, 1); + float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); + float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); + float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); + float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); + float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); + float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + pC += 8; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + + uint16x4_t _bf0 = float2bfloat(_f0); + uint16x4_t _bf1 = float2bfloat(_f1); + uint16x4_t _bf2 = float2bfloat(_f2); + uint16x4_t _bf3 = float2bfloat(_f3); + uint16x4_t _bf4 = float2bfloat(_f4); + uint16x4_t _bf5 = float2bfloat(_f5); + uint16x4_t _bf6 = float2bfloat(_f6); + uint16x4_t _bf7 = float2bfloat(_f7); + + if (out_elempack == 4) + { + vst1q_u16(p0, vcombine_u16(_bf0, _bf1)); + vst1q_u16(p0 + 8, vcombine_u16(_bf2, _bf3)); + vst1q_u16(p0 + 16, vcombine_u16(_bf4, _bf5)); + vst1q_u16(p0 + 24, vcombine_u16(_bf6, _bf7)); + p0 += 32; + } + if (out_elempack == 1) + { + transpose4x4_u16(_bf0, _bf1, _bf2, _bf3); + transpose4x4_u16(_bf4, _bf5, _bf6, _bf7); + vst1q_u16(p0, vcombine_u16(_bf0, _bf4)); + vst1q_u16(p0 + out_hstep, vcombine_u16(_bf1, _bf5)); + vst1q_u16(p0 + out_hstep * 2, vcombine_u16(_bf2, _bf6)); + vst1q_u16(p0 + out_hstep * 3, vcombine_u16(_bf3, _bf7)); + p0 += 8; + } + + pp += 32; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 +#else + // from + // a0 b1 c2 d3 + // c0 d1 a2 b3 + // a3 b2 c1 d0 + // c3 d2 a1 b0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + { + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + _sum2 = vextq_s32(_sum2, _sum2, 2); + _sum3 = vextq_s32(_sum3, _sum3, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum3); + int32x4x2_t _t1 = vzipq_s32(_sum1, _sum2); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 3) + { + float32x4_t _c1; + float32x4_t _c2; + float32x4_t _c3; + if (c_elempack == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + pC += 16; + } + if (c_elempack == 1) + { + uint16x4_t _cc0 = vld1_u16(pC); + uint16x4_t _cc1 = vld1_u16(pC + c_hstep * 1); + uint16x4_t _cc2 = vld1_u16(pC + c_hstep * 2); + uint16x4_t _cc3 = vld1_u16(pC + c_hstep * 3); + transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); + _c0 = bfloat2float(_cc0); + _c1 = bfloat2float(_cc1); + _c2 = bfloat2float(_cc2); + _c3 = bfloat2float(_cc3); + pC += 4; + } + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + } + if (broadcast_type_C == 4) + { + float32x4_t _c = bfloat2float(vld1_u16(pC)); + _c = vmulq_n_f32(_c, beta); +#if __aarch64__ + _c0 = vdupq_laneq_f32(_c, 0); + float32x4_t _c1 = vdupq_laneq_f32(_c, 1); + float32x4_t _c2 = vdupq_laneq_f32(_c, 2); + float32x4_t _c3 = vdupq_laneq_f32(_c, 3); +#else + _c0 = vdupq_lane_f32(vget_low_f32(_c), 0); + float32x4_t _c1 = vdupq_lane_f32(vget_low_f32(_c), 1); + float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_c), 0); + float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_c), 1); +#endif + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 4; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + + uint16x4_t _bf0 = float2bfloat(_f0); + uint16x4_t _bf1 = float2bfloat(_f1); + uint16x4_t _bf2 = float2bfloat(_f2); + uint16x4_t _bf3 = float2bfloat(_f3); + + if (out_elempack == 4) + { + vst1q_u16(p0, vcombine_u16(_bf0, _bf1)); + vst1q_u16(p0 + 8, vcombine_u16(_bf2, _bf3)); + p0 += 16; + } + if (out_elempack == 1) + { + transpose4x4_u16(_bf0, _bf1, _bf2, _bf3); + vst1_u16(p0, _bf0); + vst1_u16(p0 + out_hstep, _bf1); + vst1_u16(p0 + out_hstep * 2, _bf2); + vst1_u16(p0 + out_hstep * 3, _bf3); + p0 += 4; + } + + pp += 16; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 +#else + // from + // a0 b1 c0 d1 + // a1 b0 c1 d0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + { + _sum1 = vrev64q_s32(_sum1); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); + _sum1 = vrev64q_s32(_sum1); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 3) + { + uint16x8_t _c; + if (c_elempack == 4) + { + _c = vld1q_u16(pC); + pC += 8; + } + if (c_elempack == 1) + { + _c = uint16x8_t(); + _c = vsetq_lane_u16(pC[0], _c, 0); + _c = vsetq_lane_u16(pC[c_hstep], _c, 1); + _c = vsetq_lane_u16(pC[c_hstep * 2], _c, 2); + _c = vsetq_lane_u16(pC[c_hstep * 3], _c, 3); + _c = vsetq_lane_u16(pC[1], _c, 4); + _c = vsetq_lane_u16(pC[c_hstep + 1], _c, 5); + _c = vsetq_lane_u16(pC[c_hstep * 2 + 1], _c, 6); + _c = vsetq_lane_u16(pC[c_hstep * 3 + 1], _c, 7); + pC += 2; + } + _c0 = bfloat2float(vget_low_u16(_c)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); + float32x4_t _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1]) * beta); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 2; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + + uint16x4_t _bf0 = float2bfloat(_f0); + uint16x4_t _bf1 = float2bfloat(_f1); + + if (out_elempack == 4) + { + vst1q_u16(p0, vcombine_u16(_bf0, _bf1)); + p0 += 8; + } + if (out_elempack == 1) + { + p0[0] = vget_lane_u16(_bf0, 0); + p0[1] = vget_lane_u16(_bf1, 0); + p0[out_hstep] = vget_lane_u16(_bf0, 1); + p0[out_hstep + 1] = vget_lane_u16(_bf1, 1); + p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); + p0[out_hstep * 2 + 1] = vget_lane_u16(_bf1, 2); + p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); + p0[out_hstep * 3 + 1] = vget_lane_u16(_bf1, 3); + p0 += 2; + } + + pp += 8; + } + for (; jj < max_jj; jj++) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 3) + { + uint16x4_t _c; + if (c_elempack == 4) + { + _c = vld1_u16(pC); + pC += 4; + } + if (c_elempack == 1) + { + _c = uint16x4_t(); + _c = vset_lane_u16(pC[0], _c, 0); + _c = vset_lane_u16(pC[c_hstep], _c, 1); + _c = vset_lane_u16(pC[c_hstep * 2], _c, 2); + _c = vset_lane_u16(pC[c_hstep * 3], _c, 3); + pC += 1; + } + _c0 = bfloat2float(_c); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); + _f0 = vaddq_f32(_f0, _c0); + pC += 1; + } + } + + _f0 = vmulq_n_f32(_f0, alpha); + + uint16x4_t _bf0 = float2bfloat(_f0); + + if (out_elempack == 4) + { + vst1_u16(p0, _bf0); + p0 += 4; + } + if (out_elempack == 1) + { + p0[0] = vget_lane_u16(_bf0, 0); + p0[out_hstep] = vget_lane_u16(_bf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); + p0++; + } + + pp += 4; + } + } +#endif // __ARM_NEON + for (; ii + 1 < max_ii; ii += 2) + { + // out_elempack == 1 + unsigned short* p0 = (unsigned short*)top_blob + (i + ii) * out_hstep + j; + + const float descale0 = descales[ii]; + const float descale1 = descales[ii + 1]; +#if __ARM_NEON + float32x2_t _descale = vld1_f32((const float*)descales + ii); +#endif + + float c0; + float c1; +#if __ARM_NEON + float32x4_t _c0; + float32x4_t _c1; +#endif + if (pC) + { + if (broadcast_type_C == 0) + { + c0 = bfloat16_to_float32(pC[0]) * beta; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); +#endif + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const unsigned short*)C + i + ii; + c0 = bfloat16_to_float32(pC[0]) * beta; + c1 = bfloat16_to_float32(pC[1]) * beta; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); + _c1 = vdupq_n_f32(c1); +#endif + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + pC = (const unsigned short*)C + (i + ii) * c_hstep + j; + } + if (broadcast_type_C == 4) + { + pC = (const unsigned short*)C + j; + } + } + + int jj = 0; +#if __ARM_NEON +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale, 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale, 0); + float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), _descale, 1); + float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), _descale, 1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c1); + _f3 = vaddq_f32(_f3, _c1); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + c_hstep); + _c0 = bfloat2float(vget_low_u16(_c01)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + pC += 8; + } + if (broadcast_type_C == 4) + { + uint16x8_t _c = vld1q_u16(pC); + _c0 = bfloat2float(vget_low_u16(_c)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _c0 = vmulq_f32(_c0, _beta); + _c1 = vmulq_f32(_c1, _beta); + } + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + pC += 8; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); + vst1q_u16(p0 + out_hstep, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); + + pp += 16; + p0 += 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale, 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale, 1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + _c0 = bfloat2float(vld1_u16(pC)); + float32x4_t _c1 = bfloat2float(vld1_u16(pC + c_hstep)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + } + pC += 4; + } + if (broadcast_type_C == 4) + { + _c0 = bfloat2float(vld1_u16(pC)); + _c0 = vmulq_n_f32(_c0, beta); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + pC += 4; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + + vst1_u16(p0, float2bfloat(_f0)); + vst1_u16(p0 + out_hstep, float2bfloat(_f1)); + + pp += 8; + p0 += 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + + float32x2x2_t _descale01 = vzip_f32(_descale, _descale); + float32x4_t _descale0011 = vcombine_f32(_descale01.val[0], _descale01.val[1]); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0011); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + float32x4_t _c0011 = vcombine_f32(vget_low_f32(_c0), vget_high_f32(_c1)); + _f0 = vaddq_f32(_f0, _c0011); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + uint16x4_t _c = uint16x4_t(); + _c = vset_lane_u16(pC[0], _c, 0); + _c = vset_lane_u16(pC[1], _c, 1); + _c = vset_lane_u16(pC[c_hstep], _c, 2); + _c = vset_lane_u16(pC[c_hstep + 1], _c, 3); + _c0 = bfloat2float(_c); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + pC += 2; + } + if (broadcast_type_C == 4) + { + uint16x4_t _c = uint16x4_t(); + _c = vset_lane_u16(pC[0], _c, 0); + _c = vset_lane_u16(pC[1], _c, 1); + _c = vset_lane_u16(pC[0], _c, 2); + _c = vset_lane_u16(pC[1], _c, 3); + _c0 = bfloat2float(_c); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + pC += 2; + } + } + + _f0 = vmulq_n_f32(_f0, alpha); + + uint16x4_t _bf0 = float2bfloat(_f0); + + p0[0] = vget_lane_u16(_bf0, 0); + p0[1] = vget_lane_u16(_bf0, 1); + p0[out_hstep] = vget_lane_u16(_bf0, 2); + p0[out_hstep + 1] = vget_lane_u16(_bf0, 3); + + pp += 4; + p0 += 2; + } +#endif // __ARM_NEON + for (; jj < max_jj; jj++) + { + float f0 = pp[0] * descale0; + float f1 = pp[1] * descale1; + + if (pC) + { + if (broadcast_type_C == 0) + { + f0 += c0; + f1 += c0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + f0 += c0; + f1 += c1; + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + f0 += bfloat16_to_float32(pC[0]) * beta; + f1 += bfloat16_to_float32(pC[c_hstep]) * beta; + pC += 1; + } + if (broadcast_type_C == 4) + { + f0 += bfloat16_to_float32(pC[0]) * beta; + f1 += bfloat16_to_float32(pC[0]) * beta; + pC += 1; + } + } + + if (alpha != 1.f) + { + f0 *= alpha; + f1 *= alpha; + } + + p0[0] = float32_to_bfloat16(f0); + p0[out_hstep] = float32_to_bfloat16(f1); + + pp += 2; + p0++; + } + } + for (; ii < max_ii; ii += 1) + { + // out_elempack == 1 + unsigned short* p0 = (unsigned short*)top_blob + (i + ii) * out_hstep + j; + + const float descale = descales[ii]; +#if __ARM_NEON + float32x4_t _descale = vdupq_n_f32(descale); +#endif + + float c0; +#if __ARM_NEON + float32x4_t _c0; +#endif + if (pC) + { + if (broadcast_type_C == 0) + { + c0 = bfloat16_to_float32(pC[0]) * beta; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); +#endif + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const unsigned short*)C + i + ii; + c0 = bfloat16_to_float32(pC[0]) * beta; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); +#endif + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + pC = (const unsigned short*)C + (i + ii) * c_hstep + j; + } + if (broadcast_type_C == 4) + { + pC = (const unsigned short*)C + j; + } + } + + int jj = 0; +#if __ARM_NEON + for (; jj + 15 < max_jj; jj += 16) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + _c0 = bfloat2float(vget_low_u16(_c01)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + pC += 16; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); + vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); + + pp += 16; + p0 += 16; + } + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + uint16x8_t _c01 = vld1q_u16(pC); + _c0 = bfloat2float(vget_low_u16(_c01)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + } + pC += 8; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); + + pp += 8; + p0 += 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + _c0 = bfloat2float(vld1_u16(pC)); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + pC += 4; + } + } + + _f0 = vmulq_n_f32(_f0, alpha); + + vst1_u16(p0, float2bfloat(_f0)); + + pp += 4; + p0 += 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + float32x2_t _f0 = vmul_f32(vcvt_f32_s32(vld1_s32(pp)), vget_low_f32(_descale)); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vadd_f32(_f0, vget_low_f32(_c0)); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + float32x2_t _cc = float32x2_t(); + _cc = vset_lane_f32(bfloat16_to_float32(pC[0]), _cc, 0); + _cc = vset_lane_f32(bfloat16_to_float32(pC[1]), _cc, 1); + _f0 = vmla_n_f32(_f0, _cc, beta); + pC += 2; + } + } + + _f0 = vmul_n_f32(_f0, alpha); + + p0[0] = float32_to_bfloat16(vget_lane_f32(_f0, 0)); + p0[1] = float32_to_bfloat16(vget_lane_f32(_f0, 1)); + + pp += 2; + p0 += 2; + } +#endif // __ARM_NEON + for (; jj < max_jj; jj++) + { + float f0 = pp[0] * descale; + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + f0 += c0; + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + f0 += bfloat16_to_float32(pC[0]) * beta; + pC += 1; + } + } + + f0 *= alpha; + + p0[0] = float32_to_bfloat16(f0); + + pp += 1; + p0++; + } + } +} + +static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + transpose_unpack_output_tile_int32_to_bf16_asimddp(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales, alpha, beta); + return; + } +#endif + + const int out_elempack = top_blob.elempack; + const int out_hstep = top_blob.dims == 3 ? (int)top_blob.cstep : top_blob.w; + + const int c_hstep = C.dims == 3 ? (int)C.cstep : C.w; + const int c_elempack = C.elempack; + const unsigned short* pC = C; + + // NCNN_LOGE("transpose_unpack_output_tile_int32_to_bf16 %d %d %d %d %d %d %d", i, max_ii, j, max_jj, out_elempack, broadcast_type_C, c_elempack); + + const int* pp = topT; + + int ii = 0; +#if __ARM_NEON + for (; ii + 7 < max_ii; ii += 8) + { + unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack; + + float32x4_t _descale0 = vld1q_f32((const float*)descales + ii); + float32x4_t _descale1 = vld1q_f32((const float*)descales + ii + 4); + + float32x4_t _c0; + float32x4_t _c1; + if (pC) + { + if (broadcast_type_C == 0) + { + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const unsigned short*)C + i + ii; + uint16x8_t _c = vld1q_u16(pC); + _c0 = bfloat2float(vget_low_u16(_c)); + _c1 = bfloat2float(vget_high_u16(_c)); + _c0 = vmulq_n_f32(_c0, beta); + _c1 = vmulq_n_f32(_c1, beta); + } + if (broadcast_type_C == 3) + { + pC = (const unsigned short*)C + (i + ii) * c_hstep + j * c_elempack; + } + if (broadcast_type_C == 4) + { + pC = (const unsigned short*)C + j; + } + } + + int jj = 0; +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + int32x4_t _sum8 = vld1q_s32(pp + 32); + int32x4_t _sum9 = vld1q_s32(pp + 36); + int32x4_t _suma = vld1q_s32(pp + 40); + int32x4_t _sumb = vld1q_s32(pp + 44); + int32x4_t _sumc = vld1q_s32(pp + 48); + int32x4_t _sumd = vld1q_s32(pp + 52); + int32x4_t _sume = vld1q_s32(pp + 56); + int32x4_t _sumf = vld1q_s32(pp + 60); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 +#else + // from + // a0 b1 c2 d3 + // e4 f5 g6 h7 + // e0 f1 g2 h3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // g4 h5 e6 f7 + // g0 h1 e2 f3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // e7 f6 g5 h4 + // e3 f2 g1 h0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // g7 h6 e5 f4 + // g3 h2 e1 f0 + // c7 d6 a5 b4 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 + { + _sum8 = vrev64q_s32(_sum8); + _sum9 = vrev64q_s32(_sum9); + _suma = vrev64q_s32(_suma); + _sumb = vrev64q_s32(_sumb); + _sumc = vrev64q_s32(_sumc); + _sumd = vrev64q_s32(_sumd); + _sume = vrev64q_s32(_sume); + _sumf = vrev64q_s32(_sumf); + _sum8 = vextq_s32(_sum8, _sum8, 2); + _sum9 = vextq_s32(_sum9, _sum9, 2); + _suma = vextq_s32(_suma, _suma, 2); + _sumb = vextq_s32(_sumb, _sumb, 2); + _sumc = vextq_s32(_sumc, _sumc, 2); + _sumd = vextq_s32(_sumd, _sumd, 2); + _sume = vextq_s32(_sume, _sume, 2); + _sumf = vextq_s32(_sumf, _sumf, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sumc); + int32x4x2_t _t1 = vzipq_s32(_sum4, _sum8); + int32x4x2_t _t2 = vzipq_s32(_sum2, _sume); + int32x4x2_t _t3 = vzipq_s32(_sum6, _suma); + int32x4x2_t _t4 = vzipq_s32(_sum3, _sumf); + int32x4x2_t _t5 = vzipq_s32(_sum7, _sumb); + int32x4x2_t _t6 = vzipq_s32(_sum1, _sumd); + int32x4x2_t _t7 = vzipq_s32(_sum5, _sum9); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum8 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); + _sum9 = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); + _suma = vcombine_s32(vget_low_s32(_t5.val[1]), vget_low_s32(_t4.val[1])); + _sumb = vcombine_s32(vget_high_s32(_t5.val[1]), vget_high_s32(_t4.val[1])); + _sumc = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); + _sumd = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); + _sume = vcombine_s32(vget_low_s32(_t7.val[1]), vget_low_s32(_t6.val[1])); + _sumf = vcombine_s32(vget_high_s32(_t7.val[1]), vget_high_s32(_t6.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + _sum9 = vrev64q_s32(_sum9); + _sumb = vrev64q_s32(_sumb); + _sumd = vrev64q_s32(_sumd); + _sumf = vrev64q_s32(_sumf); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum8), _descale0); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum9), _descale0); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_suma), _descale0); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sumb), _descale0); + float32x4_t _f8 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); + float32x4_t _f9 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); + float32x4_t _fa = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); + float32x4_t _fb = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); + float32x4_t _fc = vmulq_f32(vcvtq_f32_s32(_sumc), _descale1); + float32x4_t _fd = vmulq_f32(vcvtq_f32_s32(_sumd), _descale1); + float32x4_t _fe = vmulq_f32(vcvtq_f32_s32(_sume), _descale1); + float32x4_t _ff = vmulq_f32(vcvtq_f32_s32(_sumf), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c0); + _fa = vaddq_f32(_fa, _c0); + _fb = vaddq_f32(_fb, _c0); + _fc = vaddq_f32(_fc, _c0); + _fd = vaddq_f32(_fd, _c0); + _fe = vaddq_f32(_fe, _c0); + _ff = vaddq_f32(_ff, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + _f8 = vaddq_f32(_f8, _c1); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c1); + _fb = vaddq_f32(_fb, _c1); + _fc = vaddq_f32(_fc, _c1); + _fd = vaddq_f32(_fd, _c1); + _fe = vaddq_f32(_fe, _c1); + _ff = vaddq_f32(_ff, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + uint16x8_t _c45 = vld1q_u16(pC + 16); + uint16x8_t _c67 = vld1q_u16(pC + 24); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); + float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); + float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); + float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c01 = vld1q_u16(pC + c_hstep * 4); + _c23 = vld1q_u16(pC + c_hstep * 4 + 8); + _c45 = vld1q_u16(pC + c_hstep * 4 + 16); + _c67 = vld1q_u16(pC + c_hstep * 4 + 24); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + _c4 = bfloat2float(vget_low_u16(_c45)); + _c5 = bfloat2float(vget_high_u16(_c45)); + _c6 = bfloat2float(vget_low_u16(_c67)); + _c7 = bfloat2float(vget_high_u16(_c67)); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); + } + pC += 32; + } + if (c_elempack == 1) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + c_hstep); + uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); + uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); + transpose8x4_u16(_c01, _c23, _c45, _c67); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); + float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); + float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); + float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c01 = vld1q_u16(pC + c_hstep * 4); + _c23 = vld1q_u16(pC + c_hstep * 5); + _c45 = vld1q_u16(pC + c_hstep * 6); + _c67 = vld1q_u16(pC + c_hstep * 7); + transpose8x4_u16(_c01, _c23, _c45, _c67); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + _c4 = bfloat2float(vget_low_u16(_c45)); + _c5 = bfloat2float(vget_high_u16(_c45)); + _c6 = bfloat2float(vget_low_u16(_c67)); + _c7 = bfloat2float(vget_high_u16(_c67)); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); + } + pC += 8; + } + } + if (broadcast_type_C == 4) + { + uint16x8_t _c = vld1q_u16(pC); + float32x4_t _cc0 = bfloat2float(vget_low_u16(_c)); + float32x4_t _cc1 = bfloat2float(vget_high_u16(_c)); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _cc0 = vmulq_f32(_cc0, _beta); + _cc1 = vmulq_f32(_cc1, _beta); + } + _c0 = vdupq_laneq_f32(_cc0, 0); + _c1 = vdupq_laneq_f32(_cc0, 1); + float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); + float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); + float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); + float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); + float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); + float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + pC += 8; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + _f8 = vmulq_f32(_f8, _alpha); + _f9 = vmulq_f32(_f9, _alpha); + _fa = vmulq_f32(_fa, _alpha); + _fb = vmulq_f32(_fb, _alpha); + _fc = vmulq_f32(_fc, _alpha); + _fd = vmulq_f32(_fd, _alpha); + _fe = vmulq_f32(_fe, _alpha); + _ff = vmulq_f32(_ff, _alpha); + } + + uint16x8_t _bf0 = vcombine_u16(float2bfloat(_f0), float2bfloat(_f8)); + uint16x8_t _bf1 = vcombine_u16(float2bfloat(_f1), float2bfloat(_f9)); + uint16x8_t _bf2 = vcombine_u16(float2bfloat(_f2), float2bfloat(_fa)); + uint16x8_t _bf3 = vcombine_u16(float2bfloat(_f3), float2bfloat(_fb)); + uint16x8_t _bf4 = vcombine_u16(float2bfloat(_f4), float2bfloat(_fc)); + uint16x8_t _bf5 = vcombine_u16(float2bfloat(_f5), float2bfloat(_fd)); + uint16x8_t _bf6 = vcombine_u16(float2bfloat(_f6), float2bfloat(_fe)); + uint16x8_t _bf7 = vcombine_u16(float2bfloat(_f7), float2bfloat(_ff)); + + if (out_elempack == 4) + { + uint16x8x4_t _bfa; + uint16x8x4_t _bfb; + _bfa.val[0] = _bf0; + _bfa.val[1] = _bf1; + _bfa.val[2] = _bf2; + _bfa.val[3] = _bf3; + _bfb.val[0] = _bf4; + _bfb.val[1] = _bf5; + _bfb.val[2] = _bf6; + _bfb.val[3] = _bf7; + vst4q_u16(p0, _bfa); + vst4q_u16(p0 + out_hstep * 4, _bfb); + } + if (out_elempack == 1) + { + vst1q_u16(p0, _bf0); + vst1q_u16(p0 + out_hstep, _bf1); + vst1q_u16(p0 + out_hstep * 2, _bf2); + vst1q_u16(p0 + out_hstep * 3, _bf3); + vst1q_u16(p0 + out_hstep * 4, _bf4); + vst1q_u16(p0 + out_hstep * 5, _bf5); + vst1q_u16(p0 + out_hstep * 6, _bf6); + vst1q_u16(p0 + out_hstep * 7, _bf7); + } + + pp += 64; + p0 += out_hstep * 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + +#else + // from + // a0 b1 c2 d3 + // e0 f1 g2 h3 + // c0 d1 a2 b3 + // g0 h1 e2 f3 + // a3 b2 c1 d0 + // e3 f2 g1 h0 + // c3 d2 a1 b0 + // g3 h2 e1 f0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + + { + _sum4 = vrev64q_s32(_sum4); + _sum5 = vrev64q_s32(_sum5); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + _sum4 = vextq_s32(_sum4, _sum4, 2); + _sum5 = vextq_s32(_sum5, _sum5, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c1); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c1); + _f7 = vaddq_f32(_f7, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + _c01 = vld1q_u16(pC + c_hstep * 4); + _c23 = vld1q_u16(pC + c_hstep * 4 + 8); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); + } + pC += 16; + } + if (c_elempack == 1) + { + uint16x4_t _cc0 = vld1_u16(pC); + uint16x4_t _cc1 = vld1_u16(pC + c_hstep); + uint16x4_t _cc2 = vld1_u16(pC + c_hstep * 2); + uint16x4_t _cc3 = vld1_u16(pC + c_hstep * 3); + transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); + _c0 = bfloat2float(_cc0); + _c1 = bfloat2float(_cc1); + float32x4_t _c2 = bfloat2float(_cc2); + float32x4_t _c3 = bfloat2float(_cc3); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + _cc0 = vld1_u16(pC + c_hstep * 4); + _cc1 = vld1_u16(pC + c_hstep * 5); + _cc2 = vld1_u16(pC + c_hstep * 6); + _cc3 = vld1_u16(pC + c_hstep * 7); + transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); + _c0 = bfloat2float(_cc0); + _c1 = bfloat2float(_cc1); + _c2 = bfloat2float(_cc2); + _c3 = bfloat2float(_cc3); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); + } + pC += 4; + } + } + if (broadcast_type_C == 4) + { + float32x4_t _c = bfloat2float(vld1_u16(pC)); + _c = vmulq_n_f32(_c, beta); +#if __aarch64__ + _c0 = vdupq_laneq_f32(_c, 0); + _c1 = vdupq_laneq_f32(_c, 1); + float32x4_t _c2 = vdupq_laneq_f32(_c, 2); + float32x4_t _c3 = vdupq_laneq_f32(_c, 3); +#else + _c0 = vdupq_lane_f32(vget_low_f32(_c), 0); + _c1 = vdupq_lane_f32(vget_low_f32(_c), 1); + float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_c), 0); + float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_c), 1); +#endif + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + pC += 4; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + + uint16x8_t _bf0 = vcombine_u16(float2bfloat(_f0), float2bfloat(_f4)); + uint16x8_t _bf1 = vcombine_u16(float2bfloat(_f1), float2bfloat(_f5)); + uint16x8_t _bf2 = vcombine_u16(float2bfloat(_f2), float2bfloat(_f6)); + uint16x8_t _bf3 = vcombine_u16(float2bfloat(_f3), float2bfloat(_f7)); + + if (out_elempack == 4) + { + uint16x8x4_t _bf; + _bf.val[0] = _bf0; + _bf.val[1] = _bf1; + _bf.val[2] = _bf2; + _bf.val[3] = _bf3; + vst4q_u16(p0, _bf); + } + if (out_elempack == 1) + { + vst1q_u16(p0, _bf0); + vst1q_u16(p0 + out_hstep, _bf1); + vst1q_u16(p0 + out_hstep * 2, _bf2); + vst1q_u16(p0 + out_hstep * 3, _bf3); + } + + pp += 32; + p0 += out_hstep * 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // e0 f0 g0 h0 + // e1 f1 g1 h1 +#else + // from + // a0 b1 c0 d1 + // e0 f1 g0 h1 + // a1 b0 c1 d0 + // e1 f0 g1 h0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + { + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum2); + int32x4x2_t _t1 = vzipq_s32(_sum1, _sum3); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[0]), vget_low_s32(_t1.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[0]), vget_high_s32(_t1.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale1); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c1); + _f3 = vaddq_f32(_f3, _c1); + } + if (broadcast_type_C == 3) + { + float32x4_t _c2; + float32x4_t _c3; + if (c_elempack == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + c_hstep * 4); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + pC += 8; + } + if (c_elempack == 1) + { + uint16x8_t _c01 = uint16x8_t(); + _c01 = vsetq_lane_u16(pC[0], _c01, 0); + _c01 = vsetq_lane_u16(pC[c_hstep], _c01, 1); + _c01 = vsetq_lane_u16(pC[c_hstep * 2], _c01, 2); + _c01 = vsetq_lane_u16(pC[c_hstep * 3], _c01, 3); + _c01 = vsetq_lane_u16(pC[c_hstep * 4], _c01, 4); + _c01 = vsetq_lane_u16(pC[c_hstep * 5], _c01, 5); + _c01 = vsetq_lane_u16(pC[c_hstep * 6], _c01, 6); + _c01 = vsetq_lane_u16(pC[c_hstep * 7], _c01, 7); + + uint16x8_t _c23 = uint16x8_t(); + _c23 = vsetq_lane_u16(pC[1], _c23, 0); + _c23 = vsetq_lane_u16(pC[c_hstep + 1], _c23, 1); + _c23 = vsetq_lane_u16(pC[c_hstep * 2 + 1], _c23, 2); + _c23 = vsetq_lane_u16(pC[c_hstep * 3 + 1], _c23, 3); + _c23 = vsetq_lane_u16(pC[c_hstep * 4 + 1], _c23, 4); + _c23 = vsetq_lane_u16(pC[c_hstep * 5 + 1], _c23, 5); + _c23 = vsetq_lane_u16(pC[c_hstep * 6 + 1], _c23, 6); + _c23 = vsetq_lane_u16(pC[c_hstep * 7 + 1], _c23, 7); + + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_low_u16(_c23)); + _c2 = bfloat2float(vget_high_u16(_c01)); + _c3 = bfloat2float(vget_high_u16(_c23)); + pC += 2; + } + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); + _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1]) * beta); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + pC += 2; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f2))); + vst1q_u16(p0 + out_hstep, vcombine_u16(float2bfloat(_f1), float2bfloat(_f3))); + + pp += 16; + p0 += out_hstep * 2; + } + for (; jj < max_jj; jj += 1) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp + 4)), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 4) + { + _c0 = bfloat2float(vld1_u16(pC)); + _c1 = bfloat2float(vld1_u16(pC + c_hstep * 4)); + pC += 4; + } + if (c_elempack == 1) + { + uint16x8_t _c01 = uint16x8_t(); + _c01 = vsetq_lane_u16(pC[0], _c01, 0); + _c01 = vsetq_lane_u16(pC[c_hstep], _c01, 1); + _c01 = vsetq_lane_u16(pC[c_hstep * 2], _c01, 2); + _c01 = vsetq_lane_u16(pC[c_hstep * 3], _c01, 3); + _c01 = vsetq_lane_u16(pC[c_hstep * 4], _c01, 4); + _c01 = vsetq_lane_u16(pC[c_hstep * 5], _c01, 5); + _c01 = vsetq_lane_u16(pC[c_hstep * 6], _c01, 6); + _c01 = vsetq_lane_u16(pC[c_hstep * 7], _c01, 7); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + pC += 1; + } + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + pC += 1; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); + pp += 8; + p0 += out_hstep; + } + } + for (; ii + 3 < max_ii; ii += 4) + { + unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack; + + float32x4_t _descale = vld1q_f32((const float*)descales + ii); + + float32x4_t _c0; + if (pC) + { + if (broadcast_type_C == 0) + { + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const unsigned short*)C + i + ii; + _c0 = bfloat2float(vld1_u16(pC)); + _c0 = vmulq_n_f32(_c0, beta); + } + if (broadcast_type_C == 3) + { + pC = (const unsigned short*)C + (i + ii) * c_hstep + j * c_elempack; + } + if (broadcast_type_C == 4) + { + pC = (const unsigned short*)C + j; + } + } + + int jj = 0; +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 +#else + // from + // a0 b1 c2 d3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // c7 d6 a5 b4 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + { + _sum4 = vrev64q_s32(_sum4); + _sum5 = vrev64q_s32(_sum5); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + _sum4 = vextq_s32(_sum4, _sum4, 2); + _sum5 = vextq_s32(_sum5, _sum5, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 3) + { + uint16x8_t _c01; + uint16x8_t _c23; + uint16x8_t _c45; + uint16x8_t _c67; + if (c_elempack == 4) + { + _c01 = vld1q_u16(pC); + _c23 = vld1q_u16(pC + 8); + _c45 = vld1q_u16(pC + 16); + _c67 = vld1q_u16(pC + 24); + pC += 32; + } + if (c_elempack == 1) + { + _c01 = vld1q_u16(pC); + _c23 = vld1q_u16(pC + c_hstep); + _c45 = vld1q_u16(pC + c_hstep * 2); + _c67 = vld1q_u16(pC + c_hstep * 3); + transpose8x4_u16(_c01, _c23, _c45, _c67); + pC += 8; + } + _c0 = bfloat2float(vget_low_u16(_c01)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); + float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); + float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); + float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + } + if (broadcast_type_C == 4) + { + uint16x8_t _c = vld1q_u16(pC); + float32x4_t _cc0 = bfloat2float(vget_low_u16(_c)); + float32x4_t _cc1 = bfloat2float(vget_high_u16(_c)); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _cc0 = vmulq_f32(_cc0, _beta); + _cc1 = vmulq_f32(_cc1, _beta); + } + _c0 = vdupq_laneq_f32(_cc0, 0); + float32x4_t _c1 = vdupq_laneq_f32(_cc0, 1); + float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); + float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); + float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); + float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); + float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); + float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + pC += 8; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + + uint16x4_t _bf0 = float2bfloat(_f0); + uint16x4_t _bf1 = float2bfloat(_f1); + uint16x4_t _bf2 = float2bfloat(_f2); + uint16x4_t _bf3 = float2bfloat(_f3); + uint16x4_t _bf4 = float2bfloat(_f4); + uint16x4_t _bf5 = float2bfloat(_f5); + uint16x4_t _bf6 = float2bfloat(_f6); + uint16x4_t _bf7 = float2bfloat(_f7); + + if (out_elempack == 4) + { + uint16x4x4_t _bfa; + uint16x4x4_t _bfb; + _bfa.val[0] = _bf0; + _bfa.val[1] = _bf1; + _bfa.val[2] = _bf2; + _bfa.val[3] = _bf3; + _bfb.val[0] = _bf4; + _bfb.val[1] = _bf5; + _bfb.val[2] = _bf6; + _bfb.val[3] = _bf7; + vst4_u16(p0, _bfa); + vst4_u16(p0 + out_hstep * 4, _bfb); + } + if (out_elempack == 1) + { + vst1_u16(p0, _bf0); + vst1_u16(p0 + out_hstep, _bf1); + vst1_u16(p0 + out_hstep * 2, _bf2); + vst1_u16(p0 + out_hstep * 3, _bf3); + vst1_u16(p0 + out_hstep * 4, _bf4); + vst1_u16(p0 + out_hstep * 5, _bf5); + vst1_u16(p0 + out_hstep * 6, _bf6); + vst1_u16(p0 + out_hstep * 7, _bf7); + } + + pp += 32; + p0 += out_hstep * 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 +#else + // from + // a0 b1 c2 d3 + // c0 d1 a2 b3 + // a3 b2 c1 d0 + // c3 d2 a1 b0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + { + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + _sum2 = vextq_s32(_sum2, _sum2, 2); + _sum3 = vextq_s32(_sum3, _sum3, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum3); + int32x4x2_t _t1 = vzipq_s32(_sum1, _sum2); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 3) + { + float32x4_t _c1; + float32x4_t _c2; + float32x4_t _c3; + if (c_elempack == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + pC += 16; + } + if (c_elempack == 1) + { + uint16x4_t _cc0 = vld1_u16(pC); + uint16x4_t _cc1 = vld1_u16(pC + c_hstep); + uint16x4_t _cc2 = vld1_u16(pC + c_hstep * 2); + uint16x4_t _cc3 = vld1_u16(pC + c_hstep * 3); + transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); + _c0 = bfloat2float(_cc0); + _c1 = bfloat2float(_cc1); + _c2 = bfloat2float(_cc2); + _c3 = bfloat2float(_cc3); + pC += 4; + } + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + } + if (broadcast_type_C == 4) + { + float32x4_t _c = bfloat2float(vld1_u16(pC)); + _c = vmulq_n_f32(_c, beta); +#if __aarch64__ + _c0 = vdupq_laneq_f32(_c, 0); + float32x4_t _c1 = vdupq_laneq_f32(_c, 1); + float32x4_t _c2 = vdupq_laneq_f32(_c, 2); + float32x4_t _c3 = vdupq_laneq_f32(_c, 3); +#else + _c0 = vdupq_lane_f32(vget_low_f32(_c), 0); + float32x4_t _c1 = vdupq_lane_f32(vget_low_f32(_c), 1); + float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_c), 0); + float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_c), 1); +#endif + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 4; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + + uint16x4_t _bf0 = float2bfloat(_f0); + uint16x4_t _bf1 = float2bfloat(_f1); + uint16x4_t _bf2 = float2bfloat(_f2); + uint16x4_t _bf3 = float2bfloat(_f3); + + if (out_elempack == 4) + { + uint16x4x4_t _bf; + _bf.val[0] = _bf0; + _bf.val[1] = _bf1; + _bf.val[2] = _bf2; + _bf.val[3] = _bf3; + vst4_u16(p0, _bf); + } + if (out_elempack == 1) + { + vst1_u16(p0, _bf0); + vst1_u16(p0 + out_hstep, _bf1); + vst1_u16(p0 + out_hstep * 2, _bf2); + vst1_u16(p0 + out_hstep * 3, _bf3); + } + + pp += 16; + p0 += out_hstep * 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 +#else + // from + // a0 b1 c0 d1 + // a1 b0 c1 d0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + { + _sum1 = vrev64q_s32(_sum1); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); + _sum1 = vrev64q_s32(_sum1); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 3) + { + uint16x8_t _c; + if (c_elempack == 4) + { + _c = vld1q_u16(pC); + pC += 8; + } + if (c_elempack == 1) + { + _c = uint16x8_t(); + _c = vsetq_lane_u16(pC[0], _c, 0); + _c = vsetq_lane_u16(pC[c_hstep], _c, 1); + _c = vsetq_lane_u16(pC[c_hstep * 2], _c, 2); + _c = vsetq_lane_u16(pC[c_hstep * 3], _c, 3); + _c = vsetq_lane_u16(pC[1], _c, 4); + _c = vsetq_lane_u16(pC[c_hstep + 1], _c, 5); + _c = vsetq_lane_u16(pC[c_hstep * 2 + 1], _c, 6); + _c = vsetq_lane_u16(pC[c_hstep * 3 + 1], _c, 7); + pC += 2; + } + _c0 = bfloat2float(vget_low_u16(_c)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); + float32x4_t _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1]) * beta); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 2; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + + vst1_u16(p0, float2bfloat(_f0)); + vst1_u16(p0 + out_hstep, float2bfloat(_f1)); + + pp += 8; + p0 += out_hstep * 2; + } + for (; jj < max_jj; jj += 1) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 3) + { + uint16x4_t _c; + if (c_elempack == 4) + { + _c = vld1_u16(pC); + pC += 4; + } + if (c_elempack == 1) + { + _c = uint16x4_t(); + _c = vset_lane_u16(pC[0], _c, 0); + _c = vset_lane_u16(pC[c_hstep], _c, 1); + _c = vset_lane_u16(pC[c_hstep * 2], _c, 2); + _c = vset_lane_u16(pC[c_hstep * 3], _c, 3); + pC += 1; + } + _c0 = bfloat2float(_c); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); + _f0 = vaddq_f32(_f0, _c0); + pC += 1; + } + } + + _f0 = vmulq_n_f32(_f0, alpha); + + vst1_u16(p0, float2bfloat(_f0)); + pp += 4; + p0 += out_hstep; + } + } +#endif // __ARM_NEON + for (; ii + 1 < max_ii; ii += 2) + { + unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack; + + const float descale0 = descales[ii]; + const float descale1 = descales[ii + 1]; +#if __ARM_NEON + float32x2_t _descale01 = vld1_f32((const float*)descales + ii); +#endif + + float c0; + float c1; +#if __ARM_NEON + float32x4_t _c0; + float32x4_t _c1; +#endif + if (pC) + { + if (broadcast_type_C == 0) + { + c0 = bfloat16_to_float32(pC[0]) * beta; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); +#endif + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const unsigned short*)C + i + ii; + c0 = bfloat16_to_float32(pC[0]) * beta; + c1 = bfloat16_to_float32(pC[1]) * beta; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); + _c1 = vdupq_n_f32(c1); +#endif + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + pC = (const unsigned short*)C + (i + ii) * c_hstep + j; + } + if (broadcast_type_C == 4) + { + pC = (const unsigned short*)C + j; + } + } + + int jj = 0; +#if __ARM_NEON +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale01, 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale01, 0); + float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), _descale01, 1); + float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), _descale01, 1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c1); + _f3 = vaddq_f32(_f3, _c1); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + c_hstep); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + pC += 8; + } + if (broadcast_type_C == 4) + { + uint16x8_t _c = vld1q_u16(pC); + _c0 = bfloat2float(vget_low_u16(_c)); + _c1 = bfloat2float(vget_high_u16(_c)); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _c0 = vmulq_f32(_c0, _beta); + _c1 = vmulq_f32(_c1, _beta); + } + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + pC += 8; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + + uint16x4_t _bf0 = float2bfloat(_f0); + uint16x4_t _bf1 = float2bfloat(_f1); + uint16x4_t _bf2 = float2bfloat(_f2); + uint16x4_t _bf3 = float2bfloat(_f3); + + if (out_elempack == 4) + { + vst1q_u16(p0, vcombine_u16(_bf0, _bf2)); + vst1q_u16(p0 + out_hstep * 4, vcombine_u16(_bf1, _bf3)); + } + if (out_elempack == 1) + { + p0[0] = vget_lane_u16(_bf0, 0); + p0[1] = vget_lane_u16(_bf2, 0); + p0[out_hstep] = vget_lane_u16(_bf0, 1); + p0[out_hstep + 1] = vget_lane_u16(_bf2, 1); + p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); + p0[out_hstep * 2 + 1] = vget_lane_u16(_bf2, 2); + p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); + p0[out_hstep * 3 + 1] = vget_lane_u16(_bf2, 3); + p0[out_hstep * 4] = vget_lane_u16(_bf1, 0); + p0[out_hstep * 4 + 1] = vget_lane_u16(_bf3, 0); + p0[out_hstep * 5] = vget_lane_u16(_bf1, 1); + p0[out_hstep * 5 + 1] = vget_lane_u16(_bf3, 1); + p0[out_hstep * 6] = vget_lane_u16(_bf1, 2); + p0[out_hstep * 6 + 1] = vget_lane_u16(_bf3, 2); + p0[out_hstep * 7] = vget_lane_u16(_bf1, 3); + p0[out_hstep * 7 + 1] = vget_lane_u16(_bf3, 3); + } + + pp += 16; + p0 += out_hstep * 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + // a0 a1 a2 a3 + // b0 b1 b2 b3 + + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale01, 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale01, 1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + _c0 = bfloat2float(vld1_u16(pC)); + _c1 = bfloat2float(vld1_u16(pC + c_hstep)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + } + pC += 4; + } + if (broadcast_type_C == 4) + { + _c0 = bfloat2float(vld1_u16(pC)); + _c0 = vmulq_n_f32(_c0, beta); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + pC += 4; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + + uint16x4_t _bf0 = float2bfloat(_f0); + uint16x4_t _bf1 = float2bfloat(_f1); + + if (out_elempack == 4) + { + vst1q_u16(p0, vcombine_u16(_bf0, _bf1)); + } + if (out_elempack == 1) + { + p0[0] = vget_lane_u16(_bf0, 0); + p0[1] = vget_lane_u16(_bf1, 0); + p0[out_hstep] = vget_lane_u16(_bf0, 1); + p0[out_hstep + 1] = vget_lane_u16(_bf1, 1); + p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); + p0[out_hstep * 2 + 1] = vget_lane_u16(_bf1, 2); + p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); + p0[out_hstep * 3 + 1] = vget_lane_u16(_bf1, 3); + } + + pp += 8; + p0 += out_hstep * 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + // a0 a1 b0 b1 + int32x2x2_t _sum0 = vld2_s32(pp); + + float32x4_t _descale = vcombine_f32(_descale01, _descale01); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vcombine_s32(_sum0.val[0], _sum0.val[1])), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + float32x4_t _cc = vzipq_f32(_c0, _c1).val[0]; + _f0 = vaddq_f32(_f0, _cc); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + uint16x4_t _c = uint16x4_t(); + _c = vset_lane_u16(pC[0], _c, 0); + _c = vset_lane_u16(pC[c_hstep], _c, 1); + _c = vset_lane_u16(pC[1], _c, 2); + _c = vset_lane_u16(pC[c_hstep + 1], _c, 3); + _c0 = bfloat2float(_c); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + pC += 2; + } + if (broadcast_type_C == 4) + { + uint16x4_t _c = uint16x4_t(); + _c = vset_lane_u16(pC[0], _c, 0); + _c = vset_lane_u16(pC[0], _c, 1); + _c = vset_lane_u16(pC[1], _c, 2); + _c = vset_lane_u16(pC[1], _c, 3); + _c0 = bfloat2float(_c); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + pC += 2; + } + } + + _f0 = vmulq_n_f32(_f0, alpha); + + uint16x4_t _bf0 = float2bfloat(_f0); + + p0[0] = vget_lane_u16(_bf0, 0); + p0[1] = vget_lane_u16(_bf0, 1); + p0[out_hstep] = vget_lane_u16(_bf0, 2); + p0[out_hstep + 1] = vget_lane_u16(_bf0, 3); + + pp += 4; + p0 += out_hstep * 2; + } +#endif // __ARM_NEON + for (; jj < max_jj; jj += 1) + { + float f0 = pp[0] * descale0; + float f1 = pp[1] * descale1; + + if (pC) + { + if (broadcast_type_C == 0) + { + f0 += c0; + f1 += c0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + f0 += c0; + f1 += c1; + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + f0 += bfloat16_to_float32(pC[0]) * beta; + f1 += bfloat16_to_float32(pC[c_hstep]) * beta; + pC += 1; + } + if (broadcast_type_C == 4) + { + c0 = bfloat16_to_float32(pC[0]) * beta; + f0 += c0; + f1 += c0; + pC += 1; + } + } + + if (alpha != 1.f) + { + f0 *= alpha; + f1 *= alpha; + } + + p0[0] = float32_to_bfloat16(f0); + p0[1] = float32_to_bfloat16(f1); + pp += 2; + p0 += out_hstep; + } + } + for (; ii < max_ii; ii += 1) + { + unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack; + + const float descale = descales[ii]; +#if __ARM_NEON + float32x4_t _descale = vdupq_n_f32(descale); +#endif + + float c0; +#if __ARM_NEON + float32x4_t _c0; +#endif + if (pC) + { + if (broadcast_type_C == 0) + { + c0 = bfloat16_to_float32(pC[0]) * beta; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); +#endif + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const unsigned short*)C + i + ii; + c0 = bfloat16_to_float32(pC[0]) * beta; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); +#endif + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + pC = (const unsigned short*)C + (i + ii) * c_hstep + j; + } + if (broadcast_type_C == 4) + { + pC = (const unsigned short*)C + j; + } + } + + int jj = 0; +#if __ARM_NEON + for (; jj + 15 < max_jj; jj += 16) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + _c0 = bfloat2float(vget_low_u16(_c01)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + pC += 16; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + + uint16x4_t _bf0 = float2bfloat(_f0); + uint16x4_t _bf1 = float2bfloat(_f1); + uint16x4_t _bf2 = float2bfloat(_f2); + uint16x4_t _bf3 = float2bfloat(_f3); + + if (out_hstep == 1) + { + vst1q_u16(p0, vcombine_u16(_bf0, _bf1)); + vst1q_u16(p0 + 8, vcombine_u16(_bf2, _bf3)); + } + else + { + if (out_elempack == 4) + { + vst1_u16(p0, _bf0); + vst1_u16(p0 + out_hstep * 4, _bf1); + vst1_u16(p0 + out_hstep * 8, _bf2); + vst1_u16(p0 + out_hstep * 12, _bf3); + } + if (out_elempack == 1) + { + p0[0] = vget_lane_u16(_bf0, 0); + p0[out_hstep] = vget_lane_u16(_bf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); + p0[out_hstep * 4] = vget_lane_u16(_bf1, 0); + p0[out_hstep * 5] = vget_lane_u16(_bf1, 1); + p0[out_hstep * 6] = vget_lane_u16(_bf1, 2); + p0[out_hstep * 7] = vget_lane_u16(_bf1, 3); + p0[out_hstep * 8] = vget_lane_u16(_bf2, 0); + p0[out_hstep * 9] = vget_lane_u16(_bf2, 1); + p0[out_hstep * 10] = vget_lane_u16(_bf2, 2); + p0[out_hstep * 11] = vget_lane_u16(_bf2, 3); + p0[out_hstep * 12] = vget_lane_u16(_bf3, 0); + p0[out_hstep * 13] = vget_lane_u16(_bf3, 1); + p0[out_hstep * 14] = vget_lane_u16(_bf3, 2); + p0[out_hstep * 15] = vget_lane_u16(_bf3, 3); + } + } + + pp += 16; + p0 += out_hstep * 16; + } + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + uint16x8_t _c = vld1q_u16(pC); + _c0 = bfloat2float(vget_low_u16(_c)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + } + pC += 8; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + + uint16x4_t _bf0 = float2bfloat(_f0); + uint16x4_t _bf1 = float2bfloat(_f1); + + if (out_hstep == 1) + { + vst1q_u16(p0, vcombine_u16(_bf0, _bf1)); + } + else + { + if (out_elempack == 4) + { + vst1_u16(p0, _bf0); + vst1_u16(p0 + out_hstep * 4, _bf1); + } + if (out_elempack == 1) + { + p0[0] = vget_lane_u16(_bf0, 0); + p0[out_hstep] = vget_lane_u16(_bf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); + p0[out_hstep * 4] = vget_lane_u16(_bf1, 0); + p0[out_hstep * 5] = vget_lane_u16(_bf1, 1); + p0[out_hstep * 6] = vget_lane_u16(_bf1, 2); + p0[out_hstep * 7] = vget_lane_u16(_bf1, 3); + } + } + + pp += 8; + p0 += out_hstep * 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + _c0 = bfloat2float(vld1_u16(pC)); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + pC += 4; + } + } + + _f0 = vmulq_n_f32(_f0, alpha); + + uint16x4_t _bf0 = float2bfloat(_f0); + + if (out_hstep == 1) + { + vst1_u16(p0, _bf0); + } + else + { + if (out_elempack == 4) + { + vst1_u16(p0, _bf0); + } + if (out_elempack == 1) + { + p0[0] = vget_lane_u16(_bf0, 0); + p0[out_hstep] = vget_lane_u16(_bf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); + } + } + + pp += 4; + p0 += out_hstep * 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + float32x2_t _f0 = vmul_f32(vcvt_f32_s32(vld1_s32(pp)), vget_low_f32(_descale)); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vadd_f32(_f0, vget_low_f32(_c0)); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + float32x2_t _c = float32x2_t(); + _c = vset_lane_f32(bfloat16_to_float32(pC[0]), _c, 0); + _c = vset_lane_f32(bfloat16_to_float32(pC[1]), _c, 1); + _f0 = vmla_n_f32(_f0, _c, beta); + pC += 2; + } + } + + _f0 = vmul_n_f32(_f0, alpha); + + p0[0] = float32_to_bfloat16(vget_lane_f32(_f0, 0)); + p0[out_hstep] = float32_to_bfloat16(vget_lane_f32(_f0, 1)); + + pp += 2; + p0 += out_hstep * 2; + } +#endif // __ARM_NEON + for (; jj < max_jj; jj += 1) + { + float f0 = pp[0] * descale; + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + f0 += c0; + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + f0 += bfloat16_to_float32(pC[0]) * beta; + pC += 1; + } + } + + f0 *= alpha; + + p0[0] = float32_to_bfloat16(f0); + + pp += 1; + p0 += out_hstep; + } + } +} diff --git a/src/layer/arm/gemm_int8_fp16s.h b/src/layer/arm/gemm_int8_fp16s.h new file mode 100644 index 000000000000..e096a6caf6f6 --- /dev/null +++ b/src/layer/arm/gemm_int8_fp16s.h @@ -0,0 +1,10368 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 +void pack_A_tile_fp16_to_int8_i8mm(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales); +void transpose_pack_A_tile_fp16_to_int8_i8mm(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales); +void pack_B_tile_fp16_to_int8_i8mm(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale); +void transpose_pack_B_tile_fp16_to_int8_i8mm(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale); +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 +void pack_A_tile_fp16_to_int8_asimddp(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales); +void transpose_pack_A_tile_fp16_to_int8_asimddp(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales); +void pack_B_tile_fp16_to_int8_asimddp(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale); +void transpose_pack_B_tile_fp16_to_int8_asimddp(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale); +void unpack_output_tile_int32_to_fp16_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta); +void transpose_unpack_output_tile_int32_to_fp16_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta); +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82 && __aarch64__ && !__ARM_FEATURE_FP16_VECTOR_ARITHMETIC +void compute_A_tile_fp16_int8_scales_asimdhp(const Mat& A, Mat& scales, float B_scale, Mat& out_descales, int i, int max_ii); +void transpose_compute_A_tile_fp16_int8_scales_asimdhp(const Mat& A, Mat& scales, float B_scale, Mat& out_descales, int i, int max_ii); +void compute_B_fp16_int8_scale_asimdhp(const Mat& B, float& scale); +#endif + +static void compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, float B_scale, Mat& out_descales, int i, int max_ii) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM82 && __aarch64__ && !__ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (ncnn::cpu_support_arm_asimdhp()) + { + compute_A_tile_fp16_int8_scales_asimdhp(A, scales, B_scale, out_descales, i, max_ii); + return; + } +#endif + + const int elempack = A.elempack; + const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + const int K = A.w; + + // NCNN_LOGE("compute_A_tile_fp16_int8_scales %d %d", max_ii, elempack); + + const float v127_B_scale = 127.f * B_scale; + + float* ps = scales; + float* pods = out_descales; + +#if __ARM_NEON +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (elempack == 8) + { + float32x4_t _v127 = vdupq_n_f32(127.f); + float32x4_t _v127_B_scale = vdupq_n_f32(v127_B_scale); + + for (int ii = 0; ii + 7 < max_ii; ii += 8) + { + const __fp16* p0 = (const __fp16*)A + (i + ii) * A_hstep; + + float16x8_t _amax0 = vdupq_n_f16((__fp16)0.f); + float16x8_t _amax1 = vdupq_n_f16((__fp16)0.f); + float16x8_t _amax2 = vdupq_n_f16((__fp16)0.f); + float16x8_t _amax3 = vdupq_n_f16((__fp16)0.f); + int kk = 0; + for (; kk + 3 < K; kk += 4) + { + float16x8_t _p0 = vld1q_f16(p0); + float16x8_t _p1 = vld1q_f16(p0 + 8); + float16x8_t _p2 = vld1q_f16(p0 + 16); + float16x8_t _p3 = vld1q_f16(p0 + 24); + _amax0 = vmaxq_f16(_amax0, vabsq_f16(_p0)); + _amax1 = vmaxq_f16(_amax1, vabsq_f16(_p1)); + _amax2 = vmaxq_f16(_amax2, vabsq_f16(_p2)); + _amax3 = vmaxq_f16(_amax3, vabsq_f16(_p3)); + p0 += 32; + } + _amax0 = vmaxq_f16(_amax0, _amax2); + _amax1 = vmaxq_f16(_amax1, _amax3); + for (; kk + 1 < K; kk += 2) + { + float16x8_t _p0 = vld1q_f16(p0); + float16x8_t _p1 = vld1q_f16(p0 + 8); + _amax0 = vmaxq_f16(_amax0, vabsq_f16(_p0)); + _amax1 = vmaxq_f16(_amax1, vabsq_f16(_p1)); + p0 += 16; + } + _amax0 = vmaxq_f16(_amax0, _amax1); + for (; kk < K; kk++) + { + float16x8_t _p = vld1q_f16(p0); + _amax0 = vmaxq_f16(_amax0, vabsq_f16(_p)); + p0 += 8; + } + float32x4_t _absmax0 = vcvt_f32_f16(vget_low_f16(_amax0)); + float32x4_t _absmax1 = vcvt_f32_f16(vget_high_f16(_amax0)); + + float32x4_t _scale0 = vdivq_f32(_v127, _absmax0); + float32x4_t _scale1 = vdivq_f32(_v127, _absmax1); + float32x4_t _out_descale0 = vdivq_f32(_absmax0, _v127_B_scale); + float32x4_t _out_descale1 = vdivq_f32(_absmax1, _v127_B_scale); + + vst1q_f32(ps, _scale0); + vst1q_f32(ps + 4, _scale1); + vst1q_f32(pods, _out_descale0); + vst1q_f32(pods + 4, _out_descale1); + + ps += 8; + pods += 8; + } + } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (elempack == 4) + { +#if __aarch64__ + float32x4_t _v127 = vdupq_n_f32(127.f); + float32x4_t _v127_B_scale = vdupq_n_f32(v127_B_scale); +#endif + + for (int ii = 0; ii + 3 < max_ii; ii += 4) + { +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + const __fp16* p0 = (const __fp16*)A + (i + ii) * A_hstep; + + float16x8_t _amax0 = vdupq_n_f16((__fp16)0.f); + float16x8_t _amax1 = vdupq_n_f16((__fp16)0.f); + float16x8_t _amax2 = vdupq_n_f16((__fp16)0.f); + float16x8_t _amax3 = vdupq_n_f16((__fp16)0.f); + int kk = 0; + for (; kk + 7 < K; kk += 8) + { + float16x8_t _p0 = vld1q_f16(p0); + float16x8_t _p1 = vld1q_f16(p0 + 8); + float16x8_t _p2 = vld1q_f16(p0 + 16); + float16x8_t _p3 = vld1q_f16(p0 + 24); + _amax0 = vmaxq_f16(_amax0, vabsq_f16(_p0)); + _amax1 = vmaxq_f16(_amax1, vabsq_f16(_p1)); + _amax2 = vmaxq_f16(_amax2, vabsq_f16(_p2)); + _amax3 = vmaxq_f16(_amax3, vabsq_f16(_p3)); + p0 += 32; + } + _amax0 = vmaxq_f16(_amax0, _amax2); + _amax1 = vmaxq_f16(_amax1, _amax3); + for (; kk + 3 < K; kk += 4) + { + float16x8_t _p0 = vld1q_f16(p0); + float16x8_t _p1 = vld1q_f16(p0 + 8); + _amax0 = vmaxq_f16(_amax0, vabsq_f16(_p0)); + _amax1 = vmaxq_f16(_amax1, vabsq_f16(_p1)); + p0 += 16; + } + _amax0 = vmaxq_f16(_amax0, _amax1); + for (; kk + 1 < K; kk += 2) + { + float16x8_t _p = vld1q_f16(p0); + _amax0 = vmaxq_f16(_amax0, vabsq_f16(_p)); + p0 += 8; + } + float16x4_t _amax = vmax_f16(vget_low_f16(_amax0), vget_high_f16(_amax0)); + for (; kk < K; kk++) + { + float16x4_t _p = vld1_f16(p0); + _amax = vmax_f16(_amax, vabs_f16(_p)); + p0 += 4; + } + float32x4_t _absmax0 = vcvt_f32_f16(_amax); +#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep; + + float32x4_t _absmax0 = vdupq_n_f32(0.f); + float32x4_t _absmax1 = vdupq_n_f32(0.f); + float32x4_t _absmax2 = vdupq_n_f32(0.f); + float32x4_t _absmax3 = vdupq_n_f32(0.f); + int kk = 0; + for (; kk + 3 < K; kk += 4) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + _absmax2 = vmaxq_f32(_absmax2, vabsq_f32(_p2)); + _absmax3 = vmaxq_f32(_absmax3, vabsq_f32(_p3)); + p0 += 16; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax2); + _absmax1 = vmaxq_f32(_absmax1, _absmax3); + for (; kk + 1 < K; kk += 2) + { + uint16x8_t _p = vld1q_u16(p0); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + p0 += 8; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax1); + for (; kk < K; kk++) + { + float32x4_t _p = vcvt_f32_f16((float16x4_t)vld1_u16(p0)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p)); + p0 += 4; + } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +#if __aarch64__ + float32x4_t _scale = vdivq_f32(_v127, _absmax0); + float32x4_t _out_descale = vdivq_f32(_absmax0, _v127_B_scale); + + vst1q_f32(ps, _scale); + vst1q_f32(pods, _out_descale); +#else + // float32x4_t _recp_absmax = vrecpeq_f32(_absmax0); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax0, _recp_absmax), _recp_absmax); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax0, _recp_absmax), _recp_absmax); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax0, _recp_absmax), _recp_absmax); + // float32x4_t _scale = vmulq_f32(_v127, _recp_absmax); + // float32x4_t _out_descale = vmulq_f32(_absmax0, _recp_v127_B_scale); + + float tmp[4]; + vst1q_f32(tmp, _absmax0); + + ps[0] = 127.f / tmp[0]; + ps[1] = 127.f / tmp[1]; + ps[2] = 127.f / tmp[2]; + ps[3] = 127.f / tmp[3]; + + pods[0] = tmp[0] / v127_B_scale; + pods[1] = tmp[1] / v127_B_scale; + pods[2] = tmp[2] / v127_B_scale; + pods[3] = tmp[3] / v127_B_scale; + +#endif + ps += 4; + pods += 4; + } + } +#endif // __ARM_NEON + if (elempack == 1) + { + for (int ii = 0; ii < max_ii; ii++) + { +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + const __fp16* p0 = (const __fp16*)A + (i + ii) * A_hstep; + + float absmax = 0.f; + float16x8_t _amax0 = vdupq_n_f16((__fp16)0.f); + float16x8_t _amax1 = vdupq_n_f16((__fp16)0.f); + float16x8_t _amax2 = vdupq_n_f16((__fp16)0.f); + float16x8_t _amax3 = vdupq_n_f16((__fp16)0.f); + int kk = 0; + for (; kk + 31 < K; kk += 32) + { + float16x8_t _p0 = vld1q_f16(p0); + float16x8_t _p1 = vld1q_f16(p0 + 8); + float16x8_t _p2 = vld1q_f16(p0 + 16); + float16x8_t _p3 = vld1q_f16(p0 + 24); + _amax0 = vmaxq_f16(_amax0, vabsq_f16(_p0)); + _amax1 = vmaxq_f16(_amax1, vabsq_f16(_p1)); + _amax2 = vmaxq_f16(_amax2, vabsq_f16(_p2)); + _amax3 = vmaxq_f16(_amax3, vabsq_f16(_p3)); + p0 += 32; + } + _amax0 = vmaxq_f16(_amax0, _amax2); + _amax1 = vmaxq_f16(_amax1, _amax3); + for (; kk + 15 < K; kk += 16) + { + float16x8_t _p0 = vld1q_f16(p0); + float16x8_t _p1 = vld1q_f16(p0 + 8); + _amax0 = vmaxq_f16(_amax0, vabsq_f16(_p0)); + _amax1 = vmaxq_f16(_amax1, vabsq_f16(_p1)); + p0 += 16; + } + _amax0 = vmaxq_f16(_amax0, _amax1); + for (; kk + 7 < K; kk += 8) + { + float16x8_t _p = vld1q_f16(p0); + _amax0 = vmaxq_f16(_amax0, vabsq_f16(_p)); + p0 += 8; + } + absmax = (float)vmaxvq_f16(_amax0); + for (; kk < K; kk++) + { + absmax = std::max(absmax, (float)fabsf(p0[0])); + p0++; + } +#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep; + + float absmax = 0.f; + int kk = 0; +#if __ARM_NEON + float32x4_t _absmax0 = vdupq_n_f32(0.f); + float32x4_t _absmax1 = vdupq_n_f32(0.f); + float32x4_t _absmax2 = vdupq_n_f32(0.f); + float32x4_t _absmax3 = vdupq_n_f32(0.f); + for (; kk + 15 < K; kk += 16) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + _absmax2 = vmaxq_f32(_absmax2, vabsq_f32(_p2)); + _absmax3 = vmaxq_f32(_absmax3, vabsq_f32(_p3)); + p0 += 16; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax2); + _absmax1 = vmaxq_f32(_absmax1, _absmax3); + for (; kk + 7 < K; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + p0 += 8; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax1); + for (; kk + 3 < K; kk += 4) + { + float32x4_t _p = vcvt_f32_f16((float16x4_t)vld1_u16(p0)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p)); + p0 += 4; + } +#if __aarch64__ + absmax = vmaxvq_f32(_absmax0); +#else + float32x2_t _aa = vmax_f32(vget_low_f32(_absmax0), vget_high_f32(_absmax0)); + absmax = std::max(absmax, std::max(vget_lane_f32(_aa, 0), vget_lane_f32(_aa, 1))); +#endif +#endif // __ARM_NEON + for (; kk < K; kk++) + { + absmax = std::max(absmax, (float)fabsf(float16_to_float32(p0[0]))); + p0++; + } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + + ps[0] = 127.f / absmax; + pods[0] = absmax / v127_B_scale; + ps++; + pods++; + } + } +} + +static void pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + pack_A_tile_fp16_to_int8_i8mm(A, AT, i, max_ii, k, max_kk, scales); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + pack_A_tile_fp16_to_int8_asimddp(A, AT, i, max_ii, k, max_kk, scales); + return; + } +#endif + + const int elempack = A.elempack; + const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + + // NCNN_LOGE("pack_A_tile_fp16_to_int8 %d %d", max_ii, elempack); + + signed char* pp = AT; + + int ii = 0; +#if __ARM_NEON + for (; ii + 7 < max_ii; ii += 8) + { + const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k * elempack; + + float32x4_t _scale0 = vld1q_f32((const float*)scales + ii); + float32x4_t _scale1 = vld1q_f32((const float*)scales + ii + 4); + +#if __aarch64__ + if (elempack == 8) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + 16); + uint16x8_t _s = vld1q_u16(p0 + 24); + uint16x8_t _t = vld1q_u16(p0 + 32); + uint16x8_t _u = vld1q_u16(p0 + 40); + uint16x8_t _v = vld1q_u16(p0 + 48); + uint16x8_t _w = vld1q_u16(p0 + 56); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + float32x4_t _p4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_r)); + float32x4_t _p5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_r)); + float32x4_t _p6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_s)); + float32x4_t _p7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_s)); + float32x4_t _p8 = vcvt_f32_f16((float16x4_t)vget_low_u16(_t)); + float32x4_t _p9 = vcvt_f32_f16((float16x4_t)vget_high_u16(_t)); + float32x4_t _pa = vcvt_f32_f16((float16x4_t)vget_low_u16(_u)); + float32x4_t _pb = vcvt_f32_f16((float16x4_t)vget_high_u16(_u)); + float32x4_t _pc = vcvt_f32_f16((float16x4_t)vget_low_u16(_v)); + float32x4_t _pd = vcvt_f32_f16((float16x4_t)vget_high_u16(_v)); + float32x4_t _pe = vcvt_f32_f16((float16x4_t)vget_low_u16(_w)); + float32x4_t _pf = vcvt_f32_f16((float16x4_t)vget_high_u16(_w)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + _p2 = vmulq_f32(_p2, _scale0); + _p3 = vmulq_f32(_p3, _scale1); + _p4 = vmulq_f32(_p4, _scale0); + _p5 = vmulq_f32(_p5, _scale1); + _p6 = vmulq_f32(_p6, _scale0); + _p7 = vmulq_f32(_p7, _scale1); + _p8 = vmulq_f32(_p8, _scale0); + _p9 = vmulq_f32(_p9, _scale1); + _pa = vmulq_f32(_pa, _scale0); + _pb = vmulq_f32(_pb, _scale1); + _pc = vmulq_f32(_pc, _scale0); + _pd = vmulq_f32(_pd, _scale1); + _pe = vmulq_f32(_pe, _scale0); + _pf = vmulq_f32(_pf, _scale1); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8x2_t _p04 = vzip_s8(float2int8(_p0, _p1), float2int8(_p8, _p9)); + int8x8x2_t _p15 = vzip_s8(float2int8(_p2, _p3), float2int8(_pa, _pb)); + int8x8x2_t _p26 = vzip_s8(float2int8(_p4, _p5), float2int8(_pc, _pd)); + int8x8x2_t _p37 = vzip_s8(float2int8(_p6, _p7), float2int8(_pe, _pf)); + + int8x16x4_t _rr; + _rr.val[0] = vcombine_s8(_p04.val[0], _p04.val[1]); + _rr.val[1] = vcombine_s8(_p15.val[0], _p15.val[1]); + _rr.val[2] = vcombine_s8(_p26.val[0], _p26.val[1]); + _rr.val[3] = vcombine_s8(_p37.val[0], _p37.val[1]); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x16x4_t _rr; + _rr.val[0] = vcombine_s8(float2int8(_p0, _p1), float2int8(_p8, _p9)); + _rr.val[1] = vcombine_s8(float2int8(_p2, _p3), float2int8(_pa, _pb)); + _rr.val[2] = vcombine_s8(float2int8(_p4, _p5), float2int8(_pc, _pd)); + _rr.val[3] = vcombine_s8(float2int8(_p6, _p7), float2int8(_pe, _pf)); +#endif // __ARM_FEATURE_MATMUL_INT8 + + vst4q_s8(pp, _rr); +#else // __ARM_FEATURE_DOTPROD + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p1), float2int8(_p4, _p5)); + _r01.val[1] = vcombine_s8(float2int8(_p2, _p3), float2int8(_p6, _p7)); + int8x16x2_t _r23; + _r23.val[0] = vcombine_s8(float2int8(_p8, _p9), float2int8(_pc, _pd)); + _r23.val[1] = vcombine_s8(float2int8(_pa, _pb), float2int8(_pe, _pf)); + + vst2q_s8(pp, _r01); + vst2q_s8(pp + 32, _r23); +#endif // __ARM_FEATURE_DOTPROD + + pp += 64; + p0 += 64; + } + for (; kk + 3 < max_kk; kk += 4) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + 16); + uint16x8_t _s = vld1q_u16(p0 + 24); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + float32x4_t _p4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_r)); + float32x4_t _p5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_r)); + float32x4_t _p6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_s)); + float32x4_t _p7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_s)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + _p2 = vmulq_f32(_p2, _scale0); + _p3 = vmulq_f32(_p3, _scale1); + _p4 = vmulq_f32(_p4, _scale0); + _p5 = vmulq_f32(_p5, _scale1); + _p6 = vmulq_f32(_p6, _scale0); + _p7 = vmulq_f32(_p7, _scale1); + +#if __ARM_FEATURE_DOTPROD + int8x8x4_t _r0123; + _r0123.val[0] = float2int8(_p0, _p1); + _r0123.val[1] = float2int8(_p2, _p3); + _r0123.val[2] = float2int8(_p4, _p5); + _r0123.val[3] = float2int8(_p6, _p7); + + vst4_s8(pp, _r0123); +#else // __ARM_FEATURE_DOTPROD + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p1), float2int8(_p4, _p5)); + _r01.val[1] = vcombine_s8(float2int8(_p2, _p3), float2int8(_p6, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += 32; + } + for (; kk + 1 < max_kk; kk += 2) + { + uint16x8_t _p01 = vld1q_u16(p0); + uint16x8_t _p23 = vld1q_u16(p0 + 8); + + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p01)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p01)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p23)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p23)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + _p2 = vmulq_f32(_p2, _scale0); + _p3 = vmulq_f32(_p3, _scale1); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p1); + _r01.val[1] = float2int8(_p2, _p3); + + vst2_s8(pp, _r01); + + pp += 16; + p0 += 16; + } + for (; kk < max_kk; kk++) + { + uint16x8_t _p01 = vld1q_u16(p0); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p01)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p01)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += 8; + } + } +#endif // __aarch64__ + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { +#if __ARM_FEATURE_DOTPROD + uint16x8x4_t _p = vld4q_u16(p0); + uint16x8x4_t _q = vld4q_u16(p0 + A_hstep * 4); + + float32x4_t _p0 = vmulq_laneq_f32(vcvt_f32_f16((float16x4_t)vget_low_u16(_p.val[0])), _scale0, 0); + float32x4_t _p1 = vmulq_laneq_f32(vcvt_f32_f16((float16x4_t)vget_low_u16(_p.val[1])), _scale0, 1); + float32x4_t _p2 = vmulq_laneq_f32(vcvt_f32_f16((float16x4_t)vget_low_u16(_p.val[2])), _scale0, 2); + float32x4_t _p3 = vmulq_laneq_f32(vcvt_f32_f16((float16x4_t)vget_low_u16(_p.val[3])), _scale0, 3); + float32x4_t _p4 = vmulq_laneq_f32(vcvt_f32_f16((float16x4_t)vget_high_u16(_p.val[0])), _scale0, 0); + float32x4_t _p5 = vmulq_laneq_f32(vcvt_f32_f16((float16x4_t)vget_high_u16(_p.val[1])), _scale0, 1); + float32x4_t _p6 = vmulq_laneq_f32(vcvt_f32_f16((float16x4_t)vget_high_u16(_p.val[2])), _scale0, 2); + float32x4_t _p7 = vmulq_laneq_f32(vcvt_f32_f16((float16x4_t)vget_high_u16(_p.val[3])), _scale0, 3); + float32x4_t _p8 = vmulq_laneq_f32(vcvt_f32_f16((float16x4_t)vget_low_u16(_q.val[0])), _scale1, 0); + float32x4_t _p9 = vmulq_laneq_f32(vcvt_f32_f16((float16x4_t)vget_low_u16(_q.val[1])), _scale1, 1); + float32x4_t _pa = vmulq_laneq_f32(vcvt_f32_f16((float16x4_t)vget_low_u16(_q.val[2])), _scale1, 2); + float32x4_t _pb = vmulq_laneq_f32(vcvt_f32_f16((float16x4_t)vget_low_u16(_q.val[3])), _scale1, 3); + float32x4_t _pc = vmulq_laneq_f32(vcvt_f32_f16((float16x4_t)vget_high_u16(_q.val[0])), _scale1, 0); + float32x4_t _pd = vmulq_laneq_f32(vcvt_f32_f16((float16x4_t)vget_high_u16(_q.val[1])), _scale1, 1); + float32x4_t _pe = vmulq_laneq_f32(vcvt_f32_f16((float16x4_t)vget_high_u16(_q.val[2])), _scale1, 2); + float32x4_t _pf = vmulq_laneq_f32(vcvt_f32_f16((float16x4_t)vget_high_u16(_q.val[3])), _scale1, 3); + +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p4); + int8x8_t _r1 = float2int8(_p1, _p5); + int8x8_t _r2 = float2int8(_p2, _p6); + int8x8_t _r3 = float2int8(_p3, _p7); + int8x8_t _r4 = float2int8(_p8, _pc); + int8x8_t _r5 = float2int8(_p9, _pd); + int8x8_t _r6 = float2int8(_pa, _pe); + int8x8_t _r7 = float2int8(_pb, _pf); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p8, _p9); + int8x8_t _r3 = float2int8(_pa, _pb); + int8x8_t _r4 = float2int8(_p4, _p5); + int8x8_t _r5 = float2int8(_p6, _p7); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); +#endif // __ARM_FEATURE_MATMUL_INT8 + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); +#else // __ARM_FEATURE_DOTPROD + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + 16); + uint16x8_t _s = vld1q_u16(p0 + 24); + uint16x8_t _t = vld1q_u16(p0 + A_hstep * 4); + uint16x8_t _u = vld1q_u16(p0 + A_hstep * 4 + 8); + uint16x8_t _v = vld1q_u16(p0 + A_hstep * 4 + 16); + uint16x8_t _w = vld1q_u16(p0 + A_hstep * 4 + 24); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + float32x4_t _p4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_r)); + float32x4_t _p5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_r)); + float32x4_t _p6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_s)); + float32x4_t _p7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_s)); + float32x4_t _p8 = vcvt_f32_f16((float16x4_t)vget_low_u16(_t)); + float32x4_t _p9 = vcvt_f32_f16((float16x4_t)vget_high_u16(_t)); + float32x4_t _pa = vcvt_f32_f16((float16x4_t)vget_low_u16(_u)); + float32x4_t _pb = vcvt_f32_f16((float16x4_t)vget_high_u16(_u)); + float32x4_t _pc = vcvt_f32_f16((float16x4_t)vget_low_u16(_v)); + float32x4_t _pd = vcvt_f32_f16((float16x4_t)vget_high_u16(_v)); + float32x4_t _pe = vcvt_f32_f16((float16x4_t)vget_low_u16(_w)); + float32x4_t _pf = vcvt_f32_f16((float16x4_t)vget_high_u16(_w)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale0); + _p2 = vmulq_f32(_p2, _scale0); + _p3 = vmulq_f32(_p3, _scale0); + _p4 = vmulq_f32(_p4, _scale0); + _p5 = vmulq_f32(_p5, _scale0); + _p6 = vmulq_f32(_p6, _scale0); + _p7 = vmulq_f32(_p7, _scale0); + _p8 = vmulq_f32(_p8, _scale1); + _p9 = vmulq_f32(_p9, _scale1); + _pa = vmulq_f32(_pa, _scale1); + _pb = vmulq_f32(_pb, _scale1); + _pc = vmulq_f32(_pc, _scale1); + _pd = vmulq_f32(_pd, _scale1); + _pe = vmulq_f32(_pe, _scale1); + _pf = vmulq_f32(_pf, _scale1); + + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p8), float2int8(_p2, _pa)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p9), float2int8(_p3, _pb)); + int8x16x2_t _r23; + _r23.val[0] = vcombine_s8(float2int8(_p4, _pc), float2int8(_p6, _pe)); + _r23.val[1] = vcombine_s8(float2int8(_p5, _pd), float2int8(_p7, _pf)); + + vst2q_s8(pp, _r01); + vst2q_s8(pp + 32, _r23); +#endif // __ARM_FEATURE_DOTPROD + + pp += 64; + p0 += 32; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + uint16x4x4_t _p = vld4_u16(p0); + uint16x4x4_t _q = vld4_u16(p0 + A_hstep * 4); + + float32x4_t _p0 = vmulq_laneq_f32(vcvt_f32_f16((float16x4_t)_p.val[0]), _scale0, 0); + float32x4_t _p1 = vmulq_laneq_f32(vcvt_f32_f16((float16x4_t)_p.val[1]), _scale0, 1); + float32x4_t _p2 = vmulq_laneq_f32(vcvt_f32_f16((float16x4_t)_p.val[2]), _scale0, 2); + float32x4_t _p3 = vmulq_laneq_f32(vcvt_f32_f16((float16x4_t)_p.val[3]), _scale0, 3); + float32x4_t _p4 = vmulq_laneq_f32(vcvt_f32_f16((float16x4_t)_q.val[0]), _scale1, 0); + float32x4_t _p5 = vmulq_laneq_f32(vcvt_f32_f16((float16x4_t)_q.val[1]), _scale1, 1); + float32x4_t _p6 = vmulq_laneq_f32(vcvt_f32_f16((float16x4_t)_q.val[2]), _scale1, 2); + float32x4_t _p7 = vmulq_laneq_f32(vcvt_f32_f16((float16x4_t)_q.val[3]), _scale1, 3); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); +#else // __ARM_FEATURE_DOTPROD + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + A_hstep * 4); + uint16x8_t _s = vld1q_u16(p0 + A_hstep * 4 + 8); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + float32x4_t _p4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_r)); + float32x4_t _p5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_r)); + float32x4_t _p6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_s)); + float32x4_t _p7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_s)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale0); + _p2 = vmulq_f32(_p2, _scale0); + _p3 = vmulq_f32(_p3, _scale0); + _p4 = vmulq_f32(_p4, _scale1); + _p5 = vmulq_f32(_p5, _scale1); + _p6 = vmulq_f32(_p6, _scale1); + _p7 = vmulq_f32(_p7, _scale1); + + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p4), float2int8(_p2, _p6)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p5), float2int8(_p3, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += 16; + } + for (; kk + 1 < max_kk; kk += 2) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + A_hstep * 4); + + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p0n = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p1n = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale0); + _p0n = vmulq_f32(_p0n, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + _p1n = vmulq_f32(_p1n, _scale1); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p1); + _r01.val[1] = float2int8(_p0n, _p1n); + + vst2_s8(pp, _r01); + + pp += 16; + p0 += 8; + } + for (; kk < max_kk; kk++) + { + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vld1_u16(p0)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + A_hstep * 4)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + A_hstep); + uint16x8_t _r = vld1q_u16(p0 + A_hstep * 2); + uint16x8_t _s = vld1q_u16(p0 + A_hstep * 3); + uint16x8_t _t = vld1q_u16(p0 + A_hstep * 4); + uint16x8_t _u = vld1q_u16(p0 + A_hstep * 5); + uint16x8_t _v = vld1q_u16(p0 + A_hstep * 6); + uint16x8_t _w = vld1q_u16(p0 + A_hstep * 7); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + float32x4_t _p4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_r)); + float32x4_t _p5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_r)); + float32x4_t _p6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_s)); + float32x4_t _p7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_s)); + float32x4_t _p8 = vcvt_f32_f16((float16x4_t)vget_low_u16(_t)); + float32x4_t _p9 = vcvt_f32_f16((float16x4_t)vget_high_u16(_t)); + float32x4_t _pa = vcvt_f32_f16((float16x4_t)vget_low_u16(_u)); + float32x4_t _pb = vcvt_f32_f16((float16x4_t)vget_high_u16(_u)); + float32x4_t _pc = vcvt_f32_f16((float16x4_t)vget_low_u16(_v)); + float32x4_t _pd = vcvt_f32_f16((float16x4_t)vget_high_u16(_v)); + float32x4_t _pe = vcvt_f32_f16((float16x4_t)vget_low_u16(_w)); + float32x4_t _pf = vcvt_f32_f16((float16x4_t)vget_high_u16(_w)); + +#if __aarch64__ + _p0 = vmulq_laneq_f32(_p0, _scale0, 0); + _p1 = vmulq_laneq_f32(_p1, _scale0, 0); + _p2 = vmulq_laneq_f32(_p2, _scale0, 1); + _p3 = vmulq_laneq_f32(_p3, _scale0, 1); + _p4 = vmulq_laneq_f32(_p4, _scale0, 2); + _p5 = vmulq_laneq_f32(_p5, _scale0, 2); + _p6 = vmulq_laneq_f32(_p6, _scale0, 3); + _p7 = vmulq_laneq_f32(_p7, _scale0, 3); + _p8 = vmulq_laneq_f32(_p8, _scale1, 0); + _p9 = vmulq_laneq_f32(_p9, _scale1, 0); + _pa = vmulq_laneq_f32(_pa, _scale1, 1); + _pb = vmulq_laneq_f32(_pb, _scale1, 1); + _pc = vmulq_laneq_f32(_pc, _scale1, 2); + _pd = vmulq_laneq_f32(_pd, _scale1, 2); + _pe = vmulq_laneq_f32(_pe, _scale1, 3); + _pf = vmulq_laneq_f32(_pf, _scale1, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale0), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale0), 0); + _p2 = vmulq_lane_f32(_p2, vget_low_f32(_scale0), 1); + _p3 = vmulq_lane_f32(_p3, vget_low_f32(_scale0), 1); + _p4 = vmulq_lane_f32(_p4, vget_high_f32(_scale0), 0); + _p5 = vmulq_lane_f32(_p5, vget_high_f32(_scale0), 0); + _p6 = vmulq_lane_f32(_p6, vget_high_f32(_scale0), 1); + _p7 = vmulq_lane_f32(_p7, vget_high_f32(_scale0), 1); + _p8 = vmulq_lane_f32(_p8, vget_low_f32(_scale1), 0); + _p9 = vmulq_lane_f32(_p9, vget_low_f32(_scale1), 0); + _pa = vmulq_lane_f32(_pa, vget_low_f32(_scale1), 1); + _pb = vmulq_lane_f32(_pb, vget_low_f32(_scale1), 1); + _pc = vmulq_lane_f32(_pc, vget_high_f32(_scale1), 0); + _pd = vmulq_lane_f32(_pd, vget_high_f32(_scale1), 0); + _pe = vmulq_lane_f32(_pe, vget_high_f32(_scale1), 1); + _pf = vmulq_lane_f32(_pf, vget_high_f32(_scale1), 1); +#endif + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p4, _p6); + int8x8_t _r2 = float2int8(_p8, _pa); + int8x8_t _r3 = float2int8(_pc, _pe); + int8x8_t _r4 = float2int8(_p1, _p3); + int8x8_t _r5 = float2int8(_p5, _p7); + int8x8_t _r6 = float2int8(_p9, _pb); + int8x8_t _r7 = float2int8(_pd, _pf); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p4, _p6)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p8, _pa)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_pc, _pe)); + int16x4_t _t4 = vreinterpret_s16_s8(float2int8(_p1, _p3)); + int16x4_t _t5 = vreinterpret_s16_s8(float2int8(_p5, _p7)); + int16x4_t _t6 = vreinterpret_s16_s8(float2int8(_p9, _pb)); + int16x4_t _t7 = vreinterpret_s16_s8(float2int8(_pd, _pf)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int16x4x2_t _t45 = vuzp_s16(_t4, _t5); + int16x4x2_t _t67 = vuzp_s16(_t6, _t7); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r2 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); + int8x8_t _r4 = vreinterpret_s8_s16(_t45.val[0]); + int8x8_t _r5 = vreinterpret_s8_s16(_t67.val[0]); + int8x8_t _r6 = vreinterpret_s8_s16(_t45.val[1]); + int8x8_t _r7 = vreinterpret_s8_s16(_t67.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); + + pp += 64; + p0 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vld1_u16(p0)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + A_hstep)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + A_hstep * 2)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + A_hstep * 3)); + float32x4_t _p4 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + A_hstep * 4)); + float32x4_t _p5 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + A_hstep * 5)); + float32x4_t _p6 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + A_hstep * 6)); + float32x4_t _p7 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + A_hstep * 7)); + +#if __aarch64__ + _p0 = vmulq_laneq_f32(_p0, _scale0, 0); + _p1 = vmulq_laneq_f32(_p1, _scale0, 1); + _p2 = vmulq_laneq_f32(_p2, _scale0, 2); + _p3 = vmulq_laneq_f32(_p3, _scale0, 3); + _p4 = vmulq_laneq_f32(_p4, _scale1, 0); + _p5 = vmulq_laneq_f32(_p5, _scale1, 1); + _p6 = vmulq_laneq_f32(_p6, _scale1, 2); + _p7 = vmulq_laneq_f32(_p7, _scale1, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale0), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale0), 1); + _p2 = vmulq_lane_f32(_p2, vget_high_f32(_scale0), 0); + _p3 = vmulq_lane_f32(_p3, vget_high_f32(_scale0), 1); + _p4 = vmulq_lane_f32(_p4, vget_low_f32(_scale1), 0); + _p5 = vmulq_lane_f32(_p5, vget_low_f32(_scale1), 1); + _p6 = vmulq_lane_f32(_p6, vget_high_f32(_scale1), 0); + _p7 = vmulq_lane_f32(_p7, vget_high_f32(_scale1), 1); +#endif + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p4, _p5)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p6, _p7)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r2 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + + pp += 32; + p0 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[1], _p, 1); + _p = vsetq_lane_u16(p0[A_hstep], _p, 2); + _p = vsetq_lane_u16(p0[A_hstep + 1], _p, 3); + _p = vsetq_lane_u16(p0[A_hstep * 2], _p, 4); + _p = vsetq_lane_u16(p0[A_hstep * 2 + 1], _p, 5); + _p = vsetq_lane_u16(p0[A_hstep * 3], _p, 6); + _p = vsetq_lane_u16(p0[A_hstep * 3 + 1], _p, 7); + uint16x8_t _q = uint16x8_t(); + _q = vsetq_lane_u16(p0[A_hstep * 4], _q, 0); + _q = vsetq_lane_u16(p0[A_hstep * 4 + 1], _q, 1); + _q = vsetq_lane_u16(p0[A_hstep * 5], _q, 2); + _q = vsetq_lane_u16(p0[A_hstep * 5 + 1], _q, 3); + _q = vsetq_lane_u16(p0[A_hstep * 6], _q, 4); + _q = vsetq_lane_u16(p0[A_hstep * 6 + 1], _q, 5); + _q = vsetq_lane_u16(p0[A_hstep * 7], _q, 6); + _q = vsetq_lane_u16(p0[A_hstep * 7 + 1], _q, 7); + float32x4_t _p01 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p23 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p45 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p67 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + + float32x4x2_t _scale01 = vzipq_f32(_scale0, _scale0); + float32x4x2_t _scale23 = vzipq_f32(_scale1, _scale1); + + _p01 = vmulq_f32(_p01, _scale01.val[0]); + _p23 = vmulq_f32(_p23, _scale01.val[1]); + _p45 = vmulq_f32(_p45, _scale23.val[0]); + _p67 = vmulq_f32(_p67, _scale23.val[1]); + + int8x8_t _r0 = float2int8(_p01, _p23); + int8x8_t _r1 = float2int8(_p45, _p67); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += 2; + } + for (; kk < max_kk; kk++) + { + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[A_hstep], _p, 1); + _p = vsetq_lane_u16(p0[A_hstep * 2], _p, 2); + _p = vsetq_lane_u16(p0[A_hstep * 3], _p, 3); + _p = vsetq_lane_u16(p0[A_hstep * 4], _p, 4); + _p = vsetq_lane_u16(p0[A_hstep * 5], _p, 5); + _p = vsetq_lane_u16(p0[A_hstep * 6], _p, 6); + _p = vsetq_lane_u16(p0[A_hstep * 7], _p, 7); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + + pp += 8; + p0++; + } + } + } + for (; ii + 3 < max_ii; ii += 4) + { + const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k * elempack; + + float32x4_t _scale = vld1q_f32((const float*)scales + ii); + + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { +#if __ARM_FEATURE_DOTPROD + uint16x8x4_t _p = vld4q_u16(p0); + + float32x4_t _p0 = vmulq_laneq_f32(vcvt_f32_f16((float16x4_t)vget_low_u16(_p.val[0])), _scale, 0); + float32x4_t _p1 = vmulq_laneq_f32(vcvt_f32_f16((float16x4_t)vget_low_u16(_p.val[1])), _scale, 1); + float32x4_t _p2 = vmulq_laneq_f32(vcvt_f32_f16((float16x4_t)vget_low_u16(_p.val[2])), _scale, 2); + float32x4_t _p3 = vmulq_laneq_f32(vcvt_f32_f16((float16x4_t)vget_low_u16(_p.val[3])), _scale, 3); + float32x4_t _p4 = vmulq_laneq_f32(vcvt_f32_f16((float16x4_t)vget_high_u16(_p.val[0])), _scale, 0); + float32x4_t _p5 = vmulq_laneq_f32(vcvt_f32_f16((float16x4_t)vget_high_u16(_p.val[1])), _scale, 1); + float32x4_t _p6 = vmulq_laneq_f32(vcvt_f32_f16((float16x4_t)vget_high_u16(_p.val[2])), _scale, 2); + float32x4_t _p7 = vmulq_laneq_f32(vcvt_f32_f16((float16x4_t)vget_high_u16(_p.val[3])), _scale, 3); + +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p4); + int8x8_t _r1 = float2int8(_p1, _p5); + int8x8_t _r2 = float2int8(_p2, _p6); + int8x8_t _r3 = float2int8(_p3, _p7); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#endif // __ARM_FEATURE_MATMUL_INT8 + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); +#else // __ARM_FEATURE_DOTPROD + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + 16); + uint16x8_t _s = vld1q_u16(p0 + 24); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + float32x4_t _p4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_r)); + float32x4_t _p5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_r)); + float32x4_t _p6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_s)); + float32x4_t _p7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_s)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p2), float2int8(_p4, _p6)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p3), float2int8(_p5, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += 32; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + uint16x4x4_t _p = vld4_u16(p0); + + float32x4_t _p0 = vmulq_laneq_f32(vcvt_f32_f16((float16x4_t)_p.val[0]), _scale, 0); + float32x4_t _p1 = vmulq_laneq_f32(vcvt_f32_f16((float16x4_t)_p.val[1]), _scale, 1); + float32x4_t _p2 = vmulq_laneq_f32(vcvt_f32_f16((float16x4_t)_p.val[2]), _scale, 2); + float32x4_t _p3 = vmulq_laneq_f32(vcvt_f32_f16((float16x4_t)_p.val[3]), _scale, 3); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); +#else // __ARM_FEATURE_DOTPROD + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p2); + _r01.val[1] = float2int8(_p1, _p3); + + vst2_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 16; + p0 += 16; + } + for (; kk + 1 < max_kk; kk += 2) + { + uint16x8_t _p = vld1q_u16(p0); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + float32x4x2_t _p01 = vzipq_f32(_p0, _p1); + + int8x8_t _r01 = float2int8(_p01.val[0], _p01.val[1]); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += 8; + } + for (; kk < max_kk; kk++) + { + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vld1_u16(p0)); + _p0 = vmulq_f32(_p0, _scale); + int8x8_t _r0 = float2int8(_p0, _p0); + + pp[0] = vget_lane_s8(_r0, 0); + pp[1] = vget_lane_s8(_r0, 1); + pp[2] = vget_lane_s8(_r0, 2); + pp[3] = vget_lane_s8(_r0, 3); + + pp += 4; + p0 += 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + A_hstep); + uint16x8_t _r = vld1q_u16(p0 + A_hstep * 2); + uint16x8_t _s = vld1q_u16(p0 + A_hstep * 3); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + float32x4_t _p4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_r)); + float32x4_t _p5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_r)); + float32x4_t _p6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_s)); + float32x4_t _p7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_s)); + +#if __aarch64__ + _p0 = vmulq_laneq_f32(_p0, _scale, 0); + _p1 = vmulq_laneq_f32(_p1, _scale, 0); + _p2 = vmulq_laneq_f32(_p2, _scale, 1); + _p3 = vmulq_laneq_f32(_p3, _scale, 1); + _p4 = vmulq_laneq_f32(_p4, _scale, 2); + _p5 = vmulq_laneq_f32(_p5, _scale, 2); + _p6 = vmulq_laneq_f32(_p6, _scale, 3); + _p7 = vmulq_laneq_f32(_p7, _scale, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale), 0); + _p2 = vmulq_lane_f32(_p2, vget_low_f32(_scale), 1); + _p3 = vmulq_lane_f32(_p3, vget_low_f32(_scale), 1); + _p4 = vmulq_lane_f32(_p4, vget_high_f32(_scale), 0); + _p5 = vmulq_lane_f32(_p5, vget_high_f32(_scale), 0); + _p6 = vmulq_lane_f32(_p6, vget_high_f32(_scale), 1); + _p7 = vmulq_lane_f32(_p7, vget_high_f32(_scale), 1); +#endif + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p4, _p6); + int8x8_t _r2 = float2int8(_p1, _p3); + int8x8_t _r3 = float2int8(_p5, _p7); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p4, _p6)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p1, _p3)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p5, _p7)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r2 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + + pp += 32; + p0 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vld1_u16(p0)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + A_hstep)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + A_hstep * 2)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + A_hstep * 3)); + +#if __aarch64__ + _p0 = vmulq_laneq_f32(_p0, _scale, 0); + _p1 = vmulq_laneq_f32(_p1, _scale, 1); + _p2 = vmulq_laneq_f32(_p2, _scale, 2); + _p3 = vmulq_laneq_f32(_p3, _scale, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale), 1); + _p2 = vmulq_lane_f32(_p2, vget_high_f32(_scale), 0); + _p3 = vmulq_lane_f32(_p3, vget_high_f32(_scale), 1); +#endif + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[1], _p, 1); + _p = vsetq_lane_u16(p0[A_hstep], _p, 2); + _p = vsetq_lane_u16(p0[A_hstep + 1], _p, 3); + _p = vsetq_lane_u16(p0[A_hstep * 2], _p, 4); + _p = vsetq_lane_u16(p0[A_hstep * 2 + 1], _p, 5); + _p = vsetq_lane_u16(p0[A_hstep * 3], _p, 6); + _p = vsetq_lane_u16(p0[A_hstep * 3 + 1], _p, 7); + float32x4_t _p01 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p23 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + + float32x4x2_t _scale01 = vzipq_f32(_scale, _scale); + + _p01 = vmulq_f32(_p01, _scale01.val[0]); + _p23 = vmulq_f32(_p23, _scale01.val[1]); + + int8x8_t _r0 = float2int8(_p01, _p23); + + vst1_s8(pp, _r0); + + pp += 8; + p0 += 2; + } + for (; kk < max_kk; kk++) + { + uint16x4_t _p = uint16x4_t(); + _p = vset_lane_u16(p0[0], _p, 0); + _p = vset_lane_u16(p0[A_hstep], _p, 1); + _p = vset_lane_u16(p0[A_hstep * 2], _p, 2); + _p = vset_lane_u16(p0[A_hstep * 3], _p, 3); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)_p); + + _p0 = vmulq_f32(_p0, _scale); + int8x8_t _r0 = float2int8(_p0, _p0); + + pp[0] = vget_lane_s8(_r0, 0); + pp[1] = vget_lane_s8(_r0, 1); + pp[2] = vget_lane_s8(_r0, 2); + pp[3] = vget_lane_s8(_r0, 3); + + pp += 4; + p0++; + } + } + } +#endif // __ARM_NEON + for (; ii + 1 < max_ii; ii += 2) + { + const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k; + + const float scale0 = scales[ii]; + const float scale1 = scales[ii + 1]; + + // if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + float32x4_t _scale0 = vdupq_n_f32(scale0); + float32x4_t _scale1 = vdupq_n_f32(scale1); + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + A_hstep); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale0); + _p2 = vmulq_f32(_p2, _scale1); + _p3 = vmulq_f32(_p3, _scale1); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p1, _p3); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p2)); + float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p2)); + float32x4_t _t2 = vcombine_f32(vget_low_f32(_p1), vget_low_f32(_p3)); + float32x4_t _t3 = vcombine_f32(vget_high_f32(_p1), vget_high_f32(_p3)); + int8x8_t _r0 = float2int8(_t0, _t1); + int8x8_t _r1 = float2int8(_t2, _t3); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r0); + vst1_s8(pp + 8, _r1); + + pp += 16; + p0 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vld1_u16(p0)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + A_hstep)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p1)); + float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p1)); + int8x8_t _r0 = float2int8(_t0, _t1); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r0); + + pp += 8; + p0 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = float2int8(float16_to_float32(p0[0]) * scale0); + pp[1] = float2int8(float16_to_float32(p0[1]) * scale0); + pp[2] = float2int8(float16_to_float32(p0[A_hstep]) * scale1); + pp[3] = float2int8(float16_to_float32(p0[A_hstep + 1]) * scale1); + pp += 4; + p0 += 2; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(float16_to_float32(p0[0]) * scale0); + pp[1] = float2int8(float16_to_float32(p0[A_hstep]) * scale1); + pp += 2; + p0++; + } + } + } + for (; ii < max_ii; ii += 1) + { + const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k; + + const float scale = scales[ii]; + + // if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + float32x4_t _scale = vdupq_n_f32(scale); + for (; kk + 15 < max_kk; kk += 16) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + + vst1_s8(pp, _r0); + + pp += 8; + p0 += 8; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(float16_to_float32(p0[0]) * scale); + pp += 1; + p0++; + } + } + } +} + +static void transpose_compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, float B_scale, Mat& out_descales, int i, int max_ii) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM82 && __aarch64__ && !__ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (ncnn::cpu_support_arm_asimdhp()) + { + transpose_compute_A_tile_fp16_int8_scales_asimdhp(A, scales, B_scale, out_descales, i, max_ii); + return; + } +#endif + + const int elempack = A.elempack; + const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + const int K = A.dims == 3 ? A.c : A.h; + + // NCNN_LOGE("transpose_compute_A_tile_fp16_int8_scales %d %d", max_ii, elempack); + + const float v127_B_scale = 127.f * B_scale; + +#if __ARM_NEON +#if __aarch64__ + float32x4_t _v127 = vdupq_n_f32(127.f); + float32x4_t _v127_B_scale = vdupq_n_f32(v127_B_scale); +#endif +#endif + + float* ps = scales; + float* pods = out_descales; + +#if __ARM_NEON +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (elempack == 8) + { + int ii = 0; + for (; ii + 1 < max_ii; ii += 2) + { + const __fp16* p0 = (const __fp16*)A + (i + ii) * 8; + + float16x8_t _absmax0 = vdupq_n_f16((__fp16)0.f); + float16x8_t _absmax1 = vdupq_n_f16((__fp16)0.f); + float16x8_t _absmax2 = vdupq_n_f16((__fp16)0.f); + float16x8_t _absmax3 = vdupq_n_f16((__fp16)0.f); + int kk = 0; + for (; kk + 1 < K; kk += 2) + { + float16x8_t _p0 = vld1q_f16(p0); + float16x8_t _p1 = vld1q_f16(p0 + 8); + float16x8_t _p2 = vld1q_f16(p0 + A_hstep * 8); + float16x8_t _p3 = vld1q_f16(p0 + A_hstep * 8 + 8); + _absmax0 = vmaxq_f16(_absmax0, vabsq_f16(_p0)); + _absmax1 = vmaxq_f16(_absmax1, vabsq_f16(_p1)); + _absmax2 = vmaxq_f16(_absmax2, vabsq_f16(_p2)); + _absmax3 = vmaxq_f16(_absmax3, vabsq_f16(_p3)); + p0 += A_hstep * 16; + } + _absmax0 = vmaxq_f16(_absmax0, _absmax2); + _absmax1 = vmaxq_f16(_absmax1, _absmax3); + for (; kk < K; kk++) + { + float16x8_t _p0 = vld1q_f16(p0); + float16x8_t _p1 = vld1q_f16(p0 + 8); + _absmax0 = vmaxq_f16(_absmax0, vabsq_f16(_p0)); + _absmax1 = vmaxq_f16(_absmax1, vabsq_f16(_p1)); + p0 += A_hstep * 8; + } + float absmax0 = (float)vmaxvq_f16(_absmax0); + float absmax1 = (float)vmaxvq_f16(_absmax1); + + ps[0] = 127.f / absmax0; + ps[1] = 127.f / absmax1; + pods[0] = absmax0 / v127_B_scale; + pods[1] = absmax1 / v127_B_scale; + ps += 2; + pods += 2; + } + for (; ii < max_ii; ii++) + { + const __fp16* p0 = (const __fp16*)A + (i + ii) * 8; + + float16x8_t _absmax0 = vdupq_n_f16((__fp16)0.f); + float16x8_t _absmax1 = vdupq_n_f16((__fp16)0.f); + float16x8_t _absmax2 = vdupq_n_f16((__fp16)0.f); + float16x8_t _absmax3 = vdupq_n_f16((__fp16)0.f); + int kk = 0; + for (; kk + 3 < K; kk += 4) + { + float16x8_t _p0 = vld1q_f16(p0); + float16x8_t _p1 = vld1q_f16(p0 + A_hstep * 8); + float16x8_t _p2 = vld1q_f16(p0 + A_hstep * 16); + float16x8_t _p3 = vld1q_f16(p0 + A_hstep * 24); + _absmax0 = vmaxq_f16(_absmax0, vabsq_f16(_p0)); + _absmax1 = vmaxq_f16(_absmax1, vabsq_f16(_p1)); + _absmax2 = vmaxq_f16(_absmax2, vabsq_f16(_p2)); + _absmax3 = vmaxq_f16(_absmax3, vabsq_f16(_p3)); + p0 += A_hstep * 32; + } + _absmax0 = vmaxq_f16(_absmax0, _absmax2); + _absmax1 = vmaxq_f16(_absmax1, _absmax3); + for (; kk + 1 < K; kk += 2) + { + float16x8_t _p0 = vld1q_f16(p0); + float16x8_t _p1 = vld1q_f16(p0 + A_hstep * 8); + _absmax0 = vmaxq_f16(_absmax0, vabsq_f16(_p0)); + _absmax1 = vmaxq_f16(_absmax1, vabsq_f16(_p1)); + p0 += A_hstep * 16; + } + _absmax0 = vmaxq_f16(_absmax0, _absmax1); + for (; kk < K; kk++) + { + float16x8_t _p = vld1q_f16(p0); + _absmax0 = vmaxq_f16(_absmax0, vabsq_f16(_p)); + p0 += A_hstep * 8; + } + float absmax = (float)vmaxvq_f16(_absmax0); + + ps[0] = 127.f / absmax; + pods[0] = absmax / v127_B_scale; + ps++; + pods++; + } + } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (elempack == 4) + { + int ii = 0; + for (; ii + 3 < max_ii; ii += 4) + { +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + const __fp16* p0 = (const __fp16*)A + (i + ii) * 4; + + float16x8_t _absmax0 = vdupq_n_f16((__fp16)0.f); + float16x8_t _absmax1 = vdupq_n_f16((__fp16)0.f); + float16x8_t _absmax2 = vdupq_n_f16((__fp16)0.f); + float16x8_t _absmax3 = vdupq_n_f16((__fp16)0.f); + int kk = 0; + for (; kk + 1 < K; kk += 2) + { + float16x8_t _p0 = vld1q_f16(p0); + float16x8_t _p1 = vld1q_f16(p0 + 8); + float16x8_t _p2 = vld1q_f16(p0 + A_hstep * 4); + float16x8_t _p3 = vld1q_f16(p0 + A_hstep * 4 + 8); + _absmax0 = vmaxq_f16(_absmax0, vabsq_f16(_p0)); + _absmax1 = vmaxq_f16(_absmax1, vabsq_f16(_p1)); + _absmax2 = vmaxq_f16(_absmax2, vabsq_f16(_p2)); + _absmax3 = vmaxq_f16(_absmax3, vabsq_f16(_p3)); + p0 += A_hstep * 8; + } + _absmax0 = vmaxq_f16(_absmax0, _absmax2); + _absmax1 = vmaxq_f16(_absmax1, _absmax3); + for (; kk < K; kk++) + { + float16x8_t _p0 = vld1q_f16(p0); + float16x8_t _p1 = vld1q_f16(p0 + 8); + _absmax0 = vmaxq_f16(_absmax0, vabsq_f16(_p0)); + _absmax1 = vmaxq_f16(_absmax1, vabsq_f16(_p1)); + p0 += A_hstep * 4; + } + float16x8_t _aa0123 = vpmaxq_f16(_absmax0, _absmax1); + float32x4_t _absmax = vcvt_f32_f16(vpmax_f16(vget_low_f16(_aa0123), vget_high_f16(_aa0123))); +#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + const unsigned short* p0 = (const unsigned short*)A + (i + ii) * 4; + + float32x4_t _absmax0 = vdupq_n_f32(0.f); + float32x4_t _absmax1 = vdupq_n_f32(0.f); + float32x4_t _absmax2 = vdupq_n_f32(0.f); + float32x4_t _absmax3 = vdupq_n_f32(0.f); + for (int kk = 0; kk < K; kk++) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + _absmax2 = vmaxq_f32(_absmax2, vabsq_f32(_p2)); + _absmax3 = vmaxq_f32(_absmax3, vabsq_f32(_p3)); + p0 += A_hstep * 4; + } +#if __aarch64__ + float32x4_t _aa01 = vpmaxq_f32(_absmax0, _absmax1); + float32x4_t _aa23 = vpmaxq_f32(_absmax2, _absmax3); + float32x4_t _absmax = vpmaxq_f32(_aa01, _aa23); +#else + float32x2_t _aa0 = vmax_f32(vget_low_f32(_absmax0), vget_high_f32(_absmax0)); + float32x2_t _aa1 = vmax_f32(vget_low_f32(_absmax1), vget_high_f32(_absmax1)); + float32x2_t _aa2 = vmax_f32(vget_low_f32(_absmax2), vget_high_f32(_absmax2)); + float32x2_t _aa3 = vmax_f32(vget_low_f32(_absmax3), vget_high_f32(_absmax3)); + float32x2_t _aa01 = vpmax_f32(_aa0, _aa1); + float32x2_t _aa23 = vpmax_f32(_aa2, _aa3); + float32x4_t _absmax = vcombine_f32(_aa01, _aa23); +#endif +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +#if __aarch64__ + float32x4_t _scale = vdivq_f32(_v127, _absmax); + float32x4_t _out_descale = vdivq_f32(_absmax, _v127_B_scale); + + vst1q_f32(ps, _scale); + vst1q_f32(pods, _out_descale); +#else + float tmp[4]; + vst1q_f32(tmp, _absmax); + + ps[0] = 127.f / tmp[0]; + ps[1] = 127.f / tmp[1]; + ps[2] = 127.f / tmp[2]; + ps[3] = 127.f / tmp[3]; + + pods[0] = tmp[0] / v127_B_scale; + pods[1] = tmp[1] / v127_B_scale; + pods[2] = tmp[2] / v127_B_scale; + pods[3] = tmp[3] / v127_B_scale; + + // float32x4_t _recp_absmax = vrecpeq_f32(_absmax); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax, _recp_absmax), _recp_absmax); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax, _recp_absmax), _recp_absmax); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax, _recp_absmax), _recp_absmax); + // float32x4_t _scale = vmulq_f32(_v127, _recp_absmax); + // float32x4_t _out_descale = vmulq_f32(_absmax, _recp_v127_B_scale); +#endif + + ps += 4; + pods += 4; + } + for (; ii < max_ii; ii++) + { +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + const __fp16* p0 = (const __fp16*)A + (i + ii) * 4; + + float16x8_t _amax0 = vdupq_n_f16((__fp16)0.f); + float16x8_t _amax1 = vdupq_n_f16((__fp16)0.f); + float16x8_t _amax2 = vdupq_n_f16((__fp16)0.f); + float16x8_t _amax3 = vdupq_n_f16((__fp16)0.f); + int kk = 0; + for (; kk + 7 < K; kk += 8) + { + float16x4_t _p0 = vld1_f16(p0); + float16x4_t _p1 = vld1_f16(p0 + A_hstep * 4); + float16x4_t _p2 = vld1_f16(p0 + A_hstep * 8); + float16x4_t _p3 = vld1_f16(p0 + A_hstep * 12); + float16x4_t _p4 = vld1_f16(p0 + A_hstep * 16); + float16x4_t _p5 = vld1_f16(p0 + A_hstep * 20); + float16x4_t _p6 = vld1_f16(p0 + A_hstep * 24); + float16x4_t _p7 = vld1_f16(p0 + A_hstep * 28); + _amax0 = vmaxq_f16(_amax0, vabsq_f16(vcombine_f16(_p0, _p1))); + _amax1 = vmaxq_f16(_amax1, vabsq_f16(vcombine_f16(_p2, _p3))); + _amax2 = vmaxq_f16(_amax2, vabsq_f16(vcombine_f16(_p4, _p5))); + _amax3 = vmaxq_f16(_amax3, vabsq_f16(vcombine_f16(_p6, _p7))); + p0 += A_hstep * 32; + } + _amax0 = vmaxq_f16(_amax0, _amax2); + _amax1 = vmaxq_f16(_amax1, _amax3); + for (; kk + 3 < K; kk += 4) + { + float16x4_t _p0 = vld1_f16(p0); + float16x4_t _p1 = vld1_f16(p0 + A_hstep * 4); + float16x4_t _p2 = vld1_f16(p0 + A_hstep * 8); + float16x4_t _p3 = vld1_f16(p0 + A_hstep * 12); + _amax0 = vmaxq_f16(_amax0, vabsq_f16(vcombine_f16(_p0, _p1))); + _amax1 = vmaxq_f16(_amax1, vabsq_f16(vcombine_f16(_p2, _p3))); + p0 += A_hstep * 16; + } + _amax0 = vmaxq_f16(_amax0, _amax1); + for (; kk + 1 < K; kk += 2) + { + float16x4_t _p0 = vld1_f16(p0); + float16x4_t _p1 = vld1_f16(p0 + A_hstep * 4); + _amax0 = vmaxq_f16(_amax0, vabsq_f16(vcombine_f16(_p0, _p1))); + p0 += A_hstep * 8; + } + float16x4_t _amax01 = vmax_f16(vget_low_f16(_amax0), vget_high_f16(_amax0)); + for (; kk < K; kk++) + { + float16x4_t _p = vld1_f16(p0); + _amax01 = vmax_f16(_amax01, vabs_f16(_p)); + p0 += A_hstep * 4; + } + float absmax = (float)vmaxv_f16(_amax01); +#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + const unsigned short* p0 = (const unsigned short*)A + (i + ii) * 4; + + float32x4_t _absmax0 = vdupq_n_f32(0.f); + float32x4_t _absmax1 = vdupq_n_f32(0.f); + float32x4_t _absmax2 = vdupq_n_f32(0.f); + float32x4_t _absmax3 = vdupq_n_f32(0.f); + int kk = 0; + for (; kk + 3 < K; kk += 4) + { + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vld1_u16(p0)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + A_hstep * 4)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + A_hstep * 8)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + A_hstep * 12)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + _absmax2 = vmaxq_f32(_absmax2, vabsq_f32(_p2)); + _absmax3 = vmaxq_f32(_absmax3, vabsq_f32(_p3)); + p0 += A_hstep * 16; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax2); + _absmax1 = vmaxq_f32(_absmax1, _absmax3); + for (; kk + 1 < K; kk += 2) + { + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vld1_u16(p0)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + A_hstep * 4)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + p0 += A_hstep * 8; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax1); + for (; kk < K; kk++) + { + float32x4_t _p = vcvt_f32_f16((float16x4_t)vld1_u16(p0)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p)); + p0 += A_hstep * 4; + } +#if __aarch64__ + float absmax = vmaxvq_f32(_absmax0); +#else + float32x2_t _aa = vmax_f32(vget_low_f32(_absmax0), vget_high_f32(_absmax0)); + float absmax = std::max(vget_lane_f32(_aa, 0), vget_lane_f32(_aa, 1)); +#endif +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + + ps[0] = 127.f / absmax; + pods[0] = absmax / v127_B_scale; + ps++; + pods++; + } + } +#endif // __ARM_NEON + if (elempack == 1) + { + int ii = 0; +#if __ARM_NEON + for (; ii + 3 < max_ii; ii += 4) + { +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + const __fp16* p0 = (const __fp16*)A + (i + ii); + + float16x8_t _amax0 = vdupq_n_f16((__fp16)0.f); + float16x8_t _amax1 = vdupq_n_f16((__fp16)0.f); + float16x8_t _amax2 = vdupq_n_f16((__fp16)0.f); + float16x8_t _amax3 = vdupq_n_f16((__fp16)0.f); + int kk = 0; + for (; kk + 7 < K; kk += 8) + { + float16x4_t _p0 = vld1_f16(p0); + float16x4_t _p1 = vld1_f16(p0 + A_hstep); + float16x4_t _p2 = vld1_f16(p0 + A_hstep * 2); + float16x4_t _p3 = vld1_f16(p0 + A_hstep * 3); + float16x4_t _p4 = vld1_f16(p0 + A_hstep * 4); + float16x4_t _p5 = vld1_f16(p0 + A_hstep * 5); + float16x4_t _p6 = vld1_f16(p0 + A_hstep * 6); + float16x4_t _p7 = vld1_f16(p0 + A_hstep * 7); + _amax0 = vmaxq_f16(_amax0, vabsq_f16(vcombine_f16(_p0, _p1))); + _amax1 = vmaxq_f16(_amax1, vabsq_f16(vcombine_f16(_p2, _p3))); + _amax2 = vmaxq_f16(_amax2, vabsq_f16(vcombine_f16(_p4, _p5))); + _amax3 = vmaxq_f16(_amax3, vabsq_f16(vcombine_f16(_p6, _p7))); + p0 += A_hstep * 8; + } + _amax0 = vmaxq_f16(_amax0, _amax2); + _amax1 = vmaxq_f16(_amax1, _amax3); + for (; kk + 3 < K; kk += 4) + { + float16x4_t _p0 = vld1_f16(p0); + float16x4_t _p1 = vld1_f16(p0 + A_hstep); + float16x4_t _p2 = vld1_f16(p0 + A_hstep * 2); + float16x4_t _p3 = vld1_f16(p0 + A_hstep * 3); + _amax0 = vmaxq_f16(_amax0, vabsq_f16(vcombine_f16(_p0, _p1))); + _amax1 = vmaxq_f16(_amax1, vabsq_f16(vcombine_f16(_p2, _p3))); + p0 += A_hstep * 4; + } + _amax0 = vmaxq_f16(_amax0, _amax1); + for (; kk + 1 < K; kk += 2) + { + float16x4_t _p0 = vld1_f16(p0); + float16x4_t _p1 = vld1_f16(p0 + A_hstep); + _amax0 = vmaxq_f16(_amax0, vabsq_f16(vcombine_f16(_p0, _p1))); + p0 += A_hstep * 2; + } + float16x4_t _amax = vmax_f16(vget_low_f16(_amax0), vget_high_f16(_amax0)); + for (; kk < K; kk++) + { + float16x4_t _p = vld1_f16(p0); + _amax = vmax_f16(_amax, vabs_f16(_p)); + p0 += A_hstep; + } + float32x4_t _absmax0 = vcvt_f32_f16(_amax); +#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + const unsigned short* p0 = (const unsigned short*)A + (i + ii); + + float32x4_t _absmax0 = vdupq_n_f32(0.f); + float32x4_t _absmax1 = vdupq_n_f32(0.f); + float32x4_t _absmax2 = vdupq_n_f32(0.f); + float32x4_t _absmax3 = vdupq_n_f32(0.f); + int kk = 0; + for (; kk + 3 < K; kk += 4) + { + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vld1_u16(p0)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + A_hstep)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + A_hstep * 2)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + A_hstep * 3)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + _absmax2 = vmaxq_f32(_absmax2, vabsq_f32(_p2)); + _absmax3 = vmaxq_f32(_absmax3, vabsq_f32(_p3)); + p0 += A_hstep * 4; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax2); + _absmax1 = vmaxq_f32(_absmax1, _absmax3); + for (; kk + 1 < K; kk += 2) + { + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vld1_u16(p0)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + A_hstep)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + p0 += A_hstep * 2; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax1); + for (; kk < K; kk++) + { + float32x4_t _p = vcvt_f32_f16((float16x4_t)vld1_u16(p0)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p)); + p0 += A_hstep; + } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +#if __aarch64__ + float32x4_t _scale = vdivq_f32(_v127, _absmax0); + float32x4_t _out_descale = vdivq_f32(_absmax0, _v127_B_scale); + + vst1q_f32(ps, _scale); + vst1q_f32(pods, _out_descale); +#else + float tmp[4]; + vst1q_f32(tmp, _absmax0); + + ps[0] = 127.f / tmp[0]; + ps[1] = 127.f / tmp[1]; + ps[2] = 127.f / tmp[2]; + ps[3] = 127.f / tmp[3]; + + pods[0] = tmp[0] / v127_B_scale; + pods[1] = tmp[1] / v127_B_scale; + pods[2] = tmp[2] / v127_B_scale; + pods[3] = tmp[3] / v127_B_scale; + + // float32x4_t _recp_absmax = vrecpeq_f32(_absmax0); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax0, _recp_absmax), _recp_absmax); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax0, _recp_absmax), _recp_absmax); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax0, _recp_absmax), _recp_absmax); + // float32x4_t _scale = vmulq_f32(_v127, _recp_absmax); + // float32x4_t _out_descale = vmulq_f32(_absmax0, _recp_v127_B_scale); +#endif + + ps += 4; + pods += 4; + } +#endif // __ARM_NEON + for (; ii < max_ii; ii++) + { +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + const __fp16* p0 = (const __fp16*)A + (i + ii); + + float absmax = 0.f; + int kk = 0; + float16x8_t _absmax0 = vdupq_n_f16((__fp16)0.f); + for (; kk + 7 < K; kk += 8) + { + float16x8_t _p = float16x8_t(); + _p = vsetq_lane_f16(p0[0], _p, 0); + _p = vsetq_lane_f16(p0[A_hstep], _p, 1); + _p = vsetq_lane_f16(p0[A_hstep * 2], _p, 2); + _p = vsetq_lane_f16(p0[A_hstep * 3], _p, 3); + _p = vsetq_lane_f16(p0[A_hstep * 4], _p, 4); + _p = vsetq_lane_f16(p0[A_hstep * 5], _p, 5); + _p = vsetq_lane_f16(p0[A_hstep * 6], _p, 6); + _p = vsetq_lane_f16(p0[A_hstep * 7], _p, 7); + _absmax0 = vmaxq_f16(_absmax0, vabsq_f16(_p)); + p0 += A_hstep * 8; + } + float16x4_t _amax0 = vmax_f16(vget_low_f16(_absmax0), vget_high_f16(_absmax0)); + for (; kk + 3 < K; kk += 4) + { + float16x4_t _p = float16x4_t(); + _p = vset_lane_f16(p0[0], _p, 0); + _p = vset_lane_f16(p0[A_hstep], _p, 1); + _p = vset_lane_f16(p0[A_hstep * 2], _p, 2); + _p = vset_lane_f16(p0[A_hstep * 3], _p, 3); + _amax0 = vmax_f16(_amax0, vabs_f16(_p)); + p0 += A_hstep * 4; + } + absmax = (float)vmaxv_f16(_amax0); + for (; kk < K; kk++) + { + absmax = std::max(absmax, (float)fabsf((float)p0[0])); + p0 += A_hstep; + } +#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + const unsigned short* p0 = (const unsigned short*)A + (i + ii); + + float absmax = 0.f; + int kk = 0; +#if __ARM_NEON + float32x4_t _absmax0 = vdupq_n_f32(0.f); + float32x4_t _absmax1 = vdupq_n_f32(0.f); + for (; kk + 7 < K; kk += 8) + { + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[A_hstep], _p, 1); + _p = vsetq_lane_u16(p0[A_hstep * 2], _p, 2); + _p = vsetq_lane_u16(p0[A_hstep * 3], _p, 3); + _p = vsetq_lane_u16(p0[A_hstep * 4], _p, 4); + _p = vsetq_lane_u16(p0[A_hstep * 5], _p, 5); + _p = vsetq_lane_u16(p0[A_hstep * 6], _p, 6); + _p = vsetq_lane_u16(p0[A_hstep * 7], _p, 7); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + p0 += A_hstep * 8; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax1); + for (; kk + 3 < K; kk += 4) + { + uint16x4_t _p = uint16x4_t(); + _p = vset_lane_u16(p0[0], _p, 0); + _p = vset_lane_u16(p0[A_hstep], _p, 1); + _p = vset_lane_u16(p0[A_hstep * 2], _p, 2); + _p = vset_lane_u16(p0[A_hstep * 3], _p, 3); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)_p); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + p0 += A_hstep * 4; + } +#if __aarch64__ + absmax = vmaxvq_f32(_absmax0); +#else + float32x2_t _aa = vmax_f32(vget_low_f32(_absmax0), vget_high_f32(_absmax0)); + absmax = std::max(vget_lane_f32(_aa, 0), vget_lane_f32(_aa, 1)); +#endif +#endif // __ARM_NEON + for (; kk < K; kk++) + { + absmax = std::max(absmax, (float)fabsf(float16_to_float32(p0[0]))); + p0 += A_hstep; + } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + + ps[0] = 127.f / absmax; + pods[0] = absmax / v127_B_scale; + ps++; + pods++; + } + } +} + +static void transpose_pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + transpose_pack_A_tile_fp16_to_int8_i8mm(A, AT, i, max_ii, k, max_kk, scales); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + transpose_pack_A_tile_fp16_to_int8_asimddp(A, AT, i, max_ii, k, max_kk, scales); + return; + } +#endif + + const int elempack = A.elempack; + const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + + // NCNN_LOGE("transpose_pack_A_tile_fp16_to_int8 %d %d", max_ii, elempack); + + signed char* pp = AT; + + int ii = 0; +#if __ARM_NEON + for (; ii + 7 < max_ii; ii += 8) + { + const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * elempack; + + float32x4_t _scale0 = vld1q_f32((const float*)scales + ii); + float32x4_t _scale1 = vld1q_f32((const float*)scales + ii + 4); + +#if __aarch64__ + if (elempack == 8) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + 16); + uint16x8_t _s = vld1q_u16(p0 + 24); + uint16x8_t _t = vld1q_u16(p0 + 32); + uint16x8_t _u = vld1q_u16(p0 + 40); + uint16x8_t _v = vld1q_u16(p0 + 48); + uint16x8_t _w = vld1q_u16(p0 + 56); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + float32x4_t _p4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_r)); + float32x4_t _p5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_r)); + float32x4_t _p6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_s)); + float32x4_t _p7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_s)); + float32x4_t _p8 = vcvt_f32_f16((float16x4_t)vget_low_u16(_t)); + float32x4_t _p9 = vcvt_f32_f16((float16x4_t)vget_high_u16(_t)); + float32x4_t _pa = vcvt_f32_f16((float16x4_t)vget_low_u16(_u)); + float32x4_t _pb = vcvt_f32_f16((float16x4_t)vget_high_u16(_u)); + float32x4_t _pc = vcvt_f32_f16((float16x4_t)vget_low_u16(_v)); + float32x4_t _pd = vcvt_f32_f16((float16x4_t)vget_high_u16(_v)); + float32x4_t _pe = vcvt_f32_f16((float16x4_t)vget_low_u16(_w)); + float32x4_t _pf = vcvt_f32_f16((float16x4_t)vget_high_u16(_w)); + + _p0 = vmulq_laneq_f32(_p0, _scale0, 0); + _p1 = vmulq_laneq_f32(_p1, _scale0, 0); + _p2 = vmulq_laneq_f32(_p2, _scale0, 1); + _p3 = vmulq_laneq_f32(_p3, _scale0, 1); + _p4 = vmulq_laneq_f32(_p4, _scale0, 2); + _p5 = vmulq_laneq_f32(_p5, _scale0, 2); + _p6 = vmulq_laneq_f32(_p6, _scale0, 3); + _p7 = vmulq_laneq_f32(_p7, _scale0, 3); + _p8 = vmulq_laneq_f32(_p8, _scale1, 0); + _p9 = vmulq_laneq_f32(_p9, _scale1, 0); + _pa = vmulq_laneq_f32(_pa, _scale1, 1); + _pb = vmulq_laneq_f32(_pb, _scale1, 1); + _pc = vmulq_laneq_f32(_pc, _scale1, 2); + _pd = vmulq_laneq_f32(_pd, _scale1, 2); + _pe = vmulq_laneq_f32(_pe, _scale1, 3); + _pf = vmulq_laneq_f32(_pf, _scale1, 3); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p4, _p6); + int8x8_t _r2 = float2int8(_p8, _pa); + int8x8_t _r3 = float2int8(_pc, _pe); + int8x8_t _r4 = float2int8(_p1, _p3); + int8x8_t _r5 = float2int8(_p5, _p7); + int8x8_t _r6 = float2int8(_p9, _pb); + int8x8_t _r7 = float2int8(_pd, _pf); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p4, _p6); + int8x8_t _r2 = float2int8(_p8, _pa); + int8x8_t _r3 = float2int8(_pc, _pe); + int8x8_t _r4 = float2int8(_p1, _p3); + int8x8_t _r5 = float2int8(_p5, _p7); + int8x8_t _r6 = float2int8(_p9, _pb); + int8x8_t _r7 = float2int8(_pd, _pf); + + int16x8_t _r01 = vreinterpretq_s16_s8(vcombine_s8(_r0, _r1)); + int16x8_t _r23 = vreinterpretq_s16_s8(vcombine_s8(_r2, _r3)); + int16x8_t _r45 = vreinterpretq_s16_s8(vcombine_s8(_r4, _r5)); + int16x8_t _r67 = vreinterpretq_s16_s8(vcombine_s8(_r6, _r7)); + int16x8x2_t _rr0 = vuzpq_s16(_r01, _r23); + int16x8x2_t _rr1 = vuzpq_s16(_r45, _r67); + + vst1q_s8(pp, vreinterpretq_s8_s16(_rr0.val[0])); + vst1q_s8(pp + 16, vreinterpretq_s8_s16(_rr0.val[1])); + vst1q_s8(pp + 32, vreinterpretq_s8_s16(_rr1.val[0])); + vst1q_s8(pp + 48, vreinterpretq_s8_s16(_rr1.val[1])); +#endif // __ARM_FEATURE_DOTPROD + + pp += 64; + p0 += A_hstep * 8; + } + } +#endif // __aarch64__ + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + 16); + uint16x8_t _s = vld1q_u16(p0 + 24); + uint16x8_t _t = vld1q_u16(p0 + A_hstep * 4); + uint16x8_t _u = vld1q_u16(p0 + A_hstep * 4 + 8); + uint16x8_t _v = vld1q_u16(p0 + A_hstep * 4 + 16); + uint16x8_t _w = vld1q_u16(p0 + A_hstep * 4 + 24); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + float32x4_t _p4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_r)); + float32x4_t _p5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_r)); + float32x4_t _p6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_s)); + float32x4_t _p7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_s)); + float32x4_t _p8 = vcvt_f32_f16((float16x4_t)vget_low_u16(_t)); + float32x4_t _p9 = vcvt_f32_f16((float16x4_t)vget_high_u16(_t)); + float32x4_t _pa = vcvt_f32_f16((float16x4_t)vget_low_u16(_u)); + float32x4_t _pb = vcvt_f32_f16((float16x4_t)vget_high_u16(_u)); + float32x4_t _pc = vcvt_f32_f16((float16x4_t)vget_low_u16(_v)); + float32x4_t _pd = vcvt_f32_f16((float16x4_t)vget_high_u16(_v)); + float32x4_t _pe = vcvt_f32_f16((float16x4_t)vget_low_u16(_w)); + float32x4_t _pf = vcvt_f32_f16((float16x4_t)vget_high_u16(_w)); + +#if __aarch64__ + _p0 = vmulq_laneq_f32(_p0, _scale0, 0); + _p1 = vmulq_laneq_f32(_p1, _scale0, 1); + _p2 = vmulq_laneq_f32(_p2, _scale0, 2); + _p3 = vmulq_laneq_f32(_p3, _scale0, 3); + _p4 = vmulq_laneq_f32(_p4, _scale1, 0); + _p5 = vmulq_laneq_f32(_p5, _scale1, 1); + _p6 = vmulq_laneq_f32(_p6, _scale1, 2); + _p7 = vmulq_laneq_f32(_p7, _scale1, 3); + _p8 = vmulq_laneq_f32(_p8, _scale0, 0); + _p9 = vmulq_laneq_f32(_p9, _scale0, 1); + _pa = vmulq_laneq_f32(_pa, _scale0, 2); + _pb = vmulq_laneq_f32(_pb, _scale0, 3); + _pc = vmulq_laneq_f32(_pc, _scale1, 0); + _pd = vmulq_laneq_f32(_pd, _scale1, 1); + _pe = vmulq_laneq_f32(_pe, _scale1, 2); + _pf = vmulq_laneq_f32(_pf, _scale1, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale0), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale0), 1); + _p2 = vmulq_lane_f32(_p2, vget_high_f32(_scale0), 0); + _p3 = vmulq_lane_f32(_p3, vget_high_f32(_scale0), 1); + _p4 = vmulq_lane_f32(_p4, vget_low_f32(_scale1), 0); + _p5 = vmulq_lane_f32(_p5, vget_low_f32(_scale1), 1); + _p6 = vmulq_lane_f32(_p6, vget_high_f32(_scale1), 0); + _p7 = vmulq_lane_f32(_p7, vget_high_f32(_scale1), 1); + _p8 = vmulq_lane_f32(_p8, vget_low_f32(_scale0), 0); + _p9 = vmulq_lane_f32(_p9, vget_low_f32(_scale0), 1); + _pa = vmulq_lane_f32(_pa, vget_high_f32(_scale0), 0); + _pb = vmulq_lane_f32(_pb, vget_high_f32(_scale0), 1); + _pc = vmulq_lane_f32(_pc, vget_low_f32(_scale1), 0); + _pd = vmulq_lane_f32(_pd, vget_low_f32(_scale1), 1); + _pe = vmulq_lane_f32(_pe, vget_high_f32(_scale1), 0); + _pf = vmulq_lane_f32(_pf, vget_high_f32(_scale1), 1); +#endif + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p8); + int8x8_t _r1 = float2int8(_p1, _p9); + int8x8_t _r2 = float2int8(_p2, _pa); + int8x8_t _r3 = float2int8(_p3, _pb); + int8x8_t _r4 = float2int8(_p4, _pc); + int8x8_t _r5 = float2int8(_p5, _pd); + int8x8_t _r6 = float2int8(_p6, _pe); + int8x8_t _r7 = float2int8(_p7, _pf); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); + + int16x8_t _r01 = vreinterpretq_s16_s8(vcombine_s8(_r0, _r1)); + int16x8_t _r23 = vreinterpretq_s16_s8(vcombine_s8(_r2, _r3)); + int16x8_t _r45 = vreinterpretq_s16_s8(vcombine_s8(_r4, _r5)); + int16x8_t _r67 = vreinterpretq_s16_s8(vcombine_s8(_r6, _r7)); + int16x8x2_t _rr0 = vuzpq_s16(_r01, _r23); + int16x8x2_t _rr1 = vuzpq_s16(_r45, _r67); + + vst1q_s8(pp, vreinterpretq_s8_s16(_rr0.val[0])); + vst1q_s8(pp + 16, vreinterpretq_s8_s16(_rr0.val[1])); + vst1q_s8(pp + 32, vreinterpretq_s8_s16(_rr1.val[0])); + vst1q_s8(pp + 48, vreinterpretq_s8_s16(_rr1.val[1])); +#endif // __ARM_FEATURE_DOTPROD + + pp += 64; + p0 += A_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + 16); + uint16x8_t _s = vld1q_u16(p0 + 24); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + float32x4_t _p4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_r)); + float32x4_t _p5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_r)); + float32x4_t _p6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_s)); + float32x4_t _p7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_s)); + +#if __aarch64__ + _p0 = vmulq_laneq_f32(_p0, _scale0, 0); + _p1 = vmulq_laneq_f32(_p1, _scale0, 1); + _p2 = vmulq_laneq_f32(_p2, _scale0, 2); + _p3 = vmulq_laneq_f32(_p3, _scale0, 3); + _p4 = vmulq_laneq_f32(_p4, _scale1, 0); + _p5 = vmulq_laneq_f32(_p5, _scale1, 1); + _p6 = vmulq_laneq_f32(_p6, _scale1, 2); + _p7 = vmulq_laneq_f32(_p7, _scale1, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale0), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale0), 1); + _p2 = vmulq_lane_f32(_p2, vget_high_f32(_scale0), 0); + _p3 = vmulq_lane_f32(_p3, vget_high_f32(_scale0), 1); + _p4 = vmulq_lane_f32(_p4, vget_low_f32(_scale1), 0); + _p5 = vmulq_lane_f32(_p5, vget_low_f32(_scale1), 1); + _p6 = vmulq_lane_f32(_p6, vget_high_f32(_scale1), 0); + _p7 = vmulq_lane_f32(_p7, vget_high_f32(_scale1), 1); +#endif + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + +#if __ARM_FEATURE_DOTPROD + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); +#else // __ARM_FEATURE_DOTPROD + int16x8_t _r01 = vreinterpretq_s16_s8(vcombine_s8(_r0, _r1)); + int16x8_t _r23 = vreinterpretq_s16_s8(vcombine_s8(_r2, _r3)); + int16x8x2_t _rr = vuzpq_s16(_r01, _r23); + + vst1q_s8(pp, vreinterpretq_s8_s16(_rr.val[0])); + vst1q_s8(pp + 16, vreinterpretq_s8_s16(_rr.val[1])); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += A_hstep * 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + A_hstep); + uint16x8_t _r = vld1q_u16(p0 + A_hstep * 2); + uint16x8_t _s = vld1q_u16(p0 + A_hstep * 3); + uint16x8_t _t = vld1q_u16(p0 + A_hstep * 4); + uint16x8_t _u = vld1q_u16(p0 + A_hstep * 5); + uint16x8_t _v = vld1q_u16(p0 + A_hstep * 6); + uint16x8_t _w = vld1q_u16(p0 + A_hstep * 7); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + float32x4_t _p4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_r)); + float32x4_t _p5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_r)); + float32x4_t _p6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_s)); + float32x4_t _p7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_s)); + float32x4_t _p8 = vcvt_f32_f16((float16x4_t)vget_low_u16(_t)); + float32x4_t _p9 = vcvt_f32_f16((float16x4_t)vget_high_u16(_t)); + float32x4_t _pa = vcvt_f32_f16((float16x4_t)vget_low_u16(_u)); + float32x4_t _pb = vcvt_f32_f16((float16x4_t)vget_high_u16(_u)); + float32x4_t _pc = vcvt_f32_f16((float16x4_t)vget_low_u16(_v)); + float32x4_t _pd = vcvt_f32_f16((float16x4_t)vget_high_u16(_v)); + float32x4_t _pe = vcvt_f32_f16((float16x4_t)vget_low_u16(_w)); + float32x4_t _pf = vcvt_f32_f16((float16x4_t)vget_high_u16(_w)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + _p2 = vmulq_f32(_p2, _scale0); + _p3 = vmulq_f32(_p3, _scale1); + _p4 = vmulq_f32(_p4, _scale0); + _p5 = vmulq_f32(_p5, _scale1); + _p6 = vmulq_f32(_p6, _scale0); + _p7 = vmulq_f32(_p7, _scale1); + _p8 = vmulq_f32(_p8, _scale0); + _p9 = vmulq_f32(_p9, _scale1); + _pa = vmulq_f32(_pa, _scale0); + _pb = vmulq_f32(_pb, _scale1); + _pc = vmulq_f32(_pc, _scale0); + _pd = vmulq_f32(_pd, _scale1); + _pe = vmulq_f32(_pe, _scale0); + _pf = vmulq_f32(_pf, _scale1); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8x2_t _r04 = vzip_s8(_r0, _r4); + int8x8x2_t _r15 = vzip_s8(_r1, _r5); + int8x8x2_t _r26 = vzip_s8(_r2, _r6); + int8x8x2_t _r37 = vzip_s8(_r3, _r7); + int8x16x4_t _r0123; + _r0123.val[0] = vcombine_s8(_r04.val[0], _r04.val[1]); + _r0123.val[1] = vcombine_s8(_r15.val[0], _r15.val[1]); + _r0123.val[2] = vcombine_s8(_r26.val[0], _r26.val[1]); + _r0123.val[3] = vcombine_s8(_r37.val[0], _r37.val[1]); + + vst4q_s8(pp, _r0123); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8x4_t _r0123; + _r0123.val[0] = _r0; + _r0123.val[1] = _r1; + _r0123.val[2] = _r2; + _r0123.val[3] = _r3; + int8x8x4_t _r4567; + _r4567.val[0] = _r4; + _r4567.val[1] = _r5; + _r4567.val[2] = _r6; + _r4567.val[3] = _r7; + + vst4_s8(pp, _r0123); + vst4_s8(pp + 32, _r4567); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(_r0, _r2); + _r01.val[1] = vcombine_s8(_r1, _r3); + int8x16x2_t _r23; + _r23.val[0] = vcombine_s8(_r4, _r6); + _r23.val[1] = vcombine_s8(_r5, _r7); + + vst2q_s8(pp, _r01); + vst2q_s8(pp + 32, _r23); +#endif // __ARM_FEATURE_DOTPROD + + pp += 64; + p0 += A_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + A_hstep); + uint16x8_t _r = vld1q_u16(p0 + A_hstep * 2); + uint16x8_t _s = vld1q_u16(p0 + A_hstep * 3); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + float32x4_t _p4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_r)); + float32x4_t _p5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_r)); + float32x4_t _p6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_s)); + float32x4_t _p7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_s)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + _p2 = vmulq_f32(_p2, _scale0); + _p3 = vmulq_f32(_p3, _scale1); + _p4 = vmulq_f32(_p4, _scale0); + _p5 = vmulq_f32(_p5, _scale1); + _p6 = vmulq_f32(_p6, _scale0); + _p7 = vmulq_f32(_p7, _scale1); + +#if __ARM_FEATURE_DOTPROD + int8x8x4_t _r0123; + _r0123.val[0] = float2int8(_p0, _p1); + _r0123.val[1] = float2int8(_p2, _p3); + _r0123.val[2] = float2int8(_p4, _p5); + _r0123.val[3] = float2int8(_p6, _p7); + + vst4_s8(pp, _r0123); +#else // __ARM_FEATURE_DOTPROD + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p1), float2int8(_p4, _p5)); + _r01.val[1] = vcombine_s8(float2int8(_p2, _p3), float2int8(_p6, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += A_hstep * 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + A_hstep); + + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + _p2 = vmulq_f32(_p2, _scale0); + _p3 = vmulq_f32(_p3, _scale1); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p1); + _r01.val[1] = float2int8(_p2, _p3); + + vst2_s8(pp, _r01); + + pp += 16; + p0 += A_hstep * 2; + } + for (; kk < max_kk; kk++) + { + uint16x8_t _p = vld1q_u16(p0); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += A_hstep; + } + } + } + for (; ii + 3 < max_ii; ii += 4) + { + const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * elempack; + + float32x4_t _scale = vld1q_f32((const float*)scales + ii); + +#if __aarch64__ + if (elempack == 8) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + 16); + uint16x8_t _s = vld1q_u16(p0 + 24); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + float32x4_t _p4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_r)); + float32x4_t _p5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_r)); + float32x4_t _p6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_s)); + float32x4_t _p7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_s)); + + _p0 = vmulq_laneq_f32(_p0, _scale, 0); + _p1 = vmulq_laneq_f32(_p1, _scale, 0); + _p2 = vmulq_laneq_f32(_p2, _scale, 1); + _p3 = vmulq_laneq_f32(_p3, _scale, 1); + _p4 = vmulq_laneq_f32(_p4, _scale, 2); + _p5 = vmulq_laneq_f32(_p5, _scale, 2); + _p6 = vmulq_laneq_f32(_p6, _scale, 3); + _p7 = vmulq_laneq_f32(_p7, _scale, 3); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p4, _p6); + int8x8_t _r2 = float2int8(_p1, _p3); + int8x8_t _r3 = float2int8(_p5, _p7); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p4, _p6)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p1, _p3)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p5, _p7)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r2 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + + pp += 32; + p0 += A_hstep * 8; + } + } +#endif // __aarch64__ + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + A_hstep * 4); + uint16x8_t _s = vld1q_u16(p0 + A_hstep * 4 + 8); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + float32x4_t _p4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_r)); + float32x4_t _p5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_r)); + float32x4_t _p6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_s)); + float32x4_t _p7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_s)); + +#if __aarch64__ + _p0 = vmulq_laneq_f32(_p0, _scale, 0); + _p1 = vmulq_laneq_f32(_p1, _scale, 1); + _p2 = vmulq_laneq_f32(_p2, _scale, 2); + _p3 = vmulq_laneq_f32(_p3, _scale, 3); + _p4 = vmulq_laneq_f32(_p4, _scale, 0); + _p5 = vmulq_laneq_f32(_p5, _scale, 1); + _p6 = vmulq_laneq_f32(_p6, _scale, 2); + _p7 = vmulq_laneq_f32(_p7, _scale, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale), 1); + _p2 = vmulq_lane_f32(_p2, vget_high_f32(_scale), 0); + _p3 = vmulq_lane_f32(_p3, vget_high_f32(_scale), 1); + _p4 = vmulq_lane_f32(_p4, vget_low_f32(_scale), 0); + _p5 = vmulq_lane_f32(_p5, vget_low_f32(_scale), 1); + _p6 = vmulq_lane_f32(_p6, vget_high_f32(_scale), 0); + _p7 = vmulq_lane_f32(_p7, vget_high_f32(_scale), 1); +#endif + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p4); + int8x8_t _r1 = float2int8(_p1, _p5); + int8x8_t _r2 = float2int8(_p2, _p6); + int8x8_t _r3 = float2int8(_p3, _p7); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p4, _p5)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p6, _p7)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r2 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + + pp += 32; + p0 += A_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + +#if __aarch64__ + _p0 = vmulq_laneq_f32(_p0, _scale, 0); + _p1 = vmulq_laneq_f32(_p1, _scale, 1); + _p2 = vmulq_laneq_f32(_p2, _scale, 2); + _p3 = vmulq_laneq_f32(_p3, _scale, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale), 1); + _p2 = vmulq_lane_f32(_p2, vget_high_f32(_scale), 0); + _p3 = vmulq_lane_f32(_p3, vget_high_f32(_scale), 1); +#endif + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += A_hstep * 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vld1_u16(p0)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + A_hstep)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + A_hstep * 2)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + A_hstep * 3)); + float32x4_t _p4 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + A_hstep * 4)); + float32x4_t _p5 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + A_hstep * 5)); + float32x4_t _p6 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + A_hstep * 6)); + float32x4_t _p7 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + A_hstep * 7)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + float32x4x2_t _p04 = vzipq_f32(_p0, _p4); + float32x4x2_t _p15 = vzipq_f32(_p1, _p5); + float32x4x2_t _p26 = vzipq_f32(_p2, _p6); + float32x4x2_t _p37 = vzipq_f32(_p3, _p7); + int8x8x4_t _r0123; + _r0123.val[0] = float2int8(_p04.val[0], _p04.val[1]); + _r0123.val[1] = float2int8(_p15.val[0], _p15.val[1]); + _r0123.val[2] = float2int8(_p26.val[0], _p26.val[1]); + _r0123.val[3] = float2int8(_p37.val[0], _p37.val[1]); + + vst4_s8(pp, _r0123); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8x4_t _r0123; + _r0123.val[0] = float2int8(_p0, _p4); + _r0123.val[1] = float2int8(_p1, _p5); + _r0123.val[2] = float2int8(_p2, _p6); + _r0123.val[3] = float2int8(_p3, _p7); + + vst4_s8(pp, _r0123); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p2), float2int8(_p4, _p6)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p3), float2int8(_p5, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += A_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vld1_u16(p0)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + A_hstep)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + A_hstep * 2)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + A_hstep * 3)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + +#if __ARM_FEATURE_DOTPROD + transpose4x4_ps(_p0, _p1, _p2, _p3); + + int8x8_t _r01 = float2int8(_p0, _p1); + int8x8_t _r23 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r01, _r23)); +#else // __ARM_FEATURE_DOTPROD + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p2); + _r01.val[1] = float2int8(_p1, _p3); + + vst2_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 16; + p0 += A_hstep * 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vld1_u16(p0)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + A_hstep)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + float32x4x2_t _p01 = vzipq_f32(_p0, _p1); + + int8x8_t _r01 = float2int8(_p01.val[0], _p01.val[1]); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += A_hstep * 2; + } + for (; kk < max_kk; kk++) + { + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vld1_u16(p0)); + _p0 = vmulq_f32(_p0, _scale); + int8x8_t _r0 = float2int8(_p0, _p0); + + pp[0] = vget_lane_s8(_r0, 0); + pp[1] = vget_lane_s8(_r0, 1); + pp[2] = vget_lane_s8(_r0, 2); + pp[3] = vget_lane_s8(_r0, 3); + pp += 4; + p0 += A_hstep; + } + } + } +#endif // __ARM_NEON + for (; ii + 1 < max_ii; ii += 2) + { + const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * elempack; + + const float scale0 = scales[ii]; + const float scale1 = scales[ii + 1]; + +#if __ARM_NEON + float32x4_t _scale0 = vdupq_n_f32(scale0); + float32x4_t _scale1 = vdupq_n_f32(scale1); +#if __aarch64__ + if (elempack == 8) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale0); + _p2 = vmulq_f32(_p2, _scale1); + _p3 = vmulq_f32(_p3, _scale1); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p1, _p3); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4x2_t _t01 = vzip_s16(_t0, _t1); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += A_hstep * 8; + } + } +#endif // __aarch64__ + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + A_hstep * 4); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + _p2 = vmulq_f32(_p2, _scale0); + _p3 = vmulq_f32(_p3, _scale1); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p1, _p3); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p1, _p3)); + int16x4x2_t _t01 = vzip_s16(_t0, _t1); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += A_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + uint16x8_t _p = vld1q_u16(p0); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r01 = float2int8(_p0, _p1); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p1)); + float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p1)); + int8x8_t _r01 = float2int8(_t0, _t1); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r01); + + pp += 8; + p0 += A_hstep * 4; + } + } +#endif // __ARM_NEON + if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + float32x4_t _scale = vzipq_f32(_scale0, _scale1).val[0]; + for (; kk + 7 < max_kk; kk += 8) + { +#if __ARM_FEATURE_DOTPROD + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[1], _p, 1); + _p = vsetq_lane_u16(p0[A_hstep], _p, 2); + _p = vsetq_lane_u16(p0[A_hstep + 1], _p, 3); + _p = vsetq_lane_u16(p0[A_hstep * 2], _p, 4); + _p = vsetq_lane_u16(p0[A_hstep * 2 + 1], _p, 5); + _p = vsetq_lane_u16(p0[A_hstep * 3], _p, 6); + _p = vsetq_lane_u16(p0[A_hstep * 3 + 1], _p, 7); + uint16x8_t _q = uint16x8_t(); + _q = vsetq_lane_u16(p0[A_hstep * 4], _q, 0); + _q = vsetq_lane_u16(p0[A_hstep * 4 + 1], _q, 1); + _q = vsetq_lane_u16(p0[A_hstep * 5], _q, 2); + _q = vsetq_lane_u16(p0[A_hstep * 5 + 1], _q, 3); + _q = vsetq_lane_u16(p0[A_hstep * 6], _q, 4); + _q = vsetq_lane_u16(p0[A_hstep * 6 + 1], _q, 5); + _q = vsetq_lane_u16(p0[A_hstep * 7], _q, 6); + _q = vsetq_lane_u16(p0[A_hstep * 7 + 1], _q, 7); + float32x4_t _p01 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p23 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p45 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p67 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + + _p01 = vmulq_f32(_p01, _scale); + _p23 = vmulq_f32(_p23, _scale); + _p45 = vmulq_f32(_p45, _scale); + _p67 = vmulq_f32(_p67, _scale); + + int8x8_t _r0 = float2int8(_p01, _p23); + int8x8_t _r1 = float2int8(_p45, _p67); + +#if __ARM_FEATURE_MATMUL_INT8 + int8x8x2_t _r01 = vuzp_s8(_r0, _r1); + + vst1q_s8(pp, vcombine_s8(_r01.val[0], _r01.val[1])); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8x2_t _r01 = vtrn_s8(_r0, _r1); + int8x8x2_t _rr01 = vuzp_s8(_r01.val[0], _r01.val[1]); + + vst1q_s8(pp, vcombine_s8(_rr01.val[0], _rr01.val[1])); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[1], _p, 1); + _p = vsetq_lane_u16(p0[A_hstep * 2], _p, 2); + _p = vsetq_lane_u16(p0[A_hstep * 2 + 1], _p, 3); + _p = vsetq_lane_u16(p0[A_hstep * 4], _p, 4); + _p = vsetq_lane_u16(p0[A_hstep * 4 + 1], _p, 5); + _p = vsetq_lane_u16(p0[A_hstep * 6], _p, 6); + _p = vsetq_lane_u16(p0[A_hstep * 6 + 1], _p, 7); + uint16x8_t _q = uint16x8_t(); + _q = vsetq_lane_u16(p0[A_hstep], _q, 0); + _q = vsetq_lane_u16(p0[A_hstep + 1], _q, 1); + _q = vsetq_lane_u16(p0[A_hstep * 3], _q, 2); + _q = vsetq_lane_u16(p0[A_hstep * 3 + 1], _q, 3); + _q = vsetq_lane_u16(p0[A_hstep * 5], _q, 4); + _q = vsetq_lane_u16(p0[A_hstep * 5 + 1], _q, 5); + _q = vsetq_lane_u16(p0[A_hstep * 7], _q, 6); + _q = vsetq_lane_u16(p0[A_hstep * 7 + 1], _q, 7); + float32x4_t _p02 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p46 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p13 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p57 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + + _p02 = vmulq_f32(_p02, _scale); + _p46 = vmulq_f32(_p46, _scale); + _p13 = vmulq_f32(_p13, _scale); + _p57 = vmulq_f32(_p57, _scale); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p02, _p46); + _r01.val[1] = float2int8(_p13, _p57); + + vst2_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 16; + p0 += A_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[1], _p, 1); + _p = vsetq_lane_u16(p0[A_hstep], _p, 2); + _p = vsetq_lane_u16(p0[A_hstep + 1], _p, 3); + _p = vsetq_lane_u16(p0[A_hstep * 2], _p, 4); + _p = vsetq_lane_u16(p0[A_hstep * 2 + 1], _p, 5); + _p = vsetq_lane_u16(p0[A_hstep * 3], _p, 6); + _p = vsetq_lane_u16(p0[A_hstep * 3 + 1], _p, 7); + float32x4_t _p01 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p23 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + + _p01 = vmulq_f32(_p01, _scale); + _p23 = vmulq_f32(_p23, _scale); + + float32x4x2_t _pp = vuzpq_f32(_p01, _p23); + int8x8_t _r01 = float2int8(_pp.val[0], _pp.val[1]); +#else // __ARM_FEATURE_DOTPROD + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[1], _p, 1); + _p = vsetq_lane_u16(p0[A_hstep * 2], _p, 2); + _p = vsetq_lane_u16(p0[A_hstep * 2 + 1], _p, 3); + _p = vsetq_lane_u16(p0[A_hstep], _p, 4); + _p = vsetq_lane_u16(p0[A_hstep + 1], _p, 5); + _p = vsetq_lane_u16(p0[A_hstep * 3], _p, 6); + _p = vsetq_lane_u16(p0[A_hstep * 3 + 1], _p, 7); + float32x4_t _p02 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p13 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + + _p02 = vmulq_f32(_p02, _scale); + _p13 = vmulq_f32(_p13, _scale); + + float32x4x2_t _pp = vzipq_f32(_p02, _p13); + int8x8_t _r01 = float2int8(_pp.val[0], _pp.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r01); + + pp += 8; + p0 += A_hstep * 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = float2int8(float16_to_float32(p0[0]) * scale0); + pp[1] = float2int8(float16_to_float32(p0[A_hstep + 0]) * scale0); + pp[2] = float2int8(float16_to_float32(p0[1]) * scale1); + pp[3] = float2int8(float16_to_float32(p0[A_hstep + 1]) * scale1); + pp += 4; + p0 += A_hstep * 2; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(float16_to_float32(p0[0]) * scale0); + pp[1] = float2int8(float16_to_float32(p0[1]) * scale1); + pp += 2; + p0 += A_hstep; + } + } + } + for (; ii < max_ii; ii += 1) + { + const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * elempack; + + const float scale = scales[ii]; + +#if __ARM_NEON + float32x4_t _scale = vdupq_n_f32(scale); +#if __aarch64__ + if (elempack == 8) + { + int kk = 0; + for (; kk + 15 < max_kk; kk += 16) + { + uint16x8_t _p01 = vld1q_u16(p0); + uint16x8_t _p23 = vld1q_u16(p0 + A_hstep * 8); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p01)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p01)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p23)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p23)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + int8x8_t _r23 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r01, _r23)); + + pp += 16; + p0 += A_hstep * 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p01 = vld1q_u16(p0); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p01)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p01)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + pp += 8; + p0 += A_hstep * 8; + } + } +#endif // __aarch64__ + if (elempack == 4) + { + int kk = 0; + for (; kk + 15 < max_kk; kk += 16) + { + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vld1_u16(p0)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + A_hstep * 4)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + A_hstep * 8)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + A_hstep * 12)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + int8x8_t _r23 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r01, _r23)); + + pp += 16; + p0 += A_hstep * 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vld1_u16(p0)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + A_hstep * 4)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += A_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(float16_to_float32(p0[0]) * scale); + pp[1] = float2int8(float16_to_float32(p0[1]) * scale); + pp[2] = float2int8(float16_to_float32(p0[2]) * scale); + pp[3] = float2int8(float16_to_float32(p0[3]) * scale); + pp += 4; + p0 += A_hstep * 4; + } + } +#endif // __ARM_NEON + if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + for (; kk + 15 < max_kk; kk += 16) + { + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[A_hstep], _p, 1); + _p = vsetq_lane_u16(p0[A_hstep * 2], _p, 2); + _p = vsetq_lane_u16(p0[A_hstep * 3], _p, 3); + _p = vsetq_lane_u16(p0[A_hstep * 4], _p, 4); + _p = vsetq_lane_u16(p0[A_hstep * 5], _p, 5); + _p = vsetq_lane_u16(p0[A_hstep * 6], _p, 6); + _p = vsetq_lane_u16(p0[A_hstep * 7], _p, 7); + uint16x8_t _q = uint16x8_t(); + _q = vsetq_lane_u16(p0[A_hstep * 8], _q, 0); + _q = vsetq_lane_u16(p0[A_hstep * 9], _q, 1); + _q = vsetq_lane_u16(p0[A_hstep * 10], _q, 2); + _q = vsetq_lane_u16(p0[A_hstep * 11], _q, 3); + _q = vsetq_lane_u16(p0[A_hstep * 12], _q, 4); + _q = vsetq_lane_u16(p0[A_hstep * 13], _q, 5); + _q = vsetq_lane_u16(p0[A_hstep * 14], _q, 6); + _q = vsetq_lane_u16(p0[A_hstep * 15], _q, 7); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + int8x8_t _r23 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r01, _r23)); + + pp += 16; + p0 += A_hstep * 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[A_hstep], _p, 1); + _p = vsetq_lane_u16(p0[A_hstep * 2], _p, 2); + _p = vsetq_lane_u16(p0[A_hstep * 3], _p, 3); + _p = vsetq_lane_u16(p0[A_hstep * 4], _p, 4); + _p = vsetq_lane_u16(p0[A_hstep * 5], _p, 5); + _p = vsetq_lane_u16(p0[A_hstep * 6], _p, 6); + _p = vsetq_lane_u16(p0[A_hstep * 7], _p, 7); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += A_hstep * 8; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(float16_to_float32(p0[0]) * scale); + pp += 1; + p0 += A_hstep; + } + } + } +} + +static void compute_B_fp16_int8_scale(const Mat& B, float& scale) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM82 && __aarch64__ && !__ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (ncnn::cpu_support_arm_asimdhp()) + { + compute_B_fp16_int8_scale_asimdhp(B, scale); + return; + } +#endif + + float absmax = 0.f; +#if __ARM_NEON +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + float16x8_t _absmax0 = vdupq_n_f16((__fp16)0.f); + float16x8_t _absmax1 = vdupq_n_f16((__fp16)0.f); + float16x8_t _absmax2 = vdupq_n_f16((__fp16)0.f); + float16x8_t _absmax3 = vdupq_n_f16((__fp16)0.f); + float16x4_t _amax = vdup_n_f16((__fp16)0.f); +#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + float32x4_t _absmax0 = vdupq_n_f32(0.f); + float32x4_t _absmax1 = vdupq_n_f32(0.f); + float32x4_t _absmax2 = vdupq_n_f32(0.f); + float32x4_t _absmax3 = vdupq_n_f32(0.f); +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif + for (int i = 0; i < (B.dims == 3 ? B.c : B.h); i++) + { + const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; + + const int size = B.w * B.elempack; + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + const __fp16* ptr = (const __fp16*)B + i * B_hstep * B.elempack; + + int j = 0; + for (; j + 31 < size; j += 32) + { + float16x8_t _p0 = vld1q_f16(ptr); + float16x8_t _p1 = vld1q_f16(ptr + 8); + float16x8_t _p2 = vld1q_f16(ptr + 16); + float16x8_t _p3 = vld1q_f16(ptr + 24); + _absmax0 = vmaxq_f16(_absmax0, vabsq_f16(_p0)); + _absmax1 = vmaxq_f16(_absmax1, vabsq_f16(_p1)); + _absmax2 = vmaxq_f16(_absmax2, vabsq_f16(_p2)); + _absmax3 = vmaxq_f16(_absmax3, vabsq_f16(_p3)); + ptr += 32; + } + for (; j + 15 < size; j += 16) + { + float16x8_t _p0 = vld1q_f16(ptr); + float16x8_t _p1 = vld1q_f16(ptr + 8); + _absmax0 = vmaxq_f16(_absmax0, vabsq_f16(_p0)); + _absmax1 = vmaxq_f16(_absmax1, vabsq_f16(_p1)); + ptr += 16; + } + for (; j + 7 < size; j += 8) + { + float16x8_t _p = vld1q_f16(ptr); + _absmax0 = vmaxq_f16(_absmax0, vabsq_f16(_p)); + ptr += 8; + } + for (; j + 3 < size; j += 4) + { + float16x4_t _p = vld1_f16(ptr); + _amax = vmax_f16(_amax, vabs_f16(_p)); + ptr += 4; + } + for (; j < size; j++) + { + absmax = std::max(absmax, (float)fabsf((float)ptr[0])); + ptr++; + } +#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + const unsigned short* ptr = (const unsigned short*)B + i * B_hstep * B.elempack; + + int j = 0; +#if __ARM_NEON + for (; j + 15 < size; j += 16) + { + uint16x8_t _p = vld1q_u16(ptr); + uint16x8_t _q = vld1q_u16(ptr + 8); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + _absmax2 = vmaxq_f32(_absmax2, vabsq_f32(_p2)); + _absmax3 = vmaxq_f32(_absmax3, vabsq_f32(_p3)); + ptr += 16; + } + for (; j + 7 < size; j += 8) + { + uint16x8_t _p = vld1q_u16(ptr); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + ptr += 8; + } + for (; j + 3 < size; j += 4) + { + float32x4_t _p = vcvt_f32_f16((float16x4_t)vld1_u16(ptr)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p)); + ptr += 4; + } +#endif // __ARM_NEON + for (; j < size; j++) + { + absmax = std::max(absmax, (float)fabsf(float16_to_float32(ptr[0]))); + ptr++; + } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + } +#if __ARM_NEON +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + _absmax0 = vmaxq_f16(_absmax0, _absmax2); + _absmax1 = vmaxq_f16(_absmax1, _absmax3); + _absmax0 = vmaxq_f16(_absmax0, _absmax1); + absmax = std::max(absmax, (float)vmaxvq_f16(_absmax0)); + absmax = std::max(absmax, (float)vmaxv_f16(_amax)); +#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + _absmax0 = vmaxq_f32(_absmax0, _absmax2); + _absmax1 = vmaxq_f32(_absmax1, _absmax3); + _absmax0 = vmaxq_f32(_absmax0, _absmax1); +#if __aarch64__ + absmax = std::max(absmax, vmaxvq_f32(_absmax0)); +#else + float32x2_t _aa = vmax_f32(vget_low_f32(_absmax0), vget_high_f32(_absmax0)); + absmax = std::max(absmax, std::max(vget_lane_f32(_aa, 0), vget_lane_f32(_aa, 1))); +#endif +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif // __ARM_NEON + + scale = absmax == 0.f ? 1.f : 127.f / absmax; +} + +static void pack_B_tile_fp16_to_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + pack_B_tile_fp16_to_int8_i8mm(B, BT, j, max_jj, k, max_kk, scale); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + pack_B_tile_fp16_to_int8_asimddp(B, BT, j, max_jj, k, max_kk, scale); + return; + } +#endif + + const int elempack = B.elempack; + const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; + + // NCNN_LOGE("pack_B_tile_fp16_to_int8 %d %d", max_jj, elempack); + + signed char* pp = BT; + +#if __ARM_NEON + float32x4_t _scale = vdupq_n_f32(scale); +#endif + + int jj = 0; +#if __ARM_NEON +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k * elempack; + +#if __aarch64__ + if (elempack == 8) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + 16); + uint16x8_t _s = vld1q_u16(p0 + 24); + uint16x8_t _t = vld1q_u16(p0 + 32); + uint16x8_t _u = vld1q_u16(p0 + 40); + uint16x8_t _v = vld1q_u16(p0 + 48); + uint16x8_t _w = vld1q_u16(p0 + 56); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + float32x4_t _p4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_r)); + float32x4_t _p5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_r)); + float32x4_t _p6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_s)); + float32x4_t _p7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_s)); + float32x4_t _p8 = vcvt_f32_f16((float16x4_t)vget_low_u16(_t)); + float32x4_t _p9 = vcvt_f32_f16((float16x4_t)vget_high_u16(_t)); + float32x4_t _pa = vcvt_f32_f16((float16x4_t)vget_low_u16(_u)); + float32x4_t _pb = vcvt_f32_f16((float16x4_t)vget_high_u16(_u)); + float32x4_t _pc = vcvt_f32_f16((float16x4_t)vget_low_u16(_v)); + float32x4_t _pd = vcvt_f32_f16((float16x4_t)vget_high_u16(_v)); + float32x4_t _pe = vcvt_f32_f16((float16x4_t)vget_low_u16(_w)); + float32x4_t _pf = vcvt_f32_f16((float16x4_t)vget_high_u16(_w)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + _p8 = vmulq_f32(_p8, _scale); + _p9 = vmulq_f32(_p9, _scale); + _pa = vmulq_f32(_pa, _scale); + _pb = vmulq_f32(_pb, _scale); + _pc = vmulq_f32(_pc, _scale); + _pd = vmulq_f32(_pd, _scale); + _pe = vmulq_f32(_pe, _scale); + _pf = vmulq_f32(_pf, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8x2_t _p04 = vzip_s8(float2int8(_p0, _p1), float2int8(_p8, _p9)); + int8x8x2_t _p15 = vzip_s8(float2int8(_p2, _p3), float2int8(_pa, _pb)); + int8x8x2_t _p26 = vzip_s8(float2int8(_p4, _p5), float2int8(_pc, _pd)); + int8x8x2_t _p37 = vzip_s8(float2int8(_p6, _p7), float2int8(_pe, _pf)); + + int8x16x4_t _rr; + _rr.val[0] = vcombine_s8(_p04.val[0], _p04.val[1]); + _rr.val[1] = vcombine_s8(_p15.val[0], _p15.val[1]); + _rr.val[2] = vcombine_s8(_p26.val[0], _p26.val[1]); + _rr.val[3] = vcombine_s8(_p37.val[0], _p37.val[1]); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x16x4_t _rr; + _rr.val[0] = vcombine_s8(float2int8(_p0, _p1), float2int8(_p8, _p9)); + _rr.val[1] = vcombine_s8(float2int8(_p2, _p3), float2int8(_pa, _pb)); + _rr.val[2] = vcombine_s8(float2int8(_p4, _p5), float2int8(_pc, _pd)); + _rr.val[3] = vcombine_s8(float2int8(_p6, _p7), float2int8(_pe, _pf)); +#endif // __ARM_FEATURE_MATMUL_INT8 + + vst4q_s8(pp, _rr); +#else // __ARM_FEATURE_DOTPROD + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p1), float2int8(_p4, _p5)); + _r01.val[1] = vcombine_s8(float2int8(_p2, _p3), float2int8(_p6, _p7)); + int8x16x2_t _r23; + _r23.val[0] = vcombine_s8(float2int8(_p8, _p9), float2int8(_pc, _pd)); + _r23.val[1] = vcombine_s8(float2int8(_pa, _pb), float2int8(_pe, _pf)); + + vst2q_s8(pp, _r01); + vst2q_s8(pp + 32, _r23); +#endif // __ARM_FEATURE_DOTPROD + + pp += 64; + p0 += 64; + } + for (; kk + 3 < max_kk; kk += 4) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + 16); + uint16x8_t _s = vld1q_u16(p0 + 24); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + float32x4_t _p4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_r)); + float32x4_t _p5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_r)); + float32x4_t _p6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_s)); + float32x4_t _p7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_s)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + +#if __ARM_FEATURE_DOTPROD + int8x8x4_t _r0123; + _r0123.val[0] = float2int8(_p0, _p1); + _r0123.val[1] = float2int8(_p2, _p3); + _r0123.val[2] = float2int8(_p4, _p5); + _r0123.val[3] = float2int8(_p6, _p7); + + vst4_s8(pp, _r0123); +#else // __ARM_FEATURE_DOTPROD + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p1), float2int8(_p4, _p5)); + _r01.val[1] = vcombine_s8(float2int8(_p2, _p3), float2int8(_p6, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += 32; + } + for (; kk + 1 < max_kk; kk += 2) + { + uint16x8_t _p01 = vld1q_u16(p0); + uint16x8_t _p23 = vld1q_u16(p0 + 8); + + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p01)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p01)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p23)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p23)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p1); + _r01.val[1] = float2int8(_p2, _p3); + + vst2_s8(pp, _r01); + + pp += 16; + p0 += 16; + } + for (; kk < max_kk; kk++) + { + uint16x8_t _p01 = vld1q_u16(p0); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p01)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p01)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += 8; + } + } +#endif // __aarch64__ + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { +#if __ARM_FEATURE_DOTPROD + uint16x8x4_t _p = vld4q_u16(p0); + uint16x8x4_t _q = vld4q_u16(p0 + B_hstep * 4); + + float32x4_t _p0 = vmulq_f32(vcvt_f32_f16((float16x4_t)vget_low_u16(_p.val[0])), _scale); + float32x4_t _p1 = vmulq_f32(vcvt_f32_f16((float16x4_t)vget_low_u16(_p.val[1])), _scale); + float32x4_t _p2 = vmulq_f32(vcvt_f32_f16((float16x4_t)vget_low_u16(_p.val[2])), _scale); + float32x4_t _p3 = vmulq_f32(vcvt_f32_f16((float16x4_t)vget_low_u16(_p.val[3])), _scale); + float32x4_t _p4 = vmulq_f32(vcvt_f32_f16((float16x4_t)vget_high_u16(_p.val[0])), _scale); + float32x4_t _p5 = vmulq_f32(vcvt_f32_f16((float16x4_t)vget_high_u16(_p.val[1])), _scale); + float32x4_t _p6 = vmulq_f32(vcvt_f32_f16((float16x4_t)vget_high_u16(_p.val[2])), _scale); + float32x4_t _p7 = vmulq_f32(vcvt_f32_f16((float16x4_t)vget_high_u16(_p.val[3])), _scale); + float32x4_t _p8 = vmulq_f32(vcvt_f32_f16((float16x4_t)vget_low_u16(_q.val[0])), _scale); + float32x4_t _p9 = vmulq_f32(vcvt_f32_f16((float16x4_t)vget_low_u16(_q.val[1])), _scale); + float32x4_t _pa = vmulq_f32(vcvt_f32_f16((float16x4_t)vget_low_u16(_q.val[2])), _scale); + float32x4_t _pb = vmulq_f32(vcvt_f32_f16((float16x4_t)vget_low_u16(_q.val[3])), _scale); + float32x4_t _pc = vmulq_f32(vcvt_f32_f16((float16x4_t)vget_high_u16(_q.val[0])), _scale); + float32x4_t _pd = vmulq_f32(vcvt_f32_f16((float16x4_t)vget_high_u16(_q.val[1])), _scale); + float32x4_t _pe = vmulq_f32(vcvt_f32_f16((float16x4_t)vget_high_u16(_q.val[2])), _scale); + float32x4_t _pf = vmulq_f32(vcvt_f32_f16((float16x4_t)vget_high_u16(_q.val[3])), _scale); + +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p4); + int8x8_t _r1 = float2int8(_p1, _p5); + int8x8_t _r2 = float2int8(_p2, _p6); + int8x8_t _r3 = float2int8(_p3, _p7); + int8x8_t _r4 = float2int8(_p8, _pc); + int8x8_t _r5 = float2int8(_p9, _pd); + int8x8_t _r6 = float2int8(_pa, _pe); + int8x8_t _r7 = float2int8(_pb, _pf); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p8, _p9); + int8x8_t _r3 = float2int8(_pa, _pb); + int8x8_t _r4 = float2int8(_p4, _p5); + int8x8_t _r5 = float2int8(_p6, _p7); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); +#endif // __ARM_FEATURE_MATMUL_INT8 + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); +#else // __ARM_FEATURE_DOTPROD + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + 16); + uint16x8_t _s = vld1q_u16(p0 + 24); + uint16x8_t _t = vld1q_u16(p0 + B_hstep * 4); + uint16x8_t _u = vld1q_u16(p0 + B_hstep * 4 + 8); + uint16x8_t _v = vld1q_u16(p0 + B_hstep * 4 + 16); + uint16x8_t _w = vld1q_u16(p0 + B_hstep * 4 + 24); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + float32x4_t _p4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_r)); + float32x4_t _p5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_r)); + float32x4_t _p6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_s)); + float32x4_t _p7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_s)); + float32x4_t _p8 = vcvt_f32_f16((float16x4_t)vget_low_u16(_t)); + float32x4_t _p9 = vcvt_f32_f16((float16x4_t)vget_high_u16(_t)); + float32x4_t _pa = vcvt_f32_f16((float16x4_t)vget_low_u16(_u)); + float32x4_t _pb = vcvt_f32_f16((float16x4_t)vget_high_u16(_u)); + float32x4_t _pc = vcvt_f32_f16((float16x4_t)vget_low_u16(_v)); + float32x4_t _pd = vcvt_f32_f16((float16x4_t)vget_high_u16(_v)); + float32x4_t _pe = vcvt_f32_f16((float16x4_t)vget_low_u16(_w)); + float32x4_t _pf = vcvt_f32_f16((float16x4_t)vget_high_u16(_w)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + _p8 = vmulq_f32(_p8, _scale); + _p9 = vmulq_f32(_p9, _scale); + _pa = vmulq_f32(_pa, _scale); + _pb = vmulq_f32(_pb, _scale); + _pc = vmulq_f32(_pc, _scale); + _pd = vmulq_f32(_pd, _scale); + _pe = vmulq_f32(_pe, _scale); + _pf = vmulq_f32(_pf, _scale); + + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p8), float2int8(_p2, _pa)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p9), float2int8(_p3, _pb)); + int8x16x2_t _r23; + _r23.val[0] = vcombine_s8(float2int8(_p4, _pc), float2int8(_p6, _pe)); + _r23.val[1] = vcombine_s8(float2int8(_p5, _pd), float2int8(_p7, _pf)); + + vst2q_s8(pp, _r01); + vst2q_s8(pp + 32, _r23); +#endif // __ARM_FEATURE_DOTPROD + + pp += 64; + p0 += 32; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + uint16x4x4_t _p = vld4_u16(p0); + uint16x4x4_t _q = vld4_u16(p0 + B_hstep * 4); + + float32x4_t _p0 = vmulq_f32(vcvt_f32_f16((float16x4_t)_p.val[0]), _scale); + float32x4_t _p1 = vmulq_f32(vcvt_f32_f16((float16x4_t)_p.val[1]), _scale); + float32x4_t _p2 = vmulq_f32(vcvt_f32_f16((float16x4_t)_p.val[2]), _scale); + float32x4_t _p3 = vmulq_f32(vcvt_f32_f16((float16x4_t)_p.val[3]), _scale); + float32x4_t _p4 = vmulq_f32(vcvt_f32_f16((float16x4_t)_q.val[0]), _scale); + float32x4_t _p5 = vmulq_f32(vcvt_f32_f16((float16x4_t)_q.val[1]), _scale); + float32x4_t _p6 = vmulq_f32(vcvt_f32_f16((float16x4_t)_q.val[2]), _scale); + float32x4_t _p7 = vmulq_f32(vcvt_f32_f16((float16x4_t)_q.val[3]), _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); +#else // __ARM_FEATURE_DOTPROD + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + B_hstep * 4); + uint16x8_t _s = vld1q_u16(p0 + B_hstep * 4 + 8); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + float32x4_t _p4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_r)); + float32x4_t _p5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_r)); + float32x4_t _p6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_s)); + float32x4_t _p7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_s)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p4), float2int8(_p2, _p6)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p5), float2int8(_p3, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += 16; + } + for (; kk + 1 < max_kk; kk += 2) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + B_hstep * 4); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p2); + _r01.val[1] = float2int8(_p1, _p3); + + vst2_s8(pp, _r01); + + pp += 16; + p0 += 8; + } + for (; kk < max_kk; kk++) + { + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vld1_u16(p0)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + B_hstep * 4)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + + vst1_s8(pp, _r0); + + pp += 8; + p0 += 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + B_hstep); + uint16x8_t _r = vld1q_u16(p0 + B_hstep * 2); + uint16x8_t _s = vld1q_u16(p0 + B_hstep * 3); + uint16x8_t _t = vld1q_u16(p0 + B_hstep * 4); + uint16x8_t _u = vld1q_u16(p0 + B_hstep * 5); + uint16x8_t _v = vld1q_u16(p0 + B_hstep * 6); + uint16x8_t _w = vld1q_u16(p0 + B_hstep * 7); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + float32x4_t _p4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_r)); + float32x4_t _p5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_r)); + float32x4_t _p6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_s)); + float32x4_t _p7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_s)); + float32x4_t _p8 = vcvt_f32_f16((float16x4_t)vget_low_u16(_t)); + float32x4_t _p9 = vcvt_f32_f16((float16x4_t)vget_high_u16(_t)); + float32x4_t _pa = vcvt_f32_f16((float16x4_t)vget_low_u16(_u)); + float32x4_t _pb = vcvt_f32_f16((float16x4_t)vget_high_u16(_u)); + float32x4_t _pc = vcvt_f32_f16((float16x4_t)vget_low_u16(_v)); + float32x4_t _pd = vcvt_f32_f16((float16x4_t)vget_high_u16(_v)); + float32x4_t _pe = vcvt_f32_f16((float16x4_t)vget_low_u16(_w)); + float32x4_t _pf = vcvt_f32_f16((float16x4_t)vget_high_u16(_w)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + _p8 = vmulq_f32(_p8, _scale); + _p9 = vmulq_f32(_p9, _scale); + _pa = vmulq_f32(_pa, _scale); + _pb = vmulq_f32(_pb, _scale); + _pc = vmulq_f32(_pc, _scale); + _pd = vmulq_f32(_pd, _scale); + _pe = vmulq_f32(_pe, _scale); + _pf = vmulq_f32(_pf, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p4, _p6); + int8x8_t _r2 = float2int8(_p8, _pa); + int8x8_t _r3 = float2int8(_pc, _pe); + int8x8_t _r4 = float2int8(_p1, _p3); + int8x8_t _r5 = float2int8(_p5, _p7); + int8x8_t _r6 = float2int8(_p9, _pb); + int8x8_t _r7 = float2int8(_pd, _pf); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p4, _p6)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p8, _pa)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_pc, _pe)); + int16x4_t _t4 = vreinterpret_s16_s8(float2int8(_p1, _p3)); + int16x4_t _t5 = vreinterpret_s16_s8(float2int8(_p5, _p7)); + int16x4_t _t6 = vreinterpret_s16_s8(float2int8(_p9, _pb)); + int16x4_t _t7 = vreinterpret_s16_s8(float2int8(_pd, _pf)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int16x4x2_t _t45 = vuzp_s16(_t4, _t5); + int16x4x2_t _t67 = vuzp_s16(_t6, _t7); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r2 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); + int8x8_t _r4 = vreinterpret_s8_s16(_t45.val[0]); + int8x8_t _r5 = vreinterpret_s8_s16(_t67.val[0]); + int8x8_t _r6 = vreinterpret_s8_s16(_t45.val[1]); + int8x8_t _r7 = vreinterpret_s8_s16(_t67.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); + + pp += 64; + p0 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vld1_u16(p0)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + B_hstep)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + B_hstep * 2)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + B_hstep * 3)); + float32x4_t _p4 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + B_hstep * 4)); + float32x4_t _p5 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + B_hstep * 5)); + float32x4_t _p6 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + B_hstep * 6)); + float32x4_t _p7 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + B_hstep * 7)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p4, _p5)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p6, _p7)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r2 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + + pp += 32; + p0 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[1], _p, 1); + _p = vsetq_lane_u16(p0[B_hstep], _p, 2); + _p = vsetq_lane_u16(p0[B_hstep + 1], _p, 3); + _p = vsetq_lane_u16(p0[B_hstep * 2], _p, 4); + _p = vsetq_lane_u16(p0[B_hstep * 2 + 1], _p, 5); + _p = vsetq_lane_u16(p0[B_hstep * 3], _p, 6); + _p = vsetq_lane_u16(p0[B_hstep * 3 + 1], _p, 7); + uint16x8_t _q = uint16x8_t(); + _q = vsetq_lane_u16(p0[B_hstep * 4], _q, 0); + _q = vsetq_lane_u16(p0[B_hstep * 4 + 1], _q, 1); + _q = vsetq_lane_u16(p0[B_hstep * 5], _q, 2); + _q = vsetq_lane_u16(p0[B_hstep * 5 + 1], _q, 3); + _q = vsetq_lane_u16(p0[B_hstep * 6], _q, 4); + _q = vsetq_lane_u16(p0[B_hstep * 6 + 1], _q, 5); + _q = vsetq_lane_u16(p0[B_hstep * 7], _q, 6); + _q = vsetq_lane_u16(p0[B_hstep * 7 + 1], _q, 7); + float32x4_t _p01 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p23 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p45 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p67 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + + _p01 = vmulq_f32(_p01, _scale); + _p23 = vmulq_f32(_p23, _scale); + _p45 = vmulq_f32(_p45, _scale); + _p67 = vmulq_f32(_p67, _scale); + + int8x8_t _r0 = float2int8(_p01, _p23); + int8x8_t _r1 = float2int8(_p45, _p67); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += 2; + } + for (; kk < max_kk; kk++) + { + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[B_hstep], _p, 1); + _p = vsetq_lane_u16(p0[B_hstep * 2], _p, 2); + _p = vsetq_lane_u16(p0[B_hstep * 3], _p, 3); + _p = vsetq_lane_u16(p0[B_hstep * 4], _p, 4); + _p = vsetq_lane_u16(p0[B_hstep * 5], _p, 5); + _p = vsetq_lane_u16(p0[B_hstep * 6], _p, 6); + _p = vsetq_lane_u16(p0[B_hstep * 7], _p, 7); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + + vst1_s8(pp, _r0); + + pp += 8; + p0++; + } + } + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k * elempack; + + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { +#if __ARM_FEATURE_DOTPROD + uint16x8x4_t _p = vld4q_u16(p0); + + float32x4_t _p0 = vmulq_f32(vcvt_f32_f16((float16x4_t)vget_low_u16(_p.val[0])), _scale); + float32x4_t _p1 = vmulq_f32(vcvt_f32_f16((float16x4_t)vget_low_u16(_p.val[1])), _scale); + float32x4_t _p2 = vmulq_f32(vcvt_f32_f16((float16x4_t)vget_low_u16(_p.val[2])), _scale); + float32x4_t _p3 = vmulq_f32(vcvt_f32_f16((float16x4_t)vget_low_u16(_p.val[3])), _scale); + float32x4_t _p4 = vmulq_f32(vcvt_f32_f16((float16x4_t)vget_high_u16(_p.val[0])), _scale); + float32x4_t _p5 = vmulq_f32(vcvt_f32_f16((float16x4_t)vget_high_u16(_p.val[1])), _scale); + float32x4_t _p6 = vmulq_f32(vcvt_f32_f16((float16x4_t)vget_high_u16(_p.val[2])), _scale); + float32x4_t _p7 = vmulq_f32(vcvt_f32_f16((float16x4_t)vget_high_u16(_p.val[3])), _scale); + +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p4); + int8x8_t _r1 = float2int8(_p1, _p5); + int8x8_t _r2 = float2int8(_p2, _p6); + int8x8_t _r3 = float2int8(_p3, _p7); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#endif // __ARM_FEATURE_MATMUL_INT8 + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); +#else // __ARM_FEATURE_DOTPROD + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + 16); + uint16x8_t _s = vld1q_u16(p0 + 24); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + float32x4_t _p4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_r)); + float32x4_t _p5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_r)); + float32x4_t _p6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_s)); + float32x4_t _p7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_s)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p2), float2int8(_p4, _p6)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p3), float2int8(_p5, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += 32; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + uint16x4x4_t _p = vld4_u16(p0); + + float32x4_t _p0 = vmulq_f32(vcvt_f32_f16((float16x4_t)_p.val[0]), _scale); + float32x4_t _p1 = vmulq_f32(vcvt_f32_f16((float16x4_t)_p.val[1]), _scale); + float32x4_t _p2 = vmulq_f32(vcvt_f32_f16((float16x4_t)_p.val[2]), _scale); + float32x4_t _p3 = vmulq_f32(vcvt_f32_f16((float16x4_t)_p.val[3]), _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); +#else // __ARM_FEATURE_DOTPROD + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p2); + _r01.val[1] = float2int8(_p1, _p3); + + vst2_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 16; + p0 += 16; + } + for (; kk + 1 < max_kk; kk += 2) + { + uint16x8_t _p = vld1q_u16(p0); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + float32x4x2_t _p01 = vzipq_f32(_p0, _p1); + + int8x8_t _r01 = float2int8(_p01.val[0], _p01.val[1]); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += 8; + } + for (; kk < max_kk; kk++) + { + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vld1_u16(p0)); + _p0 = vmulq_f32(_p0, _scale); + int8x8_t _r0 = float2int8(_p0, _p0); + + pp[0] = vget_lane_s8(_r0, 0); + pp[1] = vget_lane_s8(_r0, 1); + pp[2] = vget_lane_s8(_r0, 2); + pp[3] = vget_lane_s8(_r0, 3); + + pp += 4; + p0 += 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + B_hstep); + uint16x8_t _r = vld1q_u16(p0 + B_hstep * 2); + uint16x8_t _s = vld1q_u16(p0 + B_hstep * 3); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + float32x4_t _p4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_r)); + float32x4_t _p5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_r)); + float32x4_t _p6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_s)); + float32x4_t _p7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_s)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p4, _p6); + int8x8_t _r2 = float2int8(_p1, _p3); + int8x8_t _r3 = float2int8(_p5, _p7); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p4, _p6)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p1, _p3)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p5, _p7)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r2 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + + pp += 32; + p0 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vld1_u16(p0)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + B_hstep)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + B_hstep * 2)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + B_hstep * 3)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[1], _p, 1); + _p = vsetq_lane_u16(p0[B_hstep], _p, 2); + _p = vsetq_lane_u16(p0[B_hstep + 1], _p, 3); + _p = vsetq_lane_u16(p0[B_hstep * 2], _p, 4); + _p = vsetq_lane_u16(p0[B_hstep * 2 + 1], _p, 5); + _p = vsetq_lane_u16(p0[B_hstep * 3], _p, 6); + _p = vsetq_lane_u16(p0[B_hstep * 3 + 1], _p, 7); + float32x4_t _p01 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p23 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + + _p01 = vmulq_f32(_p01, _scale); + _p23 = vmulq_f32(_p23, _scale); + + int8x8_t _r0 = float2int8(_p01, _p23); + + vst1_s8(pp, _r0); + + pp += 8; + p0 += 2; + } + for (; kk < max_kk; kk++) + { + uint16x4_t _p = uint16x4_t(); + _p = vset_lane_u16(p0[0], _p, 0); + _p = vset_lane_u16(p0[B_hstep], _p, 1); + _p = vset_lane_u16(p0[B_hstep * 2], _p, 2); + _p = vset_lane_u16(p0[B_hstep * 3], _p, 3); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)_p); + + _p0 = vmulq_f32(_p0, _scale); + int8x8_t _r0 = float2int8(_p0, _p0); + + pp[0] = vget_lane_s8(_r0, 0); + pp[1] = vget_lane_s8(_r0, 1); + pp[2] = vget_lane_s8(_r0, 2); + pp[3] = vget_lane_s8(_r0, 3); + + pp += 4; + p0++; + } + } + } +#endif // __ARM_NEON + for (; jj + 1 < max_jj; jj += 2) + { + const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k; + + // if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + B_hstep); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p1, _p3); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p2)); + float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p2)); + float32x4_t _t2 = vcombine_f32(vget_low_f32(_p1), vget_low_f32(_p3)); + float32x4_t _t3 = vcombine_f32(vget_high_f32(_p1), vget_high_f32(_p3)); + int8x8_t _r0 = float2int8(_t0, _t1); + int8x8_t _r1 = float2int8(_t2, _t3); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r0); + vst1_s8(pp + 8, _r1); + + pp += 16; + p0 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vld1_u16(p0)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + B_hstep)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p1)); + float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p1)); + int8x8_t _r0 = float2int8(_t0, _t1); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r0); + + pp += 8; + p0 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = float2int8(float16_to_float32(p0[0]) * scale); + pp[1] = float2int8(float16_to_float32(p0[1]) * scale); + pp[2] = float2int8(float16_to_float32(p0[B_hstep]) * scale); + pp[3] = float2int8(float16_to_float32(p0[B_hstep + 1]) * scale); + pp += 4; + p0 += 2; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(float16_to_float32(p0[0]) * scale); + pp[1] = float2int8(float16_to_float32(p0[B_hstep]) * scale); + pp += 2; + p0++; + } + } + } + for (; jj < max_jj; jj += 1) + { + const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k; + + // if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + for (; kk + 15 < max_kk; kk += 16) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + + vst1_s8(pp, _r0); + + pp += 8; + p0 += 8; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(float16_to_float32(p0[0]) * scale); + pp += 1; + p0++; + } + } + } +} + +static void transpose_pack_B_tile_fp16_to_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + transpose_pack_B_tile_fp16_to_int8_i8mm(B, BT, j, max_jj, k, max_kk, scale); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + transpose_pack_B_tile_fp16_to_int8_asimddp(B, BT, j, max_jj, k, max_kk, scale); + return; + } +#endif + + const int elempack = B.elempack; + const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; + + // NCNN_LOGE("transpose_pack_B_tile_fp16_to_int8 %d %d", max_jj, elempack); + + signed char* pp = BT; + +#if __ARM_NEON + float32x4_t _scale = vdupq_n_f32(scale); +#endif + + int jj = 0; +#if __ARM_NEON +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * elempack; + +#if __aarch64__ + if (elempack == 8) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + 16); + uint16x8_t _s = vld1q_u16(p0 + 24); + uint16x8_t _t = vld1q_u16(p0 + 32); + uint16x8_t _u = vld1q_u16(p0 + 40); + uint16x8_t _v = vld1q_u16(p0 + 48); + uint16x8_t _w = vld1q_u16(p0 + 56); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + float32x4_t _p4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_r)); + float32x4_t _p5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_r)); + float32x4_t _p6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_s)); + float32x4_t _p7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_s)); + float32x4_t _p8 = vcvt_f32_f16((float16x4_t)vget_low_u16(_t)); + float32x4_t _p9 = vcvt_f32_f16((float16x4_t)vget_high_u16(_t)); + float32x4_t _pa = vcvt_f32_f16((float16x4_t)vget_low_u16(_u)); + float32x4_t _pb = vcvt_f32_f16((float16x4_t)vget_high_u16(_u)); + float32x4_t _pc = vcvt_f32_f16((float16x4_t)vget_low_u16(_v)); + float32x4_t _pd = vcvt_f32_f16((float16x4_t)vget_high_u16(_v)); + float32x4_t _pe = vcvt_f32_f16((float16x4_t)vget_low_u16(_w)); + float32x4_t _pf = vcvt_f32_f16((float16x4_t)vget_high_u16(_w)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + _p8 = vmulq_f32(_p8, _scale); + _p9 = vmulq_f32(_p9, _scale); + _pa = vmulq_f32(_pa, _scale); + _pb = vmulq_f32(_pb, _scale); + _pc = vmulq_f32(_pc, _scale); + _pd = vmulq_f32(_pd, _scale); + _pe = vmulq_f32(_pe, _scale); + _pf = vmulq_f32(_pf, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p4, _p6); + int8x8_t _r2 = float2int8(_p8, _pa); + int8x8_t _r3 = float2int8(_pc, _pe); + int8x8_t _r4 = float2int8(_p1, _p3); + int8x8_t _r5 = float2int8(_p5, _p7); + int8x8_t _r6 = float2int8(_p9, _pb); + int8x8_t _r7 = float2int8(_pd, _pf); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p4, _p6); + int8x8_t _r2 = float2int8(_p8, _pa); + int8x8_t _r3 = float2int8(_pc, _pe); + int8x8_t _r4 = float2int8(_p1, _p3); + int8x8_t _r5 = float2int8(_p5, _p7); + int8x8_t _r6 = float2int8(_p9, _pb); + int8x8_t _r7 = float2int8(_pd, _pf); + + int16x8_t _r01 = vreinterpretq_s16_s8(vcombine_s8(_r0, _r1)); + int16x8_t _r23 = vreinterpretq_s16_s8(vcombine_s8(_r2, _r3)); + int16x8_t _r45 = vreinterpretq_s16_s8(vcombine_s8(_r4, _r5)); + int16x8_t _r67 = vreinterpretq_s16_s8(vcombine_s8(_r6, _r7)); + int16x8x2_t _rr0 = vuzpq_s16(_r01, _r23); + int16x8x2_t _rr1 = vuzpq_s16(_r45, _r67); + + vst1q_s8(pp, vreinterpretq_s8_s16(_rr0.val[0])); + vst1q_s8(pp + 16, vreinterpretq_s8_s16(_rr0.val[1])); + vst1q_s8(pp + 32, vreinterpretq_s8_s16(_rr1.val[0])); + vst1q_s8(pp + 48, vreinterpretq_s8_s16(_rr1.val[1])); +#endif // __ARM_FEATURE_DOTPROD + + pp += 64; + p0 += B_hstep * 8; + } + } +#endif // __aarch64__ + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + 16); + uint16x8_t _s = vld1q_u16(p0 + 24); + uint16x8_t _t = vld1q_u16(p0 + B_hstep * 4); + uint16x8_t _u = vld1q_u16(p0 + B_hstep * 4 + 8); + uint16x8_t _v = vld1q_u16(p0 + B_hstep * 4 + 16); + uint16x8_t _w = vld1q_u16(p0 + B_hstep * 4 + 24); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + float32x4_t _p4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_r)); + float32x4_t _p5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_r)); + float32x4_t _p6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_s)); + float32x4_t _p7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_s)); + float32x4_t _p8 = vcvt_f32_f16((float16x4_t)vget_low_u16(_t)); + float32x4_t _p9 = vcvt_f32_f16((float16x4_t)vget_high_u16(_t)); + float32x4_t _pa = vcvt_f32_f16((float16x4_t)vget_low_u16(_u)); + float32x4_t _pb = vcvt_f32_f16((float16x4_t)vget_high_u16(_u)); + float32x4_t _pc = vcvt_f32_f16((float16x4_t)vget_low_u16(_v)); + float32x4_t _pd = vcvt_f32_f16((float16x4_t)vget_high_u16(_v)); + float32x4_t _pe = vcvt_f32_f16((float16x4_t)vget_low_u16(_w)); + float32x4_t _pf = vcvt_f32_f16((float16x4_t)vget_high_u16(_w)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + _p8 = vmulq_f32(_p8, _scale); + _p9 = vmulq_f32(_p9, _scale); + _pa = vmulq_f32(_pa, _scale); + _pb = vmulq_f32(_pb, _scale); + _pc = vmulq_f32(_pc, _scale); + _pd = vmulq_f32(_pd, _scale); + _pe = vmulq_f32(_pe, _scale); + _pf = vmulq_f32(_pf, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p8); + int8x8_t _r1 = float2int8(_p1, _p9); + int8x8_t _r2 = float2int8(_p2, _pa); + int8x8_t _r3 = float2int8(_p3, _pb); + int8x8_t _r4 = float2int8(_p4, _pc); + int8x8_t _r5 = float2int8(_p5, _pd); + int8x8_t _r6 = float2int8(_p6, _pe); + int8x8_t _r7 = float2int8(_p7, _pf); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); + + int16x8_t _r01 = vreinterpretq_s16_s8(vcombine_s8(_r0, _r1)); + int16x8_t _r23 = vreinterpretq_s16_s8(vcombine_s8(_r2, _r3)); + int16x8_t _r45 = vreinterpretq_s16_s8(vcombine_s8(_r4, _r5)); + int16x8_t _r67 = vreinterpretq_s16_s8(vcombine_s8(_r6, _r7)); + int16x8x2_t _rr0 = vuzpq_s16(_r01, _r23); + int16x8x2_t _rr1 = vuzpq_s16(_r45, _r67); + + vst1q_s8(pp, vreinterpretq_s8_s16(_rr0.val[0])); + vst1q_s8(pp + 16, vreinterpretq_s8_s16(_rr0.val[1])); + vst1q_s8(pp + 32, vreinterpretq_s8_s16(_rr1.val[0])); + vst1q_s8(pp + 48, vreinterpretq_s8_s16(_rr1.val[1])); +#endif // __ARM_FEATURE_DOTPROD + + pp += 64; + p0 += B_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + 16); + uint16x8_t _s = vld1q_u16(p0 + 24); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + float32x4_t _p4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_r)); + float32x4_t _p5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_r)); + float32x4_t _p6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_s)); + float32x4_t _p7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_s)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + +#if __ARM_FEATURE_DOTPROD + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); +#else // __ARM_FEATURE_DOTPROD + int16x8_t _r01 = vreinterpretq_s16_s8(vcombine_s8(_r0, _r1)); + int16x8_t _r23 = vreinterpretq_s16_s8(vcombine_s8(_r2, _r3)); + int16x8x2_t _rr = vuzpq_s16(_r01, _r23); + + vst1q_s8(pp, vreinterpretq_s8_s16(_rr.val[0])); + vst1q_s8(pp + 16, vreinterpretq_s8_s16(_rr.val[1])); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += B_hstep * 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + B_hstep); + uint16x8_t _r = vld1q_u16(p0 + B_hstep * 2); + uint16x8_t _s = vld1q_u16(p0 + B_hstep * 3); + uint16x8_t _t = vld1q_u16(p0 + B_hstep * 4); + uint16x8_t _u = vld1q_u16(p0 + B_hstep * 5); + uint16x8_t _v = vld1q_u16(p0 + B_hstep * 6); + uint16x8_t _w = vld1q_u16(p0 + B_hstep * 7); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + float32x4_t _p4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_r)); + float32x4_t _p5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_r)); + float32x4_t _p6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_s)); + float32x4_t _p7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_s)); + float32x4_t _p8 = vcvt_f32_f16((float16x4_t)vget_low_u16(_t)); + float32x4_t _p9 = vcvt_f32_f16((float16x4_t)vget_high_u16(_t)); + float32x4_t _pa = vcvt_f32_f16((float16x4_t)vget_low_u16(_u)); + float32x4_t _pb = vcvt_f32_f16((float16x4_t)vget_high_u16(_u)); + float32x4_t _pc = vcvt_f32_f16((float16x4_t)vget_low_u16(_v)); + float32x4_t _pd = vcvt_f32_f16((float16x4_t)vget_high_u16(_v)); + float32x4_t _pe = vcvt_f32_f16((float16x4_t)vget_low_u16(_w)); + float32x4_t _pf = vcvt_f32_f16((float16x4_t)vget_high_u16(_w)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + _p8 = vmulq_f32(_p8, _scale); + _p9 = vmulq_f32(_p9, _scale); + _pa = vmulq_f32(_pa, _scale); + _pb = vmulq_f32(_pb, _scale); + _pc = vmulq_f32(_pc, _scale); + _pd = vmulq_f32(_pd, _scale); + _pe = vmulq_f32(_pe, _scale); + _pf = vmulq_f32(_pf, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8x2_t _r04 = vzip_s8(_r0, _r4); + int8x8x2_t _r15 = vzip_s8(_r1, _r5); + int8x8x2_t _r26 = vzip_s8(_r2, _r6); + int8x8x2_t _r37 = vzip_s8(_r3, _r7); + int8x16x4_t _r0123; + _r0123.val[0] = vcombine_s8(_r04.val[0], _r04.val[1]); + _r0123.val[1] = vcombine_s8(_r15.val[0], _r15.val[1]); + _r0123.val[2] = vcombine_s8(_r26.val[0], _r26.val[1]); + _r0123.val[3] = vcombine_s8(_r37.val[0], _r37.val[1]); + + vst4q_s8(pp, _r0123); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8x4_t _r0123; + _r0123.val[0] = _r0; + _r0123.val[1] = _r1; + _r0123.val[2] = _r2; + _r0123.val[3] = _r3; + int8x8x4_t _r4567; + _r4567.val[0] = _r4; + _r4567.val[1] = _r5; + _r4567.val[2] = _r6; + _r4567.val[3] = _r7; + + vst4_s8(pp, _r0123); + vst4_s8(pp + 32, _r4567); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(_r0, _r2); + _r01.val[1] = vcombine_s8(_r1, _r3); + int8x16x2_t _r23; + _r23.val[0] = vcombine_s8(_r4, _r6); + _r23.val[1] = vcombine_s8(_r5, _r7); + + vst2q_s8(pp, _r01); + vst2q_s8(pp + 32, _r23); +#endif // __ARM_FEATURE_DOTPROD + + pp += 64; + p0 += B_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + B_hstep); + uint16x8_t _r = vld1q_u16(p0 + B_hstep * 2); + uint16x8_t _s = vld1q_u16(p0 + B_hstep * 3); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + float32x4_t _p4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_r)); + float32x4_t _p5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_r)); + float32x4_t _p6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_s)); + float32x4_t _p7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_s)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + +#if __ARM_FEATURE_DOTPROD + int8x8x4_t _r0123; + _r0123.val[0] = float2int8(_p0, _p1); + _r0123.val[1] = float2int8(_p2, _p3); + _r0123.val[2] = float2int8(_p4, _p5); + _r0123.val[3] = float2int8(_p6, _p7); + + vst4_s8(pp, _r0123); +#else // __ARM_FEATURE_DOTPROD + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p1), float2int8(_p4, _p5)); + _r01.val[1] = vcombine_s8(float2int8(_p2, _p3), float2int8(_p6, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += B_hstep * 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + B_hstep); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p1); + _r01.val[1] = float2int8(_p2, _p3); + + vst2_s8(pp, _r01); + + pp += 16; + p0 += B_hstep * 2; + } + for (; kk < max_kk; kk++) + { + uint16x8_t _p = vld1q_u16(p0); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + + vst1_s8(pp, _r0); + + pp += 8; + p0 += B_hstep; + } + } + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * elempack; + +#if __aarch64__ + if (elempack == 8) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + 16); + uint16x8_t _s = vld1q_u16(p0 + 24); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + float32x4_t _p4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_r)); + float32x4_t _p5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_r)); + float32x4_t _p6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_s)); + float32x4_t _p7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_s)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p4, _p6); + int8x8_t _r2 = float2int8(_p1, _p3); + int8x8_t _r3 = float2int8(_p5, _p7); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p4, _p6)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p1, _p3)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p5, _p7)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r2 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + + pp += 32; + p0 += B_hstep * 8; + } + } +#endif // __aarch64__ + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + B_hstep * 4); + uint16x8_t _s = vld1q_u16(p0 + B_hstep * 4 + 8); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + float32x4_t _p4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_r)); + float32x4_t _p5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_r)); + float32x4_t _p6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_s)); + float32x4_t _p7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_s)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p4); + int8x8_t _r1 = float2int8(_p1, _p5); + int8x8_t _r2 = float2int8(_p2, _p6); + int8x8_t _r3 = float2int8(_p3, _p7); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p4, _p5)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p6, _p7)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r2 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + + pp += 32; + p0 += B_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += B_hstep * 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vld1_u16(p0)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + B_hstep)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + B_hstep * 2)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + B_hstep * 3)); + float32x4_t _p4 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + B_hstep * 4)); + float32x4_t _p5 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + B_hstep * 5)); + float32x4_t _p6 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + B_hstep * 6)); + float32x4_t _p7 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + B_hstep * 7)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + float32x4x2_t _p04 = vzipq_f32(_p0, _p4); + float32x4x2_t _p15 = vzipq_f32(_p1, _p5); + float32x4x2_t _p26 = vzipq_f32(_p2, _p6); + float32x4x2_t _p37 = vzipq_f32(_p3, _p7); + int8x8x4_t _r0123; + _r0123.val[0] = float2int8(_p04.val[0], _p04.val[1]); + _r0123.val[1] = float2int8(_p15.val[0], _p15.val[1]); + _r0123.val[2] = float2int8(_p26.val[0], _p26.val[1]); + _r0123.val[3] = float2int8(_p37.val[0], _p37.val[1]); + + vst4_s8(pp, _r0123); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8x4_t _r0123; + _r0123.val[0] = float2int8(_p0, _p4); + _r0123.val[1] = float2int8(_p1, _p5); + _r0123.val[2] = float2int8(_p2, _p6); + _r0123.val[3] = float2int8(_p3, _p7); + + vst4_s8(pp, _r0123); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p2), float2int8(_p4, _p6)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p3), float2int8(_p5, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += B_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vld1_u16(p0)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + B_hstep)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + B_hstep * 2)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + B_hstep * 3)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + +#if __ARM_FEATURE_DOTPROD + transpose4x4_ps(_p0, _p1, _p2, _p3); + int8x8_t _r01 = float2int8(_p0, _p1); + int8x8_t _r23 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r01, _r23)); +#else // __ARM_FEATURE_DOTPROD + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p2); + _r01.val[1] = float2int8(_p1, _p3); + + vst2_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 16; + p0 += B_hstep * 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vld1_u16(p0)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + B_hstep)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + float32x4x2_t _p01 = vzipq_f32(_p0, _p1); + int8x8_t _r01 = float2int8(_p01.val[0], _p01.val[1]); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += B_hstep * 2; + } + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(float16_to_float32(p0[0]) * scale); + pp[1] = float2int8(float16_to_float32(p0[1]) * scale); + pp[2] = float2int8(float16_to_float32(p0[2]) * scale); + pp[3] = float2int8(float16_to_float32(p0[3]) * scale); + pp += 4; + p0 += B_hstep; + } + } + } +#endif // __ARM_NEON + for (; jj + 1 < max_jj; jj += 2) + { + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * elempack; + +#if __ARM_NEON +#if __aarch64__ + if (elempack == 8) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p1, _p3); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4x2_t _t01 = vzip_s16(_t0, _t1); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += B_hstep * 8; + } + } +#endif // __aarch64__ + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + B_hstep * 4); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p1, _p3); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p1, _p3)); + int16x4x2_t _t01 = vzip_s16(_t0, _t1); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += B_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + uint16x8_t _p = vld1q_u16(p0); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r01 = float2int8(_p0, _p1); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p1)); + float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p1)); + int8x8_t _r01 = float2int8(_t0, _t1); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r01); + + pp += 8; + p0 += B_hstep * 4; + } + } +#endif // __ARM_NEON + if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + for (; kk + 7 < max_kk; kk += 8) + { +#if __ARM_FEATURE_DOTPROD + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[1], _p, 1); + _p = vsetq_lane_u16(p0[B_hstep], _p, 2); + _p = vsetq_lane_u16(p0[B_hstep + 1], _p, 3); + _p = vsetq_lane_u16(p0[B_hstep * 2], _p, 4); + _p = vsetq_lane_u16(p0[B_hstep * 2 + 1], _p, 5); + _p = vsetq_lane_u16(p0[B_hstep * 3], _p, 6); + _p = vsetq_lane_u16(p0[B_hstep * 3 + 1], _p, 7); + uint16x8_t _q = uint16x8_t(); + _q = vsetq_lane_u16(p0[B_hstep * 4], _q, 0); + _q = vsetq_lane_u16(p0[B_hstep * 4 + 1], _q, 1); + _q = vsetq_lane_u16(p0[B_hstep * 5], _q, 2); + _q = vsetq_lane_u16(p0[B_hstep * 5 + 1], _q, 3); + _q = vsetq_lane_u16(p0[B_hstep * 6], _q, 4); + _q = vsetq_lane_u16(p0[B_hstep * 6 + 1], _q, 5); + _q = vsetq_lane_u16(p0[B_hstep * 7], _q, 6); + _q = vsetq_lane_u16(p0[B_hstep * 7 + 1], _q, 7); + float32x4_t _p01 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p23 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p45 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p67 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + + _p01 = vmulq_f32(_p01, _scale); + _p23 = vmulq_f32(_p23, _scale); + _p45 = vmulq_f32(_p45, _scale); + _p67 = vmulq_f32(_p67, _scale); + + int8x8_t _r0 = float2int8(_p01, _p23); + int8x8_t _r1 = float2int8(_p45, _p67); + +#if __ARM_FEATURE_MATMUL_INT8 + int8x8x2_t _r01 = vuzp_s8(_r0, _r1); + + vst1q_s8(pp, vcombine_s8(_r01.val[0], _r01.val[1])); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8x2_t _r01 = vtrn_s8(_r0, _r1); + int8x8x2_t _rr01 = vuzp_s8(_r01.val[0], _r01.val[1]); + + vst1q_s8(pp, vcombine_s8(_rr01.val[0], _rr01.val[1])); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[1], _p, 1); + _p = vsetq_lane_u16(p0[B_hstep * 2], _p, 2); + _p = vsetq_lane_u16(p0[B_hstep * 2 + 1], _p, 3); + _p = vsetq_lane_u16(p0[B_hstep * 4], _p, 4); + _p = vsetq_lane_u16(p0[B_hstep * 4 + 1], _p, 5); + _p = vsetq_lane_u16(p0[B_hstep * 6], _p, 6); + _p = vsetq_lane_u16(p0[B_hstep * 6 + 1], _p, 7); + uint16x8_t _q = uint16x8_t(); + _q = vsetq_lane_u16(p0[B_hstep], _q, 0); + _q = vsetq_lane_u16(p0[B_hstep + 1], _q, 1); + _q = vsetq_lane_u16(p0[B_hstep * 3], _q, 2); + _q = vsetq_lane_u16(p0[B_hstep * 3 + 1], _q, 3); + _q = vsetq_lane_u16(p0[B_hstep * 5], _q, 4); + _q = vsetq_lane_u16(p0[B_hstep * 5 + 1], _q, 5); + _q = vsetq_lane_u16(p0[B_hstep * 7], _q, 6); + _q = vsetq_lane_u16(p0[B_hstep * 7 + 1], _q, 7); + float32x4_t _p02 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p46 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p13 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p57 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + + _p02 = vmulq_f32(_p02, _scale); + _p46 = vmulq_f32(_p46, _scale); + _p13 = vmulq_f32(_p13, _scale); + _p57 = vmulq_f32(_p57, _scale); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p02, _p46); + _r01.val[1] = float2int8(_p13, _p57); + + vst2_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 16; + p0 += B_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[1], _p, 1); + _p = vsetq_lane_u16(p0[B_hstep], _p, 2); + _p = vsetq_lane_u16(p0[B_hstep + 1], _p, 3); + _p = vsetq_lane_u16(p0[B_hstep * 2], _p, 4); + _p = vsetq_lane_u16(p0[B_hstep * 2 + 1], _p, 5); + _p = vsetq_lane_u16(p0[B_hstep * 3], _p, 6); + _p = vsetq_lane_u16(p0[B_hstep * 3 + 1], _p, 7); + float32x4_t _p01 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p23 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + + _p01 = vmulq_f32(_p01, _scale); + _p23 = vmulq_f32(_p23, _scale); + + float32x4x2_t _pp = vuzpq_f32(_p01, _p23); + int8x8_t _r01 = float2int8(_pp.val[0], _pp.val[1]); +#else // __ARM_FEATURE_DOTPROD + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[1], _p, 1); + _p = vsetq_lane_u16(p0[B_hstep * 2], _p, 2); + _p = vsetq_lane_u16(p0[B_hstep * 2 + 1], _p, 3); + _p = vsetq_lane_u16(p0[B_hstep], _p, 4); + _p = vsetq_lane_u16(p0[B_hstep + 1], _p, 5); + _p = vsetq_lane_u16(p0[B_hstep * 3], _p, 6); + _p = vsetq_lane_u16(p0[B_hstep * 3 + 1], _p, 7); + float32x4_t _p02 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p13 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + + _p02 = vmulq_f32(_p02, _scale); + _p13 = vmulq_f32(_p13, _scale); + + float32x4x2_t _pp = vzipq_f32(_p02, _p13); + int8x8_t _r01 = float2int8(_pp.val[0], _pp.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r01); + + pp += 8; + p0 += B_hstep * 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = float2int8(float16_to_float32(p0[0]) * scale); + pp[1] = float2int8(float16_to_float32(p0[B_hstep + 0]) * scale); + pp[2] = float2int8(float16_to_float32(p0[1]) * scale); + pp[3] = float2int8(float16_to_float32(p0[B_hstep + 1]) * scale); + pp += 4; + p0 += B_hstep * 2; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(float16_to_float32(p0[0]) * scale); + pp[1] = float2int8(float16_to_float32(p0[1]) * scale); + pp += 2; + p0 += B_hstep; + } + } + } + for (; jj < max_jj; jj += 1) + { + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * elempack; + +#if __ARM_NEON +#if __aarch64__ + if (elempack == 8) + { + int kk = 0; + for (; kk + 15 < max_kk; kk += 16) + { + uint16x8_t _p01 = vld1q_u16(p0); + uint16x8_t _p23 = vld1q_u16(p0 + B_hstep * 8); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p01)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p01)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p23)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p23)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + int8x8_t _r23 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r01, _r23)); + + pp += 16; + p0 += B_hstep * 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p01 = vld1q_u16(p0); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p01)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p01)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + pp += 8; + p0 += B_hstep * 8; + } + } +#endif // __aarch64__ + if (elempack == 4) + { + int kk = 0; + for (; kk + 15 < max_kk; kk += 16) + { + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vld1_u16(p0)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + B_hstep * 4)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + B_hstep * 8)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + B_hstep * 12)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + int8x8_t _r23 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r01, _r23)); + + pp += 16; + p0 += B_hstep * 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vld1_u16(p0)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vld1_u16(p0 + B_hstep * 4)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += B_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(float16_to_float32(p0[0]) * scale); + pp[1] = float2int8(float16_to_float32(p0[1]) * scale); + pp[2] = float2int8(float16_to_float32(p0[2]) * scale); + pp[3] = float2int8(float16_to_float32(p0[3]) * scale); + pp += 4; + p0 += B_hstep * 4; + } + } +#endif // __ARM_NEON + if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + for (; kk + 15 < max_kk; kk += 16) + { + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[B_hstep], _p, 1); + _p = vsetq_lane_u16(p0[B_hstep * 2], _p, 2); + _p = vsetq_lane_u16(p0[B_hstep * 3], _p, 3); + _p = vsetq_lane_u16(p0[B_hstep * 4], _p, 4); + _p = vsetq_lane_u16(p0[B_hstep * 5], _p, 5); + _p = vsetq_lane_u16(p0[B_hstep * 6], _p, 6); + _p = vsetq_lane_u16(p0[B_hstep * 7], _p, 7); + uint16x8_t _q = uint16x8_t(); + _q = vsetq_lane_u16(p0[B_hstep * 8], _q, 0); + _q = vsetq_lane_u16(p0[B_hstep * 9], _q, 1); + _q = vsetq_lane_u16(p0[B_hstep * 10], _q, 2); + _q = vsetq_lane_u16(p0[B_hstep * 11], _q, 3); + _q = vsetq_lane_u16(p0[B_hstep * 12], _q, 4); + _q = vsetq_lane_u16(p0[B_hstep * 13], _q, 5); + _q = vsetq_lane_u16(p0[B_hstep * 14], _q, 6); + _q = vsetq_lane_u16(p0[B_hstep * 15], _q, 7); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + int8x8_t _r23 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r01, _r23)); + + pp += 16; + p0 += B_hstep * 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[B_hstep], _p, 1); + _p = vsetq_lane_u16(p0[B_hstep * 2], _p, 2); + _p = vsetq_lane_u16(p0[B_hstep * 3], _p, 3); + _p = vsetq_lane_u16(p0[B_hstep * 4], _p, 4); + _p = vsetq_lane_u16(p0[B_hstep * 5], _p, 5); + _p = vsetq_lane_u16(p0[B_hstep * 6], _p, 6); + _p = vsetq_lane_u16(p0[B_hstep * 7], _p, 7); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += B_hstep * 8; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(float16_to_float32(p0[0]) * scale); + pp += 1; + p0 += B_hstep; + } + } + } +} + +static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + unpack_output_tile_int32_to_fp16_asimddp(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales, alpha, beta); + return; + } +#endif + + const int out_elempack = top_blob.elempack; + const int out_hstep = top_blob.dims == 3 ? (int)top_blob.cstep : top_blob.w; + + const int c_hstep = C.dims == 3 ? (int)C.cstep : C.w; + const int c_elempack = C.elempack; + const unsigned short* pC = C; + + // NCNN_LOGE("unpack_output_tile_int32_to_fp16 %d %d %d %d %d %d %d", i, max_ii, j, max_jj, out_elempack, broadcast_type_C, c_elempack); + + const int* pp = topT; + + int ii = 0; +#if __ARM_NEON + for (; ii + 7 < max_ii; ii += 8) + { + unsigned short* p0 = (unsigned short*)top_blob + (i + ii) * out_hstep + j * out_elempack; + + float32x4_t _descale0 = vld1q_f32((const float*)descales + ii); + float32x4_t _descale1 = vld1q_f32((const float*)descales + ii + 4); + + float32x4_t _c0; + float32x4_t _c1; + if (pC) + { + if (broadcast_type_C == 0) + { + _c0 = vdupq_n_f32(float16_to_float32(pC[0]) * beta); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const unsigned short*)C + i + ii; + uint16x8_t _c = vld1q_u16(pC); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c)); + _c0 = vmulq_n_f32(_c0, beta); + _c1 = vmulq_n_f32(_c1, beta); + } + if (broadcast_type_C == 3) + { + pC = (const unsigned short*)C + (i + ii) * c_hstep + j * c_elempack; + } + if (broadcast_type_C == 4) + { + pC = (const unsigned short*)C + j; + } + } + + int jj = 0; +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + int32x4_t _sum8 = vld1q_s32(pp + 32); + int32x4_t _sum9 = vld1q_s32(pp + 36); + int32x4_t _suma = vld1q_s32(pp + 40); + int32x4_t _sumb = vld1q_s32(pp + 44); + int32x4_t _sumc = vld1q_s32(pp + 48); + int32x4_t _sumd = vld1q_s32(pp + 52); + int32x4_t _sume = vld1q_s32(pp + 56); + int32x4_t _sumf = vld1q_s32(pp + 60); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 +#else + // from + // a0 b1 c2 d3 + // e4 f5 g6 h7 + // e0 f1 g2 h3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // g4 h5 e6 f7 + // g0 h1 e2 f3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // e7 f6 g5 h4 + // e3 f2 g1 h0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // g7 h6 e5 f4 + // g3 h2 e1 f0 + // c7 d6 a5 b4 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 + { + _sum8 = vrev64q_s32(_sum8); + _sum9 = vrev64q_s32(_sum9); + _suma = vrev64q_s32(_suma); + _sumb = vrev64q_s32(_sumb); + _sumc = vrev64q_s32(_sumc); + _sumd = vrev64q_s32(_sumd); + _sume = vrev64q_s32(_sume); + _sumf = vrev64q_s32(_sumf); + _sum8 = vextq_s32(_sum8, _sum8, 2); + _sum9 = vextq_s32(_sum9, _sum9, 2); + _suma = vextq_s32(_suma, _suma, 2); + _sumb = vextq_s32(_sumb, _sumb, 2); + _sumc = vextq_s32(_sumc, _sumc, 2); + _sumd = vextq_s32(_sumd, _sumd, 2); + _sume = vextq_s32(_sume, _sume, 2); + _sumf = vextq_s32(_sumf, _sumf, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sumc); + int32x4x2_t _t1 = vzipq_s32(_sum4, _sum8); + int32x4x2_t _t2 = vzipq_s32(_sum2, _sume); + int32x4x2_t _t3 = vzipq_s32(_sum6, _suma); + int32x4x2_t _t4 = vzipq_s32(_sum3, _sumf); + int32x4x2_t _t5 = vzipq_s32(_sum7, _sumb); + int32x4x2_t _t6 = vzipq_s32(_sum1, _sumd); + int32x4x2_t _t7 = vzipq_s32(_sum5, _sum9); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum8 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); + _sum9 = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); + _suma = vcombine_s32(vget_low_s32(_t5.val[1]), vget_low_s32(_t4.val[1])); + _sumb = vcombine_s32(vget_high_s32(_t5.val[1]), vget_high_s32(_t4.val[1])); + _sumc = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); + _sumd = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); + _sume = vcombine_s32(vget_low_s32(_t7.val[1]), vget_low_s32(_t6.val[1])); + _sumf = vcombine_s32(vget_high_s32(_t7.val[1]), vget_high_s32(_t6.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + _sum9 = vrev64q_s32(_sum9); + _sumb = vrev64q_s32(_sumb); + _sumd = vrev64q_s32(_sumd); + _sumf = vrev64q_s32(_sumf); + } +#endif + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum8), _descale0); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum9), _descale0); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_suma), _descale0); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sumb), _descale0); + float32x4_t _f8 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); + float32x4_t _f9 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); + float32x4_t _fa = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); + float32x4_t _fb = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); + float32x4_t _fc = vmulq_f32(vcvtq_f32_s32(_sumc), _descale1); + float32x4_t _fd = vmulq_f32(vcvtq_f32_s32(_sumd), _descale1); + float32x4_t _fe = vmulq_f32(vcvtq_f32_s32(_sume), _descale1); + float32x4_t _ff = vmulq_f32(vcvtq_f32_s32(_sumf), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c0); + _fa = vaddq_f32(_fa, _c0); + _fb = vaddq_f32(_fb, _c0); + _fc = vaddq_f32(_fc, _c0); + _fd = vaddq_f32(_fd, _c0); + _fe = vaddq_f32(_fe, _c0); + _ff = vaddq_f32(_ff, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + _f8 = vaddq_f32(_f8, _c1); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c1); + _fb = vaddq_f32(_fb, _c1); + _fc = vaddq_f32(_fc, _c1); + _fd = vaddq_f32(_fd, _c1); + _fe = vaddq_f32(_fe, _c1); + _ff = vaddq_f32(_ff, _c1); + } + if (broadcast_type_C == 3) + { +#if __aarch64__ + if (c_elempack == 8) + { + uint16x8_t _c08 = vld1q_u16(pC); + uint16x8_t _c19 = vld1q_u16(pC + 8); + uint16x8_t _c2a = vld1q_u16(pC + 16); + uint16x8_t _c3b = vld1q_u16(pC + 24); + uint16x8_t _c4c = vld1q_u16(pC + 32); + uint16x8_t _c5d = vld1q_u16(pC + 40); + uint16x8_t _c6e = vld1q_u16(pC + 48); + uint16x8_t _c7f = vld1q_u16(pC + 56); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c08)); + _c1 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c19)); + float32x4_t _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c2a)); + float32x4_t _c3 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c3b)); + float32x4_t _c4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c4c)); + float32x4_t _c5 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c5d)); + float32x4_t _c6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c6e)); + float32x4_t _c7 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c7f)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c0 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c08)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c19)); + _c2 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c2a)); + _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c3b)); + _c4 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c4c)); + _c5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c5d)); + _c6 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c6e)); + _c7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c7f)); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); + } + pC += 64; + } +#endif // __aarch64__ + if (c_elempack == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + uint16x8_t _c45 = vld1q_u16(pC + 16); + uint16x8_t _c67 = vld1q_u16(pC + 24); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); + float32x4_t _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); + float32x4_t _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); + float32x4_t _c4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c45)); + float32x4_t _c5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c45)); + float32x4_t _c6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c67)); + float32x4_t _c7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c67)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c01 = vld1q_u16(pC + c_hstep * 4); + _c23 = vld1q_u16(pC + c_hstep * 4 + 8); + _c45 = vld1q_u16(pC + c_hstep * 4 + 16); + _c67 = vld1q_u16(pC + c_hstep * 4 + 24); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); + _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); + _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); + _c4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c45)); + _c5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c45)); + _c6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c67)); + _c7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c67)); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); + } + pC += 32; + } + if (c_elempack == 1) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + c_hstep); + uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); + uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); + transpose8x4_u16(_c01, _c23, _c45, _c67); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); + float32x4_t _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); + float32x4_t _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); + float32x4_t _c4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c45)); + float32x4_t _c5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c45)); + float32x4_t _c6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c67)); + float32x4_t _c7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c67)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c01 = vld1q_u16(pC + c_hstep * 4); + _c23 = vld1q_u16(pC + c_hstep * 5); + _c45 = vld1q_u16(pC + c_hstep * 6); + _c67 = vld1q_u16(pC + c_hstep * 7); + transpose8x4_u16(_c01, _c23, _c45, _c67); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); + _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); + _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); + _c4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c45)); + _c5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c45)); + _c6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c67)); + _c7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c67)); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); + } + pC += 8; + } + } + if (broadcast_type_C == 4) + { + uint16x8_t _cc = vld1q_u16(pC); + float32x4_t _cc0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_cc)); + float32x4_t _cc1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_cc)); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _cc0 = vmulq_f32(_cc0, _beta); + _cc1 = vmulq_f32(_cc1, _beta); + } + _c0 = vdupq_laneq_f32(_cc0, 0); + _c1 = vdupq_laneq_f32(_cc0, 1); + float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); + float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); + float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); + float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); + float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); + float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + pC += 8; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + _f8 = vmulq_f32(_f8, _alpha); + _f9 = vmulq_f32(_f9, _alpha); + _fa = vmulq_f32(_fa, _alpha); + _fb = vmulq_f32(_fb, _alpha); + _fc = vmulq_f32(_fc, _alpha); + _fd = vmulq_f32(_fd, _alpha); + _fe = vmulq_f32(_fe, _alpha); + _ff = vmulq_f32(_ff, _alpha); + } + + uint16x4_t _hf0 = (uint16x4_t)vcvt_f16_f32(_f0); + uint16x4_t _hf1 = (uint16x4_t)vcvt_f16_f32(_f1); + uint16x4_t _hf2 = (uint16x4_t)vcvt_f16_f32(_f2); + uint16x4_t _hf3 = (uint16x4_t)vcvt_f16_f32(_f3); + uint16x4_t _hf4 = (uint16x4_t)vcvt_f16_f32(_f4); + uint16x4_t _hf5 = (uint16x4_t)vcvt_f16_f32(_f5); + uint16x4_t _hf6 = (uint16x4_t)vcvt_f16_f32(_f6); + uint16x4_t _hf7 = (uint16x4_t)vcvt_f16_f32(_f7); + uint16x4_t _hf8 = (uint16x4_t)vcvt_f16_f32(_f8); + uint16x4_t _hf9 = (uint16x4_t)vcvt_f16_f32(_f9); + uint16x4_t _hfa = (uint16x4_t)vcvt_f16_f32(_fa); + uint16x4_t _hfb = (uint16x4_t)vcvt_f16_f32(_fb); + uint16x4_t _hfc = (uint16x4_t)vcvt_f16_f32(_fc); + uint16x4_t _hfd = (uint16x4_t)vcvt_f16_f32(_fd); + uint16x4_t _hfe = (uint16x4_t)vcvt_f16_f32(_fe); + uint16x4_t _hff = (uint16x4_t)vcvt_f16_f32(_ff); + +#if __aarch64__ + if (out_elempack == 8) + { + vst1q_u16(p0, vcombine_u16(_hf0, _hf8)); + vst1q_u16(p0 + 8, vcombine_u16(_hf1, _hf9)); + vst1q_u16(p0 + 16, vcombine_u16(_hf2, _hfa)); + vst1q_u16(p0 + 24, vcombine_u16(_hf3, _hfb)); + vst1q_u16(p0 + 32, vcombine_u16(_hf4, _hfc)); + vst1q_u16(p0 + 40, vcombine_u16(_hf5, _hfd)); + vst1q_u16(p0 + 48, vcombine_u16(_hf6, _hfe)); + vst1q_u16(p0 + 56, vcombine_u16(_hf7, _hff)); + p0 += 64; + } +#endif // __aarch64__ + if (out_elempack == 4) + { + vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); + vst1q_u16(p0 + 8, vcombine_u16(_hf2, _hf3)); + vst1q_u16(p0 + 16, vcombine_u16(_hf4, _hf5)); + vst1q_u16(p0 + 24, vcombine_u16(_hf6, _hf7)); + vst1q_u16(p0 + out_hstep * 4, vcombine_u16(_hf8, _hf9)); + vst1q_u16(p0 + out_hstep * 4 + 8, vcombine_u16(_hfa, _hfb)); + vst1q_u16(p0 + out_hstep * 4 + 16, vcombine_u16(_hfc, _hfd)); + vst1q_u16(p0 + out_hstep * 4 + 24, vcombine_u16(_hfe, _hff)); + p0 += 32; + } + if (out_elempack == 1) + { + transpose4x4_u16(_hf0, _hf1, _hf2, _hf3); + transpose4x4_u16(_hf4, _hf5, _hf6, _hf7); + vst1q_u16(p0, vcombine_u16(_hf0, _hf4)); + vst1q_u16(p0 + out_hstep, vcombine_u16(_hf1, _hf5)); + vst1q_u16(p0 + out_hstep * 2, vcombine_u16(_hf2, _hf6)); + vst1q_u16(p0 + out_hstep * 3, vcombine_u16(_hf3, _hf7)); + transpose4x4_u16(_hf8, _hf9, _hfa, _hfb); + transpose4x4_u16(_hfc, _hfd, _hfe, _hff); + vst1q_u16(p0 + out_hstep * 4, vcombine_u16(_hf8, _hfc)); + vst1q_u16(p0 + out_hstep * 5, vcombine_u16(_hf9, _hfd)); + vst1q_u16(p0 + out_hstep * 6, vcombine_u16(_hfa, _hfe)); + vst1q_u16(p0 + out_hstep * 7, vcombine_u16(_hfb, _hff)); + p0 += 8; + } + + pp += 64; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 +#else + // from + // a0 b1 c2 d3 + // e0 f1 g2 h3 + // c0 d1 a2 b3 + // g0 h1 e2 f3 + // a3 b2 c1 d0 + // e3 f2 g1 h0 + // c3 d2 a1 b0 + // g3 h2 e1 f0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + { + _sum4 = vrev64q_s32(_sum4); + _sum5 = vrev64q_s32(_sum5); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + _sum4 = vextq_s32(_sum4, _sum4, 2); + _sum5 = vextq_s32(_sum5, _sum5, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c1); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c1); + _f7 = vaddq_f32(_f7, _c1); + } + if (broadcast_type_C == 3) + { +#if __aarch64__ + if (c_elempack == 8) + { + uint16x8_t _c04 = vld1q_u16(pC); + uint16x8_t _c15 = vld1q_u16(pC + 8); + uint16x8_t _c26 = vld1q_u16(pC + 16); + uint16x8_t _c37 = vld1q_u16(pC + 24); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c04)); + _c1 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c15)); + float32x4_t _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c26)); + float32x4_t _c3 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c37)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + _c0 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c04)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c15)); + _c2 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c26)); + _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c37)); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); + } + pC += 32; + } +#endif // __aarch64__ + if (c_elempack == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); + float32x4_t _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); + float32x4_t _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + _c01 = vld1q_u16(pC + c_hstep * 4); + _c23 = vld1q_u16(pC + c_hstep * 4 + 8); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); + _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); + _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); + } + pC += 16; + } + if (c_elempack == 1) + { + uint16x4_t _cc0 = vld1_u16(pC); + uint16x4_t _cc1 = vld1_u16(pC + c_hstep); + uint16x4_t _cc2 = vld1_u16(pC + c_hstep * 2); + uint16x4_t _cc3 = vld1_u16(pC + c_hstep * 3); + transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); + _c0 = vcvt_f32_f16((float16x4_t)_cc0); + _c1 = vcvt_f32_f16((float16x4_t)_cc1); + float32x4_t _c2 = vcvt_f32_f16((float16x4_t)_cc2); + float32x4_t _c3 = vcvt_f32_f16((float16x4_t)_cc3); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + _cc0 = vld1_u16(pC + c_hstep * 4); + _cc1 = vld1_u16(pC + c_hstep * 5); + _cc2 = vld1_u16(pC + c_hstep * 6); + _cc3 = vld1_u16(pC + c_hstep * 7); + transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); + _c0 = vcvt_f32_f16((float16x4_t)_cc0); + _c1 = vcvt_f32_f16((float16x4_t)_cc1); + _c2 = vcvt_f32_f16((float16x4_t)_cc2); + _c3 = vcvt_f32_f16((float16x4_t)_cc3); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); + } + pC += 4; + } + } + if (broadcast_type_C == 4) + { + float32x4_t _c = vcvt_f32_f16((float16x4_t)vld1_u16(pC)); + _c = vmulq_n_f32(_c, beta); +#if __aarch64__ + _c0 = vdupq_laneq_f32(_c, 0); + _c1 = vdupq_laneq_f32(_c, 1); + float32x4_t _c2 = vdupq_laneq_f32(_c, 2); + float32x4_t _c3 = vdupq_laneq_f32(_c, 3); +#else + _c0 = vdupq_lane_f32(vget_low_f32(_c), 0); + _c1 = vdupq_lane_f32(vget_low_f32(_c), 1); + float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_c), 0); + float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_c), 1); +#endif + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + pC += 4; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + + uint16x4_t _hf0 = (uint16x4_t)vcvt_f16_f32(_f0); + uint16x4_t _hf1 = (uint16x4_t)vcvt_f16_f32(_f1); + uint16x4_t _hf2 = (uint16x4_t)vcvt_f16_f32(_f2); + uint16x4_t _hf3 = (uint16x4_t)vcvt_f16_f32(_f3); + uint16x4_t _hf4 = (uint16x4_t)vcvt_f16_f32(_f4); + uint16x4_t _hf5 = (uint16x4_t)vcvt_f16_f32(_f5); + uint16x4_t _hf6 = (uint16x4_t)vcvt_f16_f32(_f6); + uint16x4_t _hf7 = (uint16x4_t)vcvt_f16_f32(_f7); + +#if __aarch64__ + if (out_elempack == 8) + { + vst1q_u16(p0, vcombine_u16(_hf0, _hf4)); + vst1q_u16(p0 + 8, vcombine_u16(_hf1, _hf5)); + vst1q_u16(p0 + 16, vcombine_u16(_hf2, _hf6)); + vst1q_u16(p0 + 24, vcombine_u16(_hf3, _hf7)); + p0 += 32; + } +#endif // __aarch64__ + if (out_elempack == 4) + { + vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); + vst1q_u16(p0 + 8, vcombine_u16(_hf2, _hf3)); + vst1q_u16(p0 + out_hstep * 4, vcombine_u16(_hf4, _hf5)); + vst1q_u16(p0 + out_hstep * 4 + 8, vcombine_u16(_hf6, _hf7)); + p0 += 16; + } + if (out_elempack == 1) + { + transpose4x4_u16(_hf0, _hf1, _hf2, _hf3); + transpose4x4_u16(_hf4, _hf5, _hf6, _hf7); + vst1_u16(p0, _hf0); + vst1_u16(p0 + out_hstep, _hf1); + vst1_u16(p0 + out_hstep * 2, _hf2); + vst1_u16(p0 + out_hstep * 3, _hf3); + vst1_u16(p0 + out_hstep * 4, _hf4); + vst1_u16(p0 + out_hstep * 5, _hf5); + vst1_u16(p0 + out_hstep * 6, _hf6); + vst1_u16(p0 + out_hstep * 7, _hf7); + p0 += 4; + } + + pp += 32; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // e0 f0 g0 h0 + // e1 f1 g1 h1 +#else + // from + // a0 b1 c0 d1 + // e0 f1 g0 h1 + // a1 b0 c1 d0 + // e1 f0 g1 h0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + { + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum2); + int32x4x2_t _t1 = vzipq_s32(_sum1, _sum3); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[0]), vget_low_s32(_t1.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[0]), vget_high_s32(_t1.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale1); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c1); + _f3 = vaddq_f32(_f3, _c1); + } + if (broadcast_type_C == 3) + { + float32x4_t _c2; + float32x4_t _c3; +#if __aarch64__ + if (c_elempack == 8) + { + uint16x8_t _cc0 = vld1q_u16(pC); + uint16x8_t _cc1 = vld1q_u16(pC + 8); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_cc0)); + _c1 = vcvt_f32_f16((float16x4_t)vget_low_u16(_cc1)); + _c2 = vcvt_f32_f16((float16x4_t)vget_high_u16(_cc0)); + _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_cc1)); + pC += 16; + } +#endif // __aarch64__ + if (c_elempack == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + c_hstep * 4); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); + _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); + _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); + pC += 8; + } + if (c_elempack == 1) + { + uint16x8_t _c01 = uint16x8_t(); + _c01 = vsetq_lane_u16(pC[0], _c01, 0); + _c01 = vsetq_lane_u16(pC[c_hstep], _c01, 1); + _c01 = vsetq_lane_u16(pC[c_hstep * 2], _c01, 2); + _c01 = vsetq_lane_u16(pC[c_hstep * 3], _c01, 3); + _c01 = vsetq_lane_u16(pC[1], _c01, 4); + _c01 = vsetq_lane_u16(pC[c_hstep + 1], _c01, 5); + _c01 = vsetq_lane_u16(pC[c_hstep * 2 + 1], _c01, 6); + _c01 = vsetq_lane_u16(pC[c_hstep * 3 + 1], _c01, 7); + uint16x8_t _c23 = uint16x8_t(); + _c23 = vsetq_lane_u16(pC[c_hstep * 4], _c23, 0); + _c23 = vsetq_lane_u16(pC[c_hstep * 5], _c23, 1); + _c23 = vsetq_lane_u16(pC[c_hstep * 6], _c23, 2); + _c23 = vsetq_lane_u16(pC[c_hstep * 7], _c23, 3); + _c23 = vsetq_lane_u16(pC[c_hstep * 4 + 1], _c23, 4); + _c23 = vsetq_lane_u16(pC[c_hstep * 5 + 1], _c23, 5); + _c23 = vsetq_lane_u16(pC[c_hstep * 6 + 1], _c23, 6); + _c23 = vsetq_lane_u16(pC[c_hstep * 7 + 1], _c23, 7); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); + _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); + _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); + pC += 2; + } + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(float16_to_float32(pC[0]) * beta); + _c1 = vdupq_n_f32(float16_to_float32(pC[1]) * beta); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + pC += 2; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + + uint16x4_t _hf0 = (uint16x4_t)vcvt_f16_f32(_f0); + uint16x4_t _hf1 = (uint16x4_t)vcvt_f16_f32(_f1); + uint16x4_t _hf2 = (uint16x4_t)vcvt_f16_f32(_f2); + uint16x4_t _hf3 = (uint16x4_t)vcvt_f16_f32(_f3); + +#if __aarch64__ + if (out_elempack == 8) + { + vst1q_u16(p0, vcombine_u16(_hf0, _hf2)); + vst1q_u16(p0 + 8, vcombine_u16(_hf1, _hf3)); + p0 += 16; + } +#endif // __aarch64__ + if (out_elempack == 4) + { + vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); + vst1q_u16(p0 + out_hstep * 4, vcombine_u16(_hf2, _hf3)); + p0 += 8; + } + if (out_elempack == 1) + { + p0[0] = vget_lane_u16(_hf0, 0); + p0[1] = vget_lane_u16(_hf1, 0); + p0[out_hstep] = vget_lane_u16(_hf0, 1); + p0[out_hstep + 1] = vget_lane_u16(_hf1, 1); + p0[out_hstep * 2] = vget_lane_u16(_hf0, 2); + p0[out_hstep * 2 + 1] = vget_lane_u16(_hf1, 2); + p0[out_hstep * 3] = vget_lane_u16(_hf0, 3); + p0[out_hstep * 3 + 1] = vget_lane_u16(_hf1, 3); + p0[out_hstep * 4] = vget_lane_u16(_hf2, 0); + p0[out_hstep * 4 + 1] = vget_lane_u16(_hf3, 0); + p0[out_hstep * 5] = vget_lane_u16(_hf2, 1); + p0[out_hstep * 5 + 1] = vget_lane_u16(_hf3, 1); + p0[out_hstep * 6] = vget_lane_u16(_hf2, 2); + p0[out_hstep * 6 + 1] = vget_lane_u16(_hf3, 2); + p0[out_hstep * 7] = vget_lane_u16(_hf2, 3); + p0[out_hstep * 7 + 1] = vget_lane_u16(_hf3, 3); + p0 += 2; + } + + pp += 16; + } + for (; jj < max_jj; jj++) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + if (broadcast_type_C == 3) + { +#if __aarch64__ + if (c_elempack == 8) + { + uint16x8_t _c = vld1q_u16(pC); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c)); + pC += 8; + } +#endif // __aarch64__ + if (c_elempack == 4) + { + _c0 = vcvt_f32_f16((float16x4_t)vld1_u16(pC)); + _c1 = vcvt_f32_f16((float16x4_t)vld1_u16(pC + c_hstep * 4)); + pC += 4; + } + if (c_elempack == 1) + { + uint16x8_t _c01 = uint16x8_t(); + _c01 = vsetq_lane_u16(pC[0], _c01, 0); + _c01 = vsetq_lane_u16(pC[c_hstep], _c01, 1); + _c01 = vsetq_lane_u16(pC[c_hstep * 2], _c01, 2); + _c01 = vsetq_lane_u16(pC[c_hstep * 3], _c01, 3); + _c01 = vsetq_lane_u16(pC[c_hstep * 4], _c01, 4); + _c01 = vsetq_lane_u16(pC[c_hstep * 5], _c01, 5); + _c01 = vsetq_lane_u16(pC[c_hstep * 6], _c01, 6); + _c01 = vsetq_lane_u16(pC[c_hstep * 7], _c01, 7); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); + pC += 1; + } + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(float16_to_float32(pC[0]) * beta); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + pC += 1; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + + uint16x4_t _hf0 = (uint16x4_t)vcvt_f16_f32(_f0); + uint16x4_t _hf1 = (uint16x4_t)vcvt_f16_f32(_f1); + +#if __aarch64__ + if (out_elempack == 8) + { + vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); + p0 += 8; + } +#endif // __aarch64__ + if (out_elempack == 4) + { + vst1_u16(p0, _hf0); + vst1_u16(p0 + out_hstep * 4, _hf1); + p0 += 4; + } + if (out_elempack == 1) + { + p0[0] = vget_lane_u16(_hf0, 0); + p0[out_hstep] = vget_lane_u16(_hf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_hf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_hf0, 3); + p0[out_hstep * 4] = vget_lane_u16(_hf1, 0); + p0[out_hstep * 5] = vget_lane_u16(_hf1, 1); + p0[out_hstep * 6] = vget_lane_u16(_hf1, 2); + p0[out_hstep * 7] = vget_lane_u16(_hf1, 3); + p0++; + } + + pp += 8; + } + } + for (; ii + 3 < max_ii; ii += 4) + { + unsigned short* p0 = (unsigned short*)top_blob + (i + ii) * out_hstep + j * out_elempack; + + float32x4_t _descale = vld1q_f32((const float*)descales + ii); + + float32x4_t _c0; + if (pC) + { + if (broadcast_type_C == 0) + { + _c0 = vdupq_n_f32(float16_to_float32(pC[0]) * beta); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const unsigned short*)C + i + ii; + _c0 = vcvt_f32_f16((float16x4_t)vld1_u16(pC)); + _c0 = vmulq_n_f32(_c0, beta); + } + if (broadcast_type_C == 3) + { + pC = (const unsigned short*)C + (i + ii) * c_hstep + j * c_elempack; + } + if (broadcast_type_C == 4) + { + pC = (const unsigned short*)C + j; + } + } + + int jj = 0; +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 +#else + // from + // a0 b1 c2 d3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // c7 d6 a5 b4 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + { + _sum4 = vrev64q_s32(_sum4); + _sum5 = vrev64q_s32(_sum5); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + _sum4 = vextq_s32(_sum4, _sum4, 2); + _sum5 = vextq_s32(_sum5, _sum5, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 3) + { + uint16x8_t _c01; + uint16x8_t _c23; + uint16x8_t _c45; + uint16x8_t _c67; + if (c_elempack == 4) + { + _c01 = vld1q_u16(pC); + _c23 = vld1q_u16(pC + 8); + _c45 = vld1q_u16(pC + 16); + _c67 = vld1q_u16(pC + 24); + pC += 32; + } + if (c_elempack == 1) + { + _c01 = vld1q_u16(pC); + _c23 = vld1q_u16(pC + c_hstep); + _c45 = vld1q_u16(pC + c_hstep * 2); + _c67 = vld1q_u16(pC + c_hstep * 3); + transpose8x4_u16(_c01, _c23, _c45, _c67); + pC += 8; + } + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); + float32x4_t _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); + float32x4_t _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); + float32x4_t _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); + float32x4_t _c4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c45)); + float32x4_t _c5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c45)); + float32x4_t _c6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c67)); + float32x4_t _c7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c67)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + } + if (broadcast_type_C == 4) + { + uint16x8_t _c = vld1q_u16(pC); + float32x4_t _cc0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c)); + float32x4_t _cc1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c)); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _cc0 = vmulq_f32(_cc0, _beta); + _cc1 = vmulq_f32(_cc1, _beta); + } + _c0 = vdupq_laneq_f32(_cc0, 0); + float32x4_t _c1 = vdupq_laneq_f32(_cc0, 1); + float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); + float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); + float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); + float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); + float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); + float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + pC += 8; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + + uint16x4_t _hf0 = (uint16x4_t)vcvt_f16_f32(_f0); + uint16x4_t _hf1 = (uint16x4_t)vcvt_f16_f32(_f1); + uint16x4_t _hf2 = (uint16x4_t)vcvt_f16_f32(_f2); + uint16x4_t _hf3 = (uint16x4_t)vcvt_f16_f32(_f3); + uint16x4_t _hf4 = (uint16x4_t)vcvt_f16_f32(_f4); + uint16x4_t _hf5 = (uint16x4_t)vcvt_f16_f32(_f5); + uint16x4_t _hf6 = (uint16x4_t)vcvt_f16_f32(_f6); + uint16x4_t _hf7 = (uint16x4_t)vcvt_f16_f32(_f7); + + if (out_elempack == 4) + { + vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); + vst1q_u16(p0 + 8, vcombine_u16(_hf2, _hf3)); + vst1q_u16(p0 + 16, vcombine_u16(_hf4, _hf5)); + vst1q_u16(p0 + 24, vcombine_u16(_hf6, _hf7)); + p0 += 32; + } + if (out_elempack == 1) + { + transpose4x4_u16(_hf0, _hf1, _hf2, _hf3); + transpose4x4_u16(_hf4, _hf5, _hf6, _hf7); + vst1q_u16(p0, vcombine_u16(_hf0, _hf4)); + vst1q_u16(p0 + out_hstep, vcombine_u16(_hf1, _hf5)); + vst1q_u16(p0 + out_hstep * 2, vcombine_u16(_hf2, _hf6)); + vst1q_u16(p0 + out_hstep * 3, vcombine_u16(_hf3, _hf7)); + p0 += 8; + } + + pp += 32; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 +#else + // from + // a0 b1 c2 d3 + // c0 d1 a2 b3 + // a3 b2 c1 d0 + // c3 d2 a1 b0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + { + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + _sum2 = vextq_s32(_sum2, _sum2, 2); + _sum3 = vextq_s32(_sum3, _sum3, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum3); + int32x4x2_t _t1 = vzipq_s32(_sum1, _sum2); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 3) + { + float32x4_t _c1; + float32x4_t _c2; + float32x4_t _c3; + if (c_elempack == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); + _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); + _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); + pC += 16; + } + if (c_elempack == 1) + { + uint16x4_t _cc0 = vld1_u16(pC); + uint16x4_t _cc1 = vld1_u16(pC + c_hstep * 1); + uint16x4_t _cc2 = vld1_u16(pC + c_hstep * 2); + uint16x4_t _cc3 = vld1_u16(pC + c_hstep * 3); + transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); + _c0 = vcvt_f32_f16((float16x4_t)_cc0); + _c1 = vcvt_f32_f16((float16x4_t)_cc1); + _c2 = vcvt_f32_f16((float16x4_t)_cc2); + _c3 = vcvt_f32_f16((float16x4_t)_cc3); + pC += 4; + } + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + } + if (broadcast_type_C == 4) + { + float32x4_t _c = vcvt_f32_f16((float16x4_t)vld1_u16(pC)); + _c = vmulq_n_f32(_c, beta); +#if __aarch64__ + _c0 = vdupq_laneq_f32(_c, 0); + float32x4_t _c1 = vdupq_laneq_f32(_c, 1); + float32x4_t _c2 = vdupq_laneq_f32(_c, 2); + float32x4_t _c3 = vdupq_laneq_f32(_c, 3); +#else + _c0 = vdupq_lane_f32(vget_low_f32(_c), 0); + float32x4_t _c1 = vdupq_lane_f32(vget_low_f32(_c), 1); + float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_c), 0); + float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_c), 1); +#endif + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 4; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + + uint16x4_t _hf0 = (uint16x4_t)vcvt_f16_f32(_f0); + uint16x4_t _hf1 = (uint16x4_t)vcvt_f16_f32(_f1); + uint16x4_t _hf2 = (uint16x4_t)vcvt_f16_f32(_f2); + uint16x4_t _hf3 = (uint16x4_t)vcvt_f16_f32(_f3); + + if (out_elempack == 4) + { + vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); + vst1q_u16(p0 + 8, vcombine_u16(_hf2, _hf3)); + p0 += 16; + } + if (out_elempack == 1) + { + transpose4x4_u16(_hf0, _hf1, _hf2, _hf3); + vst1_u16(p0, _hf0); + vst1_u16(p0 + out_hstep, _hf1); + vst1_u16(p0 + out_hstep * 2, _hf2); + vst1_u16(p0 + out_hstep * 3, _hf3); + p0 += 4; + } + + pp += 16; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 +#else + // from + // a0 b1 c0 d1 + // a1 b0 c1 d0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + { + _sum1 = vrev64q_s32(_sum1); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); + _sum1 = vrev64q_s32(_sum1); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 3) + { + uint16x8_t _c; + if (c_elempack == 4) + { + _c = vld1q_u16(pC); + pC += 8; + } + if (c_elempack == 1) + { + _c = uint16x8_t(); + _c = vsetq_lane_u16(pC[0], _c, 0); + _c = vsetq_lane_u16(pC[c_hstep], _c, 1); + _c = vsetq_lane_u16(pC[c_hstep * 2], _c, 2); + _c = vsetq_lane_u16(pC[c_hstep * 3], _c, 3); + _c = vsetq_lane_u16(pC[1], _c, 4); + _c = vsetq_lane_u16(pC[c_hstep + 1], _c, 5); + _c = vsetq_lane_u16(pC[c_hstep * 2 + 1], _c, 6); + _c = vsetq_lane_u16(pC[c_hstep * 3 + 1], _c, 7); + pC += 2; + } + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c)); + float32x4_t _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(float16_to_float32(pC[0]) * beta); + float32x4_t _c1 = vdupq_n_f32(float16_to_float32(pC[1]) * beta); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 2; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + + uint16x4_t _hf0 = (uint16x4_t)vcvt_f16_f32(_f0); + uint16x4_t _hf1 = (uint16x4_t)vcvt_f16_f32(_f1); + + if (out_elempack == 4) + { + vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); + p0 += 8; + } + if (out_elempack == 1) + { + p0[0] = vget_lane_u16(_hf0, 0); + p0[1] = vget_lane_u16(_hf1, 0); + p0[out_hstep] = vget_lane_u16(_hf0, 1); + p0[out_hstep + 1] = vget_lane_u16(_hf1, 1); + p0[out_hstep * 2] = vget_lane_u16(_hf0, 2); + p0[out_hstep * 2 + 1] = vget_lane_u16(_hf1, 2); + p0[out_hstep * 3] = vget_lane_u16(_hf0, 3); + p0[out_hstep * 3 + 1] = vget_lane_u16(_hf1, 3); + p0 += 2; + } + + pp += 8; + } + for (; jj < max_jj; jj++) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 3) + { + uint16x4_t _c; + if (c_elempack == 4) + { + _c = vld1_u16(pC); + pC += 4; + } + if (c_elempack == 1) + { + _c = uint16x4_t(); + _c = vset_lane_u16(pC[0], _c, 0); + _c = vset_lane_u16(pC[c_hstep], _c, 1); + _c = vset_lane_u16(pC[c_hstep * 2], _c, 2); + _c = vset_lane_u16(pC[c_hstep * 3], _c, 3); + pC += 1; + } + _c0 = vcvt_f32_f16((float16x4_t)_c); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(float16_to_float32(pC[0]) * beta); + _f0 = vaddq_f32(_f0, _c0); + pC += 1; + } + } + + _f0 = vmulq_n_f32(_f0, alpha); + + uint16x4_t _hf0 = (uint16x4_t)vcvt_f16_f32(_f0); + + if (out_elempack == 4) + { + vst1_u16(p0, _hf0); + p0 += 4; + } + if (out_elempack == 1) + { + p0[0] = vget_lane_u16(_hf0, 0); + p0[out_hstep] = vget_lane_u16(_hf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_hf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_hf0, 3); + p0++; + } + + pp += 4; + } + } +#endif // __ARM_NEON + for (; ii + 1 < max_ii; ii += 2) + { + // out_elempack == 1 + unsigned short* p0 = (unsigned short*)top_blob + (i + ii) * out_hstep + j; + + const float descale0 = descales[ii]; + const float descale1 = descales[ii + 1]; +#if __ARM_NEON + float32x2_t _descale = vld1_f32((const float*)descales + ii); +#endif + + float c0; + float c1; +#if __ARM_NEON + float32x4_t _c0; + float32x4_t _c1; +#endif + if (pC) + { + if (broadcast_type_C == 0) + { + c0 = float16_to_float32(pC[0]) * beta; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); +#endif + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const unsigned short*)C + i + ii; + c0 = float16_to_float32(pC[0]) * beta; + c1 = float16_to_float32(pC[1]) * beta; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); + _c1 = vdupq_n_f32(c1); +#endif + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + pC = (const unsigned short*)C + (i + ii) * c_hstep + j; + } + if (broadcast_type_C == 4) + { + pC = (const unsigned short*)C + j; + } + } + + int jj = 0; +#if __ARM_NEON +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale, 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale, 0); + float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), _descale, 1); + float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), _descale, 1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c1); + _f3 = vaddq_f32(_f3, _c1); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + c_hstep); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); + float32x4_t _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); + float32x4_t _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); + float32x4_t _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + pC += 8; + } + if (broadcast_type_C == 4) + { + uint16x8_t _c = vld1q_u16(pC); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c)); + float32x4_t _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c)); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _c0 = vmulq_f32(_c0, _beta); + _c1 = vmulq_f32(_c1, _beta); + } + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + pC += 8; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + + vst1q_u16(p0, vcombine_u16((uint16x4_t)vcvt_f16_f32(_f0), (uint16x4_t)vcvt_f16_f32(_f1))); + vst1q_u16(p0 + out_hstep, vcombine_u16((uint16x4_t)vcvt_f16_f32(_f2), (uint16x4_t)vcvt_f16_f32(_f3))); + + pp += 16; + p0 += 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale, 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale, 1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + _c0 = vcvt_f32_f16((float16x4_t)vld1_u16(pC)); + float32x4_t _c1 = vcvt_f32_f16((float16x4_t)vld1_u16(pC + c_hstep)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + } + pC += 4; + } + if (broadcast_type_C == 4) + { + _c0 = vcvt_f32_f16((float16x4_t)vld1_u16(pC)); + _c0 = vmulq_n_f32(_c0, beta); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + pC += 4; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + + vst1_u16(p0, (uint16x4_t)vcvt_f16_f32(_f0)); + vst1_u16(p0 + out_hstep, (uint16x4_t)vcvt_f16_f32(_f1)); + + pp += 8; + p0 += 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + + float32x2x2_t _descale01 = vzip_f32(_descale, _descale); + float32x4_t _descale0011 = vcombine_f32(_descale01.val[0], _descale01.val[1]); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0011); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + float32x4_t _c0011 = vcombine_f32(vget_low_f32(_c0), vget_high_f32(_c1)); + _f0 = vaddq_f32(_f0, _c0011); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + uint16x4_t _c = uint16x4_t(); + _c = vset_lane_u16(pC[0], _c, 0); + _c = vset_lane_u16(pC[1], _c, 1); + _c = vset_lane_u16(pC[c_hstep], _c, 2); + _c = vset_lane_u16(pC[c_hstep + 1], _c, 3); + _c0 = vcvt_f32_f16((float16x4_t)_c); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + pC += 2; + } + if (broadcast_type_C == 4) + { + uint16x4_t _c = uint16x4_t(); + _c = vset_lane_u16(pC[0], _c, 0); + _c = vset_lane_u16(pC[1], _c, 1); + _c = vset_lane_u16(pC[0], _c, 2); + _c = vset_lane_u16(pC[1], _c, 3); + _c0 = vcvt_f32_f16((float16x4_t)_c); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + pC += 2; + } + } + + _f0 = vmulq_n_f32(_f0, alpha); + + uint16x4_t _hf0 = (uint16x4_t)vcvt_f16_f32(_f0); + + p0[0] = vget_lane_u16(_hf0, 0); + p0[1] = vget_lane_u16(_hf0, 1); + p0[out_hstep] = vget_lane_u16(_hf0, 2); + p0[out_hstep + 1] = vget_lane_u16(_hf0, 3); + + pp += 4; + p0 += 2; + } +#endif // __ARM_NEON + for (; jj < max_jj; jj++) + { + float f0 = pp[0] * descale0; + float f1 = pp[1] * descale1; + + if (pC) + { + if (broadcast_type_C == 0) + { + f0 += c0; + f1 += c0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + f0 += c0; + f1 += c1; + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + f0 += float16_to_float32(pC[0]) * beta; + f1 += float16_to_float32(pC[c_hstep]) * beta; + pC += 1; + } + if (broadcast_type_C == 4) + { + f0 += float16_to_float32(pC[0]) * beta; + f1 += float16_to_float32(pC[0]) * beta; + pC += 1; + } + } + + if (alpha != 1.f) + { + f0 *= alpha; + f1 *= alpha; + } + + p0[0] = float32_to_float16(f0); + p0[out_hstep] = float32_to_float16(f1); + + pp += 2; + p0++; + } + } + for (; ii < max_ii; ii += 1) + { + // out_elempack == 1 + unsigned short* p0 = (unsigned short*)top_blob + (i + ii) * out_hstep + j; + + const float descale = descales[ii]; +#if __ARM_NEON + float32x4_t _descale = vdupq_n_f32(descale); +#endif + + float c0; +#if __ARM_NEON + float32x4_t _c0; +#endif + if (pC) + { + if (broadcast_type_C == 0) + { + c0 = float16_to_float32(pC[0]) * beta; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); +#endif + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const unsigned short*)C + i + ii; + c0 = float16_to_float32(pC[0]) * beta; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); +#endif + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + pC = (const unsigned short*)C + (i + ii) * c_hstep + j; + } + if (broadcast_type_C == 4) + { + pC = (const unsigned short*)C + j; + } + } + + int jj = 0; +#if __ARM_NEON + for (; jj + 15 < max_jj; jj += 16) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); + float32x4_t _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); + float32x4_t _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); + float32x4_t _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + pC += 16; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + + vst1q_u16(p0, vcombine_u16((uint16x4_t)vcvt_f16_f32(_f0), (uint16x4_t)vcvt_f16_f32(_f1))); + vst1q_u16(p0 + 8, vcombine_u16((uint16x4_t)vcvt_f16_f32(_f2), (uint16x4_t)vcvt_f16_f32(_f3))); + + pp += 16; + p0 += 16; + } + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + uint16x8_t _c01 = vld1q_u16(pC); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); + float32x4_t _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + } + pC += 8; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + + vst1q_u16(p0, vcombine_u16((uint16x4_t)vcvt_f16_f32(_f0), (uint16x4_t)vcvt_f16_f32(_f1))); + + pp += 8; + p0 += 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + _c0 = vcvt_f32_f16((float16x4_t)vld1_u16(pC)); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + pC += 4; + } + } + + _f0 = vmulq_n_f32(_f0, alpha); + + vst1_u16(p0, (uint16x4_t)vcvt_f16_f32(_f0)); + + pp += 4; + p0 += 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + float32x2_t _f0 = vmul_f32(vcvt_f32_s32(vld1_s32(pp)), vget_low_f32(_descale)); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vadd_f32(_f0, vget_low_f32(_c0)); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + float32x2_t _cc = float32x2_t(); + _cc = vset_lane_f32(float16_to_float32(pC[0]), _cc, 0); + _cc = vset_lane_f32(float16_to_float32(pC[1]), _cc, 1); + _f0 = vmla_n_f32(_f0, _cc, beta); + pC += 2; + } + } + + _f0 = vmul_n_f32(_f0, alpha); + + p0[0] = float32_to_float16(vget_lane_f32(_f0, 0)); + p0[1] = float32_to_float16(vget_lane_f32(_f0, 1)); + + pp += 2; + p0 += 2; + } +#endif // __ARM_NEON + for (; jj < max_jj; jj++) + { + float f0 = pp[0] * descale; + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + f0 += c0; + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + f0 += float16_to_float32(pC[0]) * beta; + pC += 1; + } + } + + f0 *= alpha; + + p0[0] = float32_to_float16(f0); + + pp += 1; + p0++; + } + } +} + +static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + transpose_unpack_output_tile_int32_to_fp16_asimddp(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales, alpha, beta); + return; + } +#endif + + const int out_elempack = top_blob.elempack; + const int out_hstep = top_blob.dims == 3 ? (int)top_blob.cstep : top_blob.w; + + const int c_hstep = C.dims == 3 ? (int)C.cstep : C.w; + const int c_elempack = C.elempack; + const unsigned short* pC = C; + + // NCNN_LOGE("transpose_unpack_output_tile_int32_to_fp16 %d %d %d %d %d %d %d", i, max_ii, j, max_jj, out_elempack, broadcast_type_C, c_elempack); + + const int* pp = topT; + + int ii = 0; +#if __ARM_NEON + for (; ii + 7 < max_ii; ii += 8) + { + unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack; + + float32x4_t _descale0 = vld1q_f32((const float*)descales + ii); + float32x4_t _descale1 = vld1q_f32((const float*)descales + ii + 4); + + float32x4_t _c0; + float32x4_t _c1; + if (pC) + { + if (broadcast_type_C == 0) + { + _c0 = vdupq_n_f32(float16_to_float32(pC[0]) * beta); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const unsigned short*)C + i + ii; + uint16x8_t _c = vld1q_u16(pC); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c)); + _c0 = vmulq_n_f32(_c0, beta); + _c1 = vmulq_n_f32(_c1, beta); + } + if (broadcast_type_C == 3) + { + pC = (const unsigned short*)C + (i + ii) * c_hstep + j * c_elempack; + } + if (broadcast_type_C == 4) + { + pC = (const unsigned short*)C + j; + } + } + + int jj = 0; +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + int32x4_t _sum8 = vld1q_s32(pp + 32); + int32x4_t _sum9 = vld1q_s32(pp + 36); + int32x4_t _suma = vld1q_s32(pp + 40); + int32x4_t _sumb = vld1q_s32(pp + 44); + int32x4_t _sumc = vld1q_s32(pp + 48); + int32x4_t _sumd = vld1q_s32(pp + 52); + int32x4_t _sume = vld1q_s32(pp + 56); + int32x4_t _sumf = vld1q_s32(pp + 60); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 +#else + // from + // a0 b1 c2 d3 + // e4 f5 g6 h7 + // e0 f1 g2 h3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // g4 h5 e6 f7 + // g0 h1 e2 f3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // e7 f6 g5 h4 + // e3 f2 g1 h0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // g7 h6 e5 f4 + // g3 h2 e1 f0 + // c7 d6 a5 b4 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 + { + _sum8 = vrev64q_s32(_sum8); + _sum9 = vrev64q_s32(_sum9); + _suma = vrev64q_s32(_suma); + _sumb = vrev64q_s32(_sumb); + _sumc = vrev64q_s32(_sumc); + _sumd = vrev64q_s32(_sumd); + _sume = vrev64q_s32(_sume); + _sumf = vrev64q_s32(_sumf); + _sum8 = vextq_s32(_sum8, _sum8, 2); + _sum9 = vextq_s32(_sum9, _sum9, 2); + _suma = vextq_s32(_suma, _suma, 2); + _sumb = vextq_s32(_sumb, _sumb, 2); + _sumc = vextq_s32(_sumc, _sumc, 2); + _sumd = vextq_s32(_sumd, _sumd, 2); + _sume = vextq_s32(_sume, _sume, 2); + _sumf = vextq_s32(_sumf, _sumf, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sumc); + int32x4x2_t _t1 = vzipq_s32(_sum4, _sum8); + int32x4x2_t _t2 = vzipq_s32(_sum2, _sume); + int32x4x2_t _t3 = vzipq_s32(_sum6, _suma); + int32x4x2_t _t4 = vzipq_s32(_sum3, _sumf); + int32x4x2_t _t5 = vzipq_s32(_sum7, _sumb); + int32x4x2_t _t6 = vzipq_s32(_sum1, _sumd); + int32x4x2_t _t7 = vzipq_s32(_sum5, _sum9); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum8 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); + _sum9 = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); + _suma = vcombine_s32(vget_low_s32(_t5.val[1]), vget_low_s32(_t4.val[1])); + _sumb = vcombine_s32(vget_high_s32(_t5.val[1]), vget_high_s32(_t4.val[1])); + _sumc = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); + _sumd = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); + _sume = vcombine_s32(vget_low_s32(_t7.val[1]), vget_low_s32(_t6.val[1])); + _sumf = vcombine_s32(vget_high_s32(_t7.val[1]), vget_high_s32(_t6.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + _sum9 = vrev64q_s32(_sum9); + _sumb = vrev64q_s32(_sumb); + _sumd = vrev64q_s32(_sumd); + _sumf = vrev64q_s32(_sumf); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum8), _descale0); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum9), _descale0); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_suma), _descale0); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sumb), _descale0); + float32x4_t _f8 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); + float32x4_t _f9 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); + float32x4_t _fa = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); + float32x4_t _fb = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); + float32x4_t _fc = vmulq_f32(vcvtq_f32_s32(_sumc), _descale1); + float32x4_t _fd = vmulq_f32(vcvtq_f32_s32(_sumd), _descale1); + float32x4_t _fe = vmulq_f32(vcvtq_f32_s32(_sume), _descale1); + float32x4_t _ff = vmulq_f32(vcvtq_f32_s32(_sumf), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c0); + _fa = vaddq_f32(_fa, _c0); + _fb = vaddq_f32(_fb, _c0); + _fc = vaddq_f32(_fc, _c0); + _fd = vaddq_f32(_fd, _c0); + _fe = vaddq_f32(_fe, _c0); + _ff = vaddq_f32(_ff, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + _f8 = vaddq_f32(_f8, _c1); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c1); + _fb = vaddq_f32(_fb, _c1); + _fc = vaddq_f32(_fc, _c1); + _fd = vaddq_f32(_fd, _c1); + _fe = vaddq_f32(_fe, _c1); + _ff = vaddq_f32(_ff, _c1); + } + if (broadcast_type_C == 3) + { +#if __aarch64__ + if (c_elempack == 8) + { + uint16x8_t _c08 = vld1q_u16(pC); + uint16x8_t _c19 = vld1q_u16(pC + 8); + uint16x8_t _c2a = vld1q_u16(pC + 16); + uint16x8_t _c3b = vld1q_u16(pC + 24); + uint16x8_t _c4c = vld1q_u16(pC + 32); + uint16x8_t _c5d = vld1q_u16(pC + 40); + uint16x8_t _c6e = vld1q_u16(pC + 48); + uint16x8_t _c7f = vld1q_u16(pC + 56); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c08)); + _c1 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c19)); + float32x4_t _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c2a)); + float32x4_t _c3 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c3b)); + float32x4_t _c4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c4c)); + float32x4_t _c5 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c5d)); + float32x4_t _c6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c6e)); + float32x4_t _c7 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c7f)); + float32x4_t _c8 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c08)); + float32x4_t _c9 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c19)); + float32x4_t _ca = vcvt_f32_f16((float16x4_t)vget_high_u16(_c2a)); + float32x4_t _cb = vcvt_f32_f16((float16x4_t)vget_high_u16(_c3b)); + float32x4_t _cc = vcvt_f32_f16((float16x4_t)vget_high_u16(_c4c)); + float32x4_t _cd = vcvt_f32_f16((float16x4_t)vget_high_u16(_c5d)); + float32x4_t _ce = vcvt_f32_f16((float16x4_t)vget_high_u16(_c6e)); + float32x4_t _cf = vcvt_f32_f16((float16x4_t)vget_high_u16(_c7f)); + + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + _f8 = vaddq_f32(_f8, _c8); + _f9 = vaddq_f32(_f9, _c9); + _fa = vaddq_f32(_fa, _ca); + _fb = vaddq_f32(_fb, _cb); + _fc = vaddq_f32(_fc, _cc); + _fd = vaddq_f32(_fd, _cd); + _fe = vaddq_f32(_fe, _ce); + _ff = vaddq_f32(_ff, _cf); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + _f8 = vmlaq_f32(_f8, _c8, _beta); + _f9 = vmlaq_f32(_f9, _c9, _beta); + _fa = vmlaq_f32(_fa, _ca, _beta); + _fb = vmlaq_f32(_fb, _cb, _beta); + _fc = vmlaq_f32(_fc, _cc, _beta); + _fd = vmlaq_f32(_fd, _cd, _beta); + _fe = vmlaq_f32(_fe, _ce, _beta); + _ff = vmlaq_f32(_ff, _cf, _beta); + } + pC += 64; + } +#endif // __aarch64__ + if (c_elempack == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + uint16x8_t _c45 = vld1q_u16(pC + 16); + uint16x8_t _c67 = vld1q_u16(pC + 24); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); + float32x4_t _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); + float32x4_t _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); + float32x4_t _c4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c45)); + float32x4_t _c5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c45)); + float32x4_t _c6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c67)); + float32x4_t _c7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c67)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c01 = vld1q_u16(pC + c_hstep * 4); + _c23 = vld1q_u16(pC + c_hstep * 4 + 8); + _c45 = vld1q_u16(pC + c_hstep * 4 + 16); + _c67 = vld1q_u16(pC + c_hstep * 4 + 24); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); + _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); + _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); + _c4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c45)); + _c5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c45)); + _c6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c67)); + _c7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c67)); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); + } + pC += 32; + } + if (c_elempack == 1) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + c_hstep); + uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); + uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); + transpose8x4_u16(_c01, _c23, _c45, _c67); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); + float32x4_t _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); + float32x4_t _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); + float32x4_t _c4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c45)); + float32x4_t _c5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c45)); + float32x4_t _c6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c67)); + float32x4_t _c7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c67)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c01 = vld1q_u16(pC + c_hstep * 4); + _c23 = vld1q_u16(pC + c_hstep * 5); + _c45 = vld1q_u16(pC + c_hstep * 6); + _c67 = vld1q_u16(pC + c_hstep * 7); + transpose8x4_u16(_c01, _c23, _c45, _c67); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); + _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); + _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); + _c4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c45)); + _c5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c45)); + _c6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c67)); + _c7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c67)); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); + } + pC += 8; + } + } + if (broadcast_type_C == 4) + { + uint16x8_t _c = vld1q_u16(pC); + float32x4_t _cc0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c)); + float32x4_t _cc1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c)); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _cc0 = vmulq_f32(_cc0, _beta); + _cc1 = vmulq_f32(_cc1, _beta); + } + _c0 = vdupq_laneq_f32(_cc0, 0); + _c1 = vdupq_laneq_f32(_cc0, 1); + float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); + float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); + float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); + float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); + float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); + float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + pC += 8; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + _f8 = vmulq_f32(_f8, _alpha); + _f9 = vmulq_f32(_f9, _alpha); + _fa = vmulq_f32(_fa, _alpha); + _fb = vmulq_f32(_fb, _alpha); + _fc = vmulq_f32(_fc, _alpha); + _fd = vmulq_f32(_fd, _alpha); + _fe = vmulq_f32(_fe, _alpha); + _ff = vmulq_f32(_ff, _alpha); + } + + uint16x8_t _hf0 = vcombine_u16((uint16x4_t)vcvt_f16_f32(_f0), (uint16x4_t)vcvt_f16_f32(_f8)); + uint16x8_t _hf1 = vcombine_u16((uint16x4_t)vcvt_f16_f32(_f1), (uint16x4_t)vcvt_f16_f32(_f9)); + uint16x8_t _hf2 = vcombine_u16((uint16x4_t)vcvt_f16_f32(_f2), (uint16x4_t)vcvt_f16_f32(_fa)); + uint16x8_t _hf3 = vcombine_u16((uint16x4_t)vcvt_f16_f32(_f3), (uint16x4_t)vcvt_f16_f32(_fb)); + uint16x8_t _hf4 = vcombine_u16((uint16x4_t)vcvt_f16_f32(_f4), (uint16x4_t)vcvt_f16_f32(_fc)); + uint16x8_t _hf5 = vcombine_u16((uint16x4_t)vcvt_f16_f32(_f5), (uint16x4_t)vcvt_f16_f32(_fd)); + uint16x8_t _hf6 = vcombine_u16((uint16x4_t)vcvt_f16_f32(_f6), (uint16x4_t)vcvt_f16_f32(_fe)); + uint16x8_t _hf7 = vcombine_u16((uint16x4_t)vcvt_f16_f32(_f7), (uint16x4_t)vcvt_f16_f32(_ff)); + +#if __aarch64__ + if (out_elempack == 8) + { + transpose8x8_u16(_hf0, _hf1, _hf2, _hf3, _hf4, _hf5, _hf6, _hf7); + vst1q_u16(p0, _hf0); + vst1q_u16(p0 + 8, _hf1); + vst1q_u16(p0 + 16, _hf2); + vst1q_u16(p0 + 24, _hf3); + vst1q_u16(p0 + 32, _hf4); + vst1q_u16(p0 + 40, _hf5); + vst1q_u16(p0 + 48, _hf6); + vst1q_u16(p0 + 56, _hf7); + } +#endif // __aarch64__ + if (out_elempack == 4) + { + uint16x8x4_t _hfa; + uint16x8x4_t _hfb; + _hfa.val[0] = _hf0; + _hfa.val[1] = _hf1; + _hfa.val[2] = _hf2; + _hfa.val[3] = _hf3; + _hfb.val[0] = _hf4; + _hfb.val[1] = _hf5; + _hfb.val[2] = _hf6; + _hfb.val[3] = _hf7; + vst4q_u16(p0, _hfa); + vst4q_u16(p0 + out_hstep * 4, _hfb); + } + if (out_elempack == 1) + { + vst1q_u16(p0, _hf0); + vst1q_u16(p0 + out_hstep, _hf1); + vst1q_u16(p0 + out_hstep * 2, _hf2); + vst1q_u16(p0 + out_hstep * 3, _hf3); + vst1q_u16(p0 + out_hstep * 4, _hf4); + vst1q_u16(p0 + out_hstep * 5, _hf5); + vst1q_u16(p0 + out_hstep * 6, _hf6); + vst1q_u16(p0 + out_hstep * 7, _hf7); + } + + pp += 64; + p0 += out_hstep * 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + +#else + // from + // a0 b1 c2 d3 + // e0 f1 g2 h3 + // c0 d1 a2 b3 + // g0 h1 e2 f3 + // a3 b2 c1 d0 + // e3 f2 g1 h0 + // c3 d2 a1 b0 + // g3 h2 e1 f0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + + { + _sum4 = vrev64q_s32(_sum4); + _sum5 = vrev64q_s32(_sum5); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + _sum4 = vextq_s32(_sum4, _sum4, 2); + _sum5 = vextq_s32(_sum5, _sum5, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c1); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c1); + _f7 = vaddq_f32(_f7, _c1); + } + if (broadcast_type_C == 3) + { +#if __aarch64__ + if (c_elempack == 8) + { + uint16x8_t _c04 = vld1q_u16(pC); + uint16x8_t _c15 = vld1q_u16(pC + 8); + uint16x8_t _c26 = vld1q_u16(pC + 16); + uint16x8_t _c37 = vld1q_u16(pC + 24); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c04)); + _c1 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c15)); + float32x4_t _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c26)); + float32x4_t _c3 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c37)); + float32x4_t _c4 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c04)); + float32x4_t _c5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c15)); + float32x4_t _c6 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c26)); + float32x4_t _c7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c37)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + pC += 32; + } +#endif // __aarch64__ + if (c_elempack == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); + float32x4_t _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); + float32x4_t _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + _c01 = vld1q_u16(pC + c_hstep * 4); + _c23 = vld1q_u16(pC + c_hstep * 4 + 8); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); + _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); + _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); + } + pC += 16; + } + if (c_elempack == 1) + { + uint16x4_t _cc0 = vld1_u16(pC); + uint16x4_t _cc1 = vld1_u16(pC + c_hstep); + uint16x4_t _cc2 = vld1_u16(pC + c_hstep * 2); + uint16x4_t _cc3 = vld1_u16(pC + c_hstep * 3); + transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); + _c0 = vcvt_f32_f16((float16x4_t)_cc0); + _c1 = vcvt_f32_f16((float16x4_t)_cc1); + float32x4_t _c2 = vcvt_f32_f16((float16x4_t)_cc2); + float32x4_t _c3 = vcvt_f32_f16((float16x4_t)_cc3); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + _cc0 = vld1_u16(pC + c_hstep * 4); + _cc1 = vld1_u16(pC + c_hstep * 5); + _cc2 = vld1_u16(pC + c_hstep * 6); + _cc3 = vld1_u16(pC + c_hstep * 7); + transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); + _c0 = vcvt_f32_f16((float16x4_t)_cc0); + _c1 = vcvt_f32_f16((float16x4_t)_cc1); + _c2 = vcvt_f32_f16((float16x4_t)_cc2); + _c3 = vcvt_f32_f16((float16x4_t)_cc3); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); + } + pC += 4; + } + } + if (broadcast_type_C == 4) + { + float32x4_t _c = vcvt_f32_f16((float16x4_t)vld1_u16(pC)); + _c = vmulq_n_f32(_c, beta); +#if __aarch64__ + _c0 = vdupq_laneq_f32(_c, 0); + _c1 = vdupq_laneq_f32(_c, 1); + float32x4_t _c2 = vdupq_laneq_f32(_c, 2); + float32x4_t _c3 = vdupq_laneq_f32(_c, 3); +#else + _c0 = vdupq_lane_f32(vget_low_f32(_c), 0); + _c1 = vdupq_lane_f32(vget_low_f32(_c), 1); + float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_c), 0); + float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_c), 1); +#endif + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + pC += 4; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + + uint16x8_t _hf0 = vcombine_u16((uint16x4_t)vcvt_f16_f32(_f0), (uint16x4_t)vcvt_f16_f32(_f4)); + uint16x8_t _hf1 = vcombine_u16((uint16x4_t)vcvt_f16_f32(_f1), (uint16x4_t)vcvt_f16_f32(_f5)); + uint16x8_t _hf2 = vcombine_u16((uint16x4_t)vcvt_f16_f32(_f2), (uint16x4_t)vcvt_f16_f32(_f6)); + uint16x8_t _hf3 = vcombine_u16((uint16x4_t)vcvt_f16_f32(_f3), (uint16x4_t)vcvt_f16_f32(_f7)); + + if (out_elempack == 4) + { + uint16x8x4_t _hf; + _hf.val[0] = _hf0; + _hf.val[1] = _hf1; + _hf.val[2] = _hf2; + _hf.val[3] = _hf3; + vst4q_u16(p0, _hf); + } + if (out_elempack == 1) + { + vst1q_u16(p0, _hf0); + vst1q_u16(p0 + out_hstep, _hf1); + vst1q_u16(p0 + out_hstep * 2, _hf2); + vst1q_u16(p0 + out_hstep * 3, _hf3); + } + + pp += 32; + p0 += out_hstep * 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // e0 f0 g0 h0 + // e1 f1 g1 h1 +#else + // from + // a0 b1 c0 d1 + // e0 f1 g0 h1 + // a1 b0 c1 d0 + // e1 f0 g1 h0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + { + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum2); + int32x4x2_t _t1 = vzipq_s32(_sum1, _sum3); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[0]), vget_low_s32(_t1.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[0]), vget_high_s32(_t1.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale1); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c1); + _f3 = vaddq_f32(_f3, _c1); + } + if (broadcast_type_C == 3) + { + float32x4_t _c2; + float32x4_t _c3; +#if __aarch64__ + if (c_elempack == 8) + { + uint16x8_t _c02 = vld1q_u16(pC); + uint16x8_t _c13 = vld1q_u16(pC + 8); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c02)); + _c1 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c13)); + _c2 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c02)); + _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c13)); + pC += 16; + } +#endif // __aarch64__ + if (c_elempack == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + c_hstep * 4); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); + _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); + _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); + pC += 8; + } + if (c_elempack == 1) + { + uint16x8_t _c01 = uint16x8_t(); + _c01 = vsetq_lane_u16(pC[0], _c01, 0); + _c01 = vsetq_lane_u16(pC[c_hstep], _c01, 1); + _c01 = vsetq_lane_u16(pC[c_hstep * 2], _c01, 2); + _c01 = vsetq_lane_u16(pC[c_hstep * 3], _c01, 3); + _c01 = vsetq_lane_u16(pC[c_hstep * 4], _c01, 4); + _c01 = vsetq_lane_u16(pC[c_hstep * 5], _c01, 5); + _c01 = vsetq_lane_u16(pC[c_hstep * 6], _c01, 6); + _c01 = vsetq_lane_u16(pC[c_hstep * 7], _c01, 7); + + uint16x8_t _c23 = uint16x8_t(); + _c23 = vsetq_lane_u16(pC[1], _c23, 0); + _c23 = vsetq_lane_u16(pC[c_hstep + 1], _c23, 1); + _c23 = vsetq_lane_u16(pC[c_hstep * 2 + 1], _c23, 2); + _c23 = vsetq_lane_u16(pC[c_hstep * 3 + 1], _c23, 3); + _c23 = vsetq_lane_u16(pC[c_hstep * 4 + 1], _c23, 4); + _c23 = vsetq_lane_u16(pC[c_hstep * 5 + 1], _c23, 5); + _c23 = vsetq_lane_u16(pC[c_hstep * 6 + 1], _c23, 6); + _c23 = vsetq_lane_u16(pC[c_hstep * 7 + 1], _c23, 7); + + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); + _c1 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); + _c2 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); + _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); + pC += 2; + } + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(float16_to_float32(pC[0]) * beta); + _c1 = vdupq_n_f32(float16_to_float32(pC[1]) * beta); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + pC += 2; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + + vst1q_u16(p0, vcombine_u16((uint16x4_t)vcvt_f16_f32(_f0), (uint16x4_t)vcvt_f16_f32(_f2))); + vst1q_u16(p0 + out_hstep, vcombine_u16((uint16x4_t)vcvt_f16_f32(_f1), (uint16x4_t)vcvt_f16_f32(_f3))); + + pp += 16; + p0 += out_hstep * 2; + } + for (; jj < max_jj; jj += 1) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp + 4)), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + if (broadcast_type_C == 3) + { +#if __aarch64__ + if (c_elempack == 8) + { + uint16x8_t _c = vld1q_u16(pC); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c)); + pC += 8; + } +#endif // __aarch64__ + if (c_elempack == 4) + { + _c0 = vcvt_f32_f16((float16x4_t)vld1_u16(pC)); + _c1 = vcvt_f32_f16((float16x4_t)vld1_u16(pC + c_hstep * 4)); + pC += 4; + } + if (c_elempack == 1) + { + uint16x8_t _c01 = uint16x8_t(); + _c01 = vsetq_lane_u16(pC[0], _c01, 0); + _c01 = vsetq_lane_u16(pC[c_hstep], _c01, 1); + _c01 = vsetq_lane_u16(pC[c_hstep * 2], _c01, 2); + _c01 = vsetq_lane_u16(pC[c_hstep * 3], _c01, 3); + _c01 = vsetq_lane_u16(pC[c_hstep * 4], _c01, 4); + _c01 = vsetq_lane_u16(pC[c_hstep * 5], _c01, 5); + _c01 = vsetq_lane_u16(pC[c_hstep * 6], _c01, 6); + _c01 = vsetq_lane_u16(pC[c_hstep * 7], _c01, 7); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); + pC += 1; + } + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(float16_to_float32(pC[0]) * beta); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + pC += 1; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + + vst1q_u16(p0, vcombine_u16((uint16x4_t)vcvt_f16_f32(_f0), (uint16x4_t)vcvt_f16_f32(_f1))); + pp += 8; + p0 += out_hstep; + } + } + for (; ii + 3 < max_ii; ii += 4) + { + unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack; + + float32x4_t _descale = vld1q_f32((const float*)descales + ii); + + float32x4_t _c0; + if (pC) + { + if (broadcast_type_C == 0) + { + _c0 = vdupq_n_f32(float16_to_float32(pC[0]) * beta); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const unsigned short*)C + i + ii; + _c0 = vcvt_f32_f16((float16x4_t)vld1_u16(pC)); + _c0 = vmulq_n_f32(_c0, beta); + } + if (broadcast_type_C == 3) + { + pC = (const unsigned short*)C + (i + ii) * c_hstep + j * c_elempack; + } + if (broadcast_type_C == 4) + { + pC = (const unsigned short*)C + j; + } + } + + int jj = 0; +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 +#else + // from + // a0 b1 c2 d3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // c7 d6 a5 b4 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + { + _sum4 = vrev64q_s32(_sum4); + _sum5 = vrev64q_s32(_sum5); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + _sum4 = vextq_s32(_sum4, _sum4, 2); + _sum5 = vextq_s32(_sum5, _sum5, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 3) + { + uint16x8_t _c01; + uint16x8_t _c23; + uint16x8_t _c45; + uint16x8_t _c67; + if (c_elempack == 4) + { + _c01 = vld1q_u16(pC); + _c23 = vld1q_u16(pC + 8); + _c45 = vld1q_u16(pC + 16); + _c67 = vld1q_u16(pC + 24); + pC += 32; + } + if (c_elempack == 1) + { + _c01 = vld1q_u16(pC); + _c23 = vld1q_u16(pC + c_hstep); + _c45 = vld1q_u16(pC + c_hstep * 2); + _c67 = vld1q_u16(pC + c_hstep * 3); + transpose8x4_u16(_c01, _c23, _c45, _c67); + pC += 8; + } + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); + float32x4_t _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); + float32x4_t _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); + float32x4_t _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); + float32x4_t _c4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c45)); + float32x4_t _c5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c45)); + float32x4_t _c6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c67)); + float32x4_t _c7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c67)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + } + if (broadcast_type_C == 4) + { + uint16x8_t _c = vld1q_u16(pC); + float32x4_t _cc0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c)); + float32x4_t _cc1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c)); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _cc0 = vmulq_f32(_cc0, _beta); + _cc1 = vmulq_f32(_cc1, _beta); + } + _c0 = vdupq_laneq_f32(_cc0, 0); + float32x4_t _c1 = vdupq_laneq_f32(_cc0, 1); + float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); + float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); + float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); + float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); + float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); + float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + pC += 8; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + + uint16x4_t _hf0 = (uint16x4_t)vcvt_f16_f32(_f0); + uint16x4_t _hf1 = (uint16x4_t)vcvt_f16_f32(_f1); + uint16x4_t _hf2 = (uint16x4_t)vcvt_f16_f32(_f2); + uint16x4_t _hf3 = (uint16x4_t)vcvt_f16_f32(_f3); + uint16x4_t _hf4 = (uint16x4_t)vcvt_f16_f32(_f4); + uint16x4_t _hf5 = (uint16x4_t)vcvt_f16_f32(_f5); + uint16x4_t _hf6 = (uint16x4_t)vcvt_f16_f32(_f6); + uint16x4_t _hf7 = (uint16x4_t)vcvt_f16_f32(_f7); + +#if __aarch64__ + if (out_elempack == 8) + { + transpose4x4_u16(_hf0, _hf1, _hf2, _hf3); + transpose4x4_u16(_hf4, _hf5, _hf6, _hf7); + vst1q_u16(p0, vcombine_u16(_hf0, _hf4)); + vst1q_u16(p0 + 8, vcombine_u16(_hf1, _hf5)); + vst1q_u16(p0 + 16, vcombine_u16(_hf2, _hf6)); + vst1q_u16(p0 + 24, vcombine_u16(_hf3, _hf7)); + } +#endif // __aarch64__ + if (out_elempack == 4) + { + uint16x4x4_t _hfa; + uint16x4x4_t _hfb; + _hfa.val[0] = _hf0; + _hfa.val[1] = _hf1; + _hfa.val[2] = _hf2; + _hfa.val[3] = _hf3; + _hfb.val[0] = _hf4; + _hfb.val[1] = _hf5; + _hfb.val[2] = _hf6; + _hfb.val[3] = _hf7; + vst4_u16(p0, _hfa); + vst4_u16(p0 + out_hstep * 4, _hfb); + } + if (out_elempack == 1) + { + vst1_u16(p0, _hf0); + vst1_u16(p0 + out_hstep, _hf1); + vst1_u16(p0 + out_hstep * 2, _hf2); + vst1_u16(p0 + out_hstep * 3, _hf3); + vst1_u16(p0 + out_hstep * 4, _hf4); + vst1_u16(p0 + out_hstep * 5, _hf5); + vst1_u16(p0 + out_hstep * 6, _hf6); + vst1_u16(p0 + out_hstep * 7, _hf7); + } + + pp += 32; + p0 += out_hstep * 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 +#else + // from + // a0 b1 c2 d3 + // c0 d1 a2 b3 + // a3 b2 c1 d0 + // c3 d2 a1 b0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + { + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + _sum2 = vextq_s32(_sum2, _sum2, 2); + _sum3 = vextq_s32(_sum3, _sum3, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum3); + int32x4x2_t _t1 = vzipq_s32(_sum1, _sum2); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 3) + { + float32x4_t _c1; + float32x4_t _c2; + float32x4_t _c3; + if (c_elempack == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); + _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); + _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); + pC += 16; + } + if (c_elempack == 1) + { + uint16x4_t _cc0 = vld1_u16(pC); + uint16x4_t _cc1 = vld1_u16(pC + c_hstep); + uint16x4_t _cc2 = vld1_u16(pC + c_hstep * 2); + uint16x4_t _cc3 = vld1_u16(pC + c_hstep * 3); + transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); + _c0 = vcvt_f32_f16((float16x4_t)_cc0); + _c1 = vcvt_f32_f16((float16x4_t)_cc1); + _c2 = vcvt_f32_f16((float16x4_t)_cc2); + _c3 = vcvt_f32_f16((float16x4_t)_cc3); + pC += 4; + } + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + } + if (broadcast_type_C == 4) + { + float32x4_t _c = vcvt_f32_f16((float16x4_t)vld1_u16(pC)); + _c = vmulq_n_f32(_c, beta); +#if __aarch64__ + _c0 = vdupq_laneq_f32(_c, 0); + float32x4_t _c1 = vdupq_laneq_f32(_c, 1); + float32x4_t _c2 = vdupq_laneq_f32(_c, 2); + float32x4_t _c3 = vdupq_laneq_f32(_c, 3); +#else + _c0 = vdupq_lane_f32(vget_low_f32(_c), 0); + float32x4_t _c1 = vdupq_lane_f32(vget_low_f32(_c), 1); + float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_c), 0); + float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_c), 1); +#endif + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 4; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + + uint16x4_t _hf0 = (uint16x4_t)vcvt_f16_f32(_f0); + uint16x4_t _hf1 = (uint16x4_t)vcvt_f16_f32(_f1); + uint16x4_t _hf2 = (uint16x4_t)vcvt_f16_f32(_f2); + uint16x4_t _hf3 = (uint16x4_t)vcvt_f16_f32(_f3); + + if (out_elempack == 4) + { + uint16x4x4_t _hf; + _hf.val[0] = _hf0; + _hf.val[1] = _hf1; + _hf.val[2] = _hf2; + _hf.val[3] = _hf3; + vst4_u16(p0, _hf); + } + if (out_elempack == 1) + { + vst1_u16(p0, _hf0); + vst1_u16(p0 + out_hstep, _hf1); + vst1_u16(p0 + out_hstep * 2, _hf2); + vst1_u16(p0 + out_hstep * 3, _hf3); + } + + pp += 16; + p0 += out_hstep * 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 +#else + // from + // a0 b1 c0 d1 + // a1 b0 c1 d0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + { + _sum1 = vrev64q_s32(_sum1); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); + _sum1 = vrev64q_s32(_sum1); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 3) + { + uint16x8_t _c; + if (c_elempack == 4) + { + _c = vld1q_u16(pC); + pC += 8; + } + if (c_elempack == 1) + { + _c = uint16x8_t(); + _c = vsetq_lane_u16(pC[0], _c, 0); + _c = vsetq_lane_u16(pC[c_hstep], _c, 1); + _c = vsetq_lane_u16(pC[c_hstep * 2], _c, 2); + _c = vsetq_lane_u16(pC[c_hstep * 3], _c, 3); + _c = vsetq_lane_u16(pC[1], _c, 4); + _c = vsetq_lane_u16(pC[c_hstep + 1], _c, 5); + _c = vsetq_lane_u16(pC[c_hstep * 2 + 1], _c, 6); + _c = vsetq_lane_u16(pC[c_hstep * 3 + 1], _c, 7); + pC += 2; + } + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c)); + float32x4_t _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(float16_to_float32(pC[0]) * beta); + float32x4_t _c1 = vdupq_n_f32(float16_to_float32(pC[1]) * beta); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 2; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + + vst1_u16(p0, (uint16x4_t)vcvt_f16_f32(_f0)); + vst1_u16(p0 + out_hstep, (uint16x4_t)vcvt_f16_f32(_f1)); + + pp += 8; + p0 += out_hstep * 2; + } + for (; jj < max_jj; jj += 1) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 3) + { + uint16x4_t _c; + if (c_elempack == 4) + { + _c = vld1_u16(pC); + pC += 4; + } + if (c_elempack == 1) + { + _c = uint16x4_t(); + _c = vset_lane_u16(pC[0], _c, 0); + _c = vset_lane_u16(pC[c_hstep], _c, 1); + _c = vset_lane_u16(pC[c_hstep * 2], _c, 2); + _c = vset_lane_u16(pC[c_hstep * 3], _c, 3); + pC += 1; + } + _c0 = vcvt_f32_f16((float16x4_t)_c); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(float16_to_float32(pC[0]) * beta); + _f0 = vaddq_f32(_f0, _c0); + pC += 1; + } + } + + _f0 = vmulq_n_f32(_f0, alpha); + + vst1_u16(p0, (uint16x4_t)vcvt_f16_f32(_f0)); + pp += 4; + p0 += out_hstep; + } + } +#endif // __ARM_NEON + for (; ii + 1 < max_ii; ii += 2) + { + unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack; + + const float descale0 = descales[ii]; + const float descale1 = descales[ii + 1]; +#if __ARM_NEON + float32x2_t _descale01 = vld1_f32((const float*)descales + ii); +#endif + + float c0; + float c1; +#if __ARM_NEON + float32x4_t _c0; + float32x4_t _c1; +#endif + if (pC) + { + if (broadcast_type_C == 0) + { + c0 = float16_to_float32(pC[0]) * beta; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); +#endif + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const unsigned short*)C + i + ii; + c0 = float16_to_float32(pC[0]) * beta; + c1 = float16_to_float32(pC[1]) * beta; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); + _c1 = vdupq_n_f32(c1); +#endif + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + pC = (const unsigned short*)C + (i + ii) * c_hstep + j; + } + if (broadcast_type_C == 4) + { + pC = (const unsigned short*)C + j; + } + } + + int jj = 0; +#if __ARM_NEON +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale01, 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale01, 0); + float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), _descale01, 1); + float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), _descale01, 1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c1); + _f3 = vaddq_f32(_f3, _c1); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + c_hstep); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); + float32x4_t _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); + float32x4_t _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + pC += 8; + } + if (broadcast_type_C == 4) + { + uint16x8_t _c = vld1q_u16(pC); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c)); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _c0 = vmulq_f32(_c0, _beta); + _c1 = vmulq_f32(_c1, _beta); + } + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + pC += 8; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + + uint16x4_t _hf0 = (uint16x4_t)vcvt_f16_f32(_f0); + uint16x4_t _hf1 = (uint16x4_t)vcvt_f16_f32(_f1); + uint16x4_t _hf2 = (uint16x4_t)vcvt_f16_f32(_f2); + uint16x4_t _hf3 = (uint16x4_t)vcvt_f16_f32(_f3); + +#if __aarch64__ + if (out_elempack == 8) + { + vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); + vst1q_u16(p0 + 8, vcombine_u16(_hf2, _hf3)); + } +#endif // __aarch64__ + if (out_elempack == 4) + { + vst1q_u16(p0, vcombine_u16(_hf0, _hf2)); + vst1q_u16(p0 + out_hstep * 4, vcombine_u16(_hf1, _hf3)); + } + if (out_elempack == 1) + { + p0[0] = vget_lane_u16(_hf0, 0); + p0[1] = vget_lane_u16(_hf2, 0); + p0[out_hstep] = vget_lane_u16(_hf0, 1); + p0[out_hstep + 1] = vget_lane_u16(_hf2, 1); + p0[out_hstep * 2] = vget_lane_u16(_hf0, 2); + p0[out_hstep * 2 + 1] = vget_lane_u16(_hf2, 2); + p0[out_hstep * 3] = vget_lane_u16(_hf0, 3); + p0[out_hstep * 3 + 1] = vget_lane_u16(_hf2, 3); + p0[out_hstep * 4] = vget_lane_u16(_hf1, 0); + p0[out_hstep * 4 + 1] = vget_lane_u16(_hf3, 0); + p0[out_hstep * 5] = vget_lane_u16(_hf1, 1); + p0[out_hstep * 5 + 1] = vget_lane_u16(_hf3, 1); + p0[out_hstep * 6] = vget_lane_u16(_hf1, 2); + p0[out_hstep * 6 + 1] = vget_lane_u16(_hf3, 2); + p0[out_hstep * 7] = vget_lane_u16(_hf1, 3); + p0[out_hstep * 7 + 1] = vget_lane_u16(_hf3, 3); + } + + pp += 16; + p0 += out_hstep * 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + // a0 a1 a2 a3 + // b0 b1 b2 b3 + + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale01, 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale01, 1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + _c0 = vcvt_f32_f16((float16x4_t)vld1_u16(pC)); + _c1 = vcvt_f32_f16((float16x4_t)vld1_u16(pC + c_hstep)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + } + pC += 4; + } + if (broadcast_type_C == 4) + { + _c0 = vcvt_f32_f16((float16x4_t)vld1_u16(pC)); + _c0 = vmulq_n_f32(_c0, beta); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + pC += 4; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + + uint16x4_t _hf0 = (uint16x4_t)vcvt_f16_f32(_f0); + uint16x4_t _hf1 = (uint16x4_t)vcvt_f16_f32(_f1); + + if (out_elempack == 4) + { + vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); + } + if (out_elempack == 1) + { + p0[0] = vget_lane_u16(_hf0, 0); + p0[1] = vget_lane_u16(_hf1, 0); + p0[out_hstep] = vget_lane_u16(_hf0, 1); + p0[out_hstep + 1] = vget_lane_u16(_hf1, 1); + p0[out_hstep * 2] = vget_lane_u16(_hf0, 2); + p0[out_hstep * 2 + 1] = vget_lane_u16(_hf1, 2); + p0[out_hstep * 3] = vget_lane_u16(_hf0, 3); + p0[out_hstep * 3 + 1] = vget_lane_u16(_hf1, 3); + } + + pp += 8; + p0 += out_hstep * 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + // a0 a1 b0 b1 + int32x2x2_t _sum0 = vld2_s32(pp); + + float32x4_t _descale = vcombine_f32(_descale01, _descale01); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vcombine_s32(_sum0.val[0], _sum0.val[1])), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + float32x4_t _cc = vzipq_f32(_c0, _c1).val[0]; + _f0 = vaddq_f32(_f0, _cc); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + uint16x4_t _c = uint16x4_t(); + _c = vset_lane_u16(pC[0], _c, 0); + _c = vset_lane_u16(pC[c_hstep], _c, 1); + _c = vset_lane_u16(pC[1], _c, 2); + _c = vset_lane_u16(pC[c_hstep + 1], _c, 3); + _c0 = vcvt_f32_f16((float16x4_t)_c); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + pC += 2; + } + if (broadcast_type_C == 4) + { + uint16x4_t _c = uint16x4_t(); + _c = vset_lane_u16(pC[0], _c, 0); + _c = vset_lane_u16(pC[0], _c, 1); + _c = vset_lane_u16(pC[1], _c, 2); + _c = vset_lane_u16(pC[1], _c, 3); + _c0 = vcvt_f32_f16((float16x4_t)_c); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + pC += 2; + } + } + + _f0 = vmulq_n_f32(_f0, alpha); + + uint16x4_t _hf0 = (uint16x4_t)vcvt_f16_f32(_f0); + + p0[0] = vget_lane_u16(_hf0, 0); + p0[1] = vget_lane_u16(_hf0, 1); + p0[out_hstep] = vget_lane_u16(_hf0, 2); + p0[out_hstep + 1] = vget_lane_u16(_hf0, 3); + + pp += 4; + p0 += out_hstep * 2; + } +#endif // __ARM_NEON + for (; jj < max_jj; jj += 1) + { + float f0 = pp[0] * descale0; + float f1 = pp[1] * descale1; + + if (pC) + { + if (broadcast_type_C == 0) + { + f0 += c0; + f1 += c0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + f0 += c0; + f1 += c1; + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + f0 += float16_to_float32(pC[0]) * beta; + f1 += float16_to_float32(pC[c_hstep]) * beta; + pC += 1; + } + if (broadcast_type_C == 4) + { + c0 = float16_to_float32(pC[0]) * beta; + f0 += c0; + f1 += c0; + pC += 1; + } + } + + if (alpha != 1.f) + { + f0 *= alpha; + f1 *= alpha; + } + + p0[0] = float32_to_float16(f0); + p0[1] = float32_to_float16(f1); + pp += 2; + p0 += out_hstep; + } + } + for (; ii < max_ii; ii += 1) + { + unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack; + + const float descale = descales[ii]; +#if __ARM_NEON + float32x4_t _descale = vdupq_n_f32(descale); +#endif + + float c0; +#if __ARM_NEON + float32x4_t _c0; +#endif + if (pC) + { + if (broadcast_type_C == 0) + { + c0 = float16_to_float32(pC[0]) * beta; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); +#endif + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const unsigned short*)C + i + ii; + c0 = float16_to_float32(pC[0]) * beta; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); +#endif + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + pC = (const unsigned short*)C + (i + ii) * c_hstep + j; + } + if (broadcast_type_C == 4) + { + pC = (const unsigned short*)C + j; + } + } + + int jj = 0; +#if __ARM_NEON + for (; jj + 15 < max_jj; jj += 16) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); + float32x4_t _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); + float32x4_t _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); + float32x4_t _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + pC += 16; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + + uint16x4_t _hf0 = (uint16x4_t)vcvt_f16_f32(_f0); + uint16x4_t _hf1 = (uint16x4_t)vcvt_f16_f32(_f1); + uint16x4_t _hf2 = (uint16x4_t)vcvt_f16_f32(_f2); + uint16x4_t _hf3 = (uint16x4_t)vcvt_f16_f32(_f3); + + if (out_hstep == 1) + { + vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); + vst1q_u16(p0 + 8, vcombine_u16(_hf2, _hf3)); + } + else + { +#if __aarch64__ + if (out_elempack == 8) + { + vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); + vst1q_u16(p0 + out_hstep * 8, vcombine_u16(_hf2, _hf3)); + } +#endif // __aarch64__ + if (out_elempack == 4) + { + vst1_u16(p0, _hf0); + vst1_u16(p0 + out_hstep * 4, _hf1); + vst1_u16(p0 + out_hstep * 8, _hf2); + vst1_u16(p0 + out_hstep * 12, _hf3); + } + if (out_elempack == 1) + { + p0[0] = vget_lane_u16(_hf0, 0); + p0[out_hstep] = vget_lane_u16(_hf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_hf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_hf0, 3); + p0[out_hstep * 4] = vget_lane_u16(_hf1, 0); + p0[out_hstep * 5] = vget_lane_u16(_hf1, 1); + p0[out_hstep * 6] = vget_lane_u16(_hf1, 2); + p0[out_hstep * 7] = vget_lane_u16(_hf1, 3); + p0[out_hstep * 8] = vget_lane_u16(_hf2, 0); + p0[out_hstep * 9] = vget_lane_u16(_hf2, 1); + p0[out_hstep * 10] = vget_lane_u16(_hf2, 2); + p0[out_hstep * 11] = vget_lane_u16(_hf2, 3); + p0[out_hstep * 12] = vget_lane_u16(_hf3, 0); + p0[out_hstep * 13] = vget_lane_u16(_hf3, 1); + p0[out_hstep * 14] = vget_lane_u16(_hf3, 2); + p0[out_hstep * 15] = vget_lane_u16(_hf3, 3); + } + } + + pp += 16; + p0 += out_hstep * 16; + } + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // out_elempack == 1 + uint16x8_t _c = vld1q_u16(pC); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c)); + float32x4_t _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + } + pC += 8; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + + uint16x4_t _hf0 = (uint16x4_t)vcvt_f16_f32(_f0); + uint16x4_t _hf1 = (uint16x4_t)vcvt_f16_f32(_f1); + + if (out_hstep == 1) + { + vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); + } + else + { +#if __aarch64__ + if (out_elempack == 8) + { + vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); + } +#endif // __aarch64__ + if (out_elempack == 4) + { + vst1_u16(p0, _hf0); + vst1_u16(p0 + out_hstep * 4, _hf1); + } + if (out_elempack == 1) + { + p0[0] = vget_lane_u16(_hf0, 0); + p0[out_hstep] = vget_lane_u16(_hf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_hf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_hf0, 3); + p0[out_hstep * 4] = vget_lane_u16(_hf1, 0); + p0[out_hstep * 5] = vget_lane_u16(_hf1, 1); + p0[out_hstep * 6] = vget_lane_u16(_hf1, 2); + p0[out_hstep * 7] = vget_lane_u16(_hf1, 3); + } + } + + pp += 8; + p0 += out_hstep * 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // out_elempack == 1 + _c0 = vcvt_f32_f16((float16x4_t)vld1_u16(pC)); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + pC += 4; + } + } + + _f0 = vmulq_n_f32(_f0, alpha); + + uint16x4_t _hf0 = (uint16x4_t)vcvt_f16_f32(_f0); + + if (out_hstep == 1) + { + vst1_u16(p0, _hf0); + } + else + { + if (out_elempack == 4) + { + vst1_u16(p0, _hf0); + } + if (out_elempack == 1) + { + p0[0] = vget_lane_u16(_hf0, 0); + p0[out_hstep] = vget_lane_u16(_hf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_hf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_hf0, 3); + } + } + + pp += 4; + p0 += out_hstep * 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + float32x2_t _f0 = vmul_f32(vcvt_f32_s32(vld1_s32(pp)), vget_low_f32(_descale)); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vadd_f32(_f0, vget_low_f32(_c0)); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + float32x2_t _c = float32x2_t(); + _c = vset_lane_f32(float16_to_float32(pC[0]), _c, 0); + _c = vset_lane_f32(float16_to_float32(pC[1]), _c, 1); + _f0 = vmla_n_f32(_f0, _c, beta); + pC += 2; + } + } + + _f0 = vmul_n_f32(_f0, alpha); + + p0[0] = float32_to_float16(vget_lane_f32(_f0, 0)); + p0[out_hstep] = float32_to_float16(vget_lane_f32(_f0, 1)); + + pp += 2; + p0 += out_hstep * 2; + } +#endif // __ARM_NEON + for (; jj < max_jj; jj += 1) + { + float f0 = pp[0] * descale; + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + f0 += c0; + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + f0 += float16_to_float32(pC[0]) * beta; + pC += 1; + } + } + + f0 *= alpha; + + p0[0] = float32_to_float16(f0); + + pp += 1; + p0 += out_hstep; + } + } +} diff --git a/src/layer/gemm.cpp b/src/layer/gemm.cpp index de6b2adeb956..0ebe5974d0b7 100644 --- a/src/layer/gemm.cpp +++ b/src/layer/gemm.cpp @@ -39,10 +39,19 @@ int Gemm::load_param(const ParamDict& pd) output_elempack = pd.get(12, 0); output_elemtype = pd.get(13, 0); output_transpose = pd.get(14, 0); + int8_scale_term = pd.get(18, 0); constant_TILE_M = pd.get(20, 0); constant_TILE_N = pd.get(21, 0); constant_TILE_K = pd.get(22, 0); + if (int8_scale_term) + { +#if !NCNN_INT8 + NCNN_LOGE("please build ncnn with NCNN_INT8 enabled for int8 inference"); + return -1; +#endif + } + if (constantA == 1 && (constantM == 0 || constantK == 0)) { NCNN_LOGE("constantM and constantK must be non-zero when constantA enabled"); @@ -111,9 +120,175 @@ int Gemm::load_model(const ModelBin& mb) return -100; } +#if NCNN_INT8 + if (int8_scale_term) + { + if (constantA == 1) + { + A_data_int8_scales = mb.load(constantM, 1); + } + + if (constantB == 1) + { + B_data_int8_scale = mb.load(1, 1)[0]; + } + } +#endif // NCNN_INT8 + return 0; } +static void gemm_transB(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blob, float alpha, float beta, int broadcast_type_C, int output_transpose, const Option& opt) +{ + const int M = A.dims == 3 ? A.c : A.h; + const int N = BT.dims == 3 ? BT.c : BT.h; + const int K = A.w; // assert A.w == BT.w + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < M; i++) + { + const int out_hstep = top_blob.dims == 3 ? (int)top_blob.cstep : top_blob.w; + + const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + const int BT_hstep = BT.dims == 3 ? (int)BT.cstep : BT.w; + + const float* ptrA = (const float*)A + i * A_hstep; + const float* ptrC = C; + + for (int j = 0; j < N; j++) + { + const float* ptrBT = (const float*)BT + j * BT_hstep; + + float sum = 0.f; + if (ptrC) + { + if (broadcast_type_C == 0) + { + sum = ptrC[0]; + } + if (broadcast_type_C == 1) + { + sum = ptrC[i]; + } + if (broadcast_type_C == 2) + { + sum = ptrC[i]; + } + if (broadcast_type_C == 3) + { + sum = ptrC[i * N + j]; + } + if (broadcast_type_C == 4) + { + sum = ptrC[j]; + } + + sum *= beta; + } + + for (int k = 0; k < K; k++) + { + sum += ptrA[k] * ptrBT[k]; + } + + sum *= alpha; + + if (output_transpose) + { + top_blob[j * out_hstep + i] = sum; + } + else + { + top_blob[i * out_hstep + j] = sum; + } + } + } +} + +#if NCNN_INT8 +static inline signed char float2int8(float v) +{ + int int32 = static_cast(round(v)); + if (int32 > 127) return 127; + if (int32 < -127) return -127; + return (signed char)int32; +} + +static void gemm_transB_int8(const Mat& A_int8, const Mat& BT_int8, const Mat& A_int8_scales, float BT_int8_scale, const Mat& C, Mat& top_blob, float alpha, float beta, int broadcast_type_C, int output_transpose, const Option& opt) +{ + const int M = A_int8.h; + const int N = BT_int8.h; + const int K = A_int8.w; // assert A_int8.w == BT_int8.w + + // NCNN_LOGE("naive ds %f %f", A_int8_scales[0], BT_int8_scale); + + // #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < M; i++) + { + const int out_hstep = top_blob.dims == 3 ? (int)top_blob.cstep : top_blob.w; + + const signed char* ptrA = A_int8.row(i); + const float* ptrC = C; + + const float descale = 1.f / (A_int8_scales[i] * BT_int8_scale); + + // NCNN_LOGE("descale %f", descale); + + for (int j = 0; j < N; j++) + { + const signed char* ptrBT = BT_int8.row(j); + + int sum = 0; + for (int k = 0; k < K; k++) + { + // NCNN_LOGE("ptrA[%d] %d", k, ptrA[k]); + sum += ptrA[k] * ptrBT[k]; + } + + float sum_fp32 = sum * descale; + + if (ptrC) + { + float c = 0.f; + if (broadcast_type_C == 0) + { + c = ptrC[0]; + } + if (broadcast_type_C == 1) + { + c = ptrC[i]; + } + if (broadcast_type_C == 2) + { + c = ptrC[i]; + } + if (broadcast_type_C == 3) + { + c = ptrC[i * N + j]; + } + if (broadcast_type_C == 4) + { + c = ptrC[j]; + } + + sum_fp32 += c * beta; + } + + sum_fp32 *= alpha; + + if (output_transpose) + { + top_blob[j * out_hstep + i] = sum_fp32; + } + else + { + top_blob[i * out_hstep + j] = sum_fp32; + } + } + } +} +#endif // NCNN_INT8 + int Gemm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { std::vector bottom_blobs(1, bottom_blob); @@ -125,6 +300,13 @@ int Gemm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons int Gemm::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { +#if NCNN_INT8 + if (int8_scale_term) + { + return forward_int8(bottom_blobs, top_blobs, opt); + } +#endif // NCNN_INT8 + const Mat& A0 = constantA ? A_data : bottom_blobs[0]; const Mat& B0 = constantB ? B_data : constantA ? bottom_blobs[0] : bottom_blobs[1]; @@ -152,18 +334,18 @@ int Gemm::forward(const std::vector& bottom_blobs, std::vector& top_bl } } - Mat B; + Mat BT; if (transB == 0) { // transpose B to col-major - B.create((B0.dims == 3 ? B0.c : B0.h), B0.w, elemsize, opt.workspace_allocator); + BT.create((B0.dims == 3 ? B0.c : B0.h), B0.w, elemsize, opt.workspace_allocator); const int B0_hstep = B0.dims == 3 ? (int)B0.cstep : B0.w; - for (int i = 0; i < B.h; i++) + for (int i = 0; i < BT.h; i++) { - float* ptr = B.row(i); - for (int j = 0; j < B.w; j++) + float* ptr = BT.row(i); + for (int j = 0; j < BT.w; j++) { ptr[j] = B0[j * B0_hstep + i]; } @@ -171,43 +353,36 @@ int Gemm::forward(const std::vector& bottom_blobs, std::vector& top_bl } else { - B = B0; + BT = B0; } const int M = A.dims == 3 ? A.c : A.h; - const int K = A.w; // assert A.w == B.w - const int N = B.dims == 3 ? B.c : B.h; + const int N = BT.dims == 3 ? BT.c : BT.h; - const float* ptrC = 0; + Mat C; int broadcast_type_C = 0; if (constantC) { - ptrC = C_data; + C = C_data; broadcast_type_C = constant_broadcast_type_C; } else { - if (constantA && constantB) + if (constantA && constantB && bottom_blobs.size() == 1) { - ptrC = bottom_blobs.size() == 1 ? bottom_blobs[0] : 0; + C = bottom_blobs[0]; } - else if (constantA) + else if ((constantA || constantB) && bottom_blobs.size() == 2) { - ptrC = bottom_blobs.size() == 2 ? bottom_blobs[1] : 0; + C = bottom_blobs[1]; } - else if (constantB) - { - ptrC = bottom_blobs.size() == 2 ? bottom_blobs[1] : 0; - } - else + else if (bottom_blobs.size() == 3) { - ptrC = bottom_blobs.size() == 3 ? bottom_blobs[2] : 0; + C = bottom_blobs[2]; } - if (ptrC) + if (!C.empty()) { - const Mat& C = bottom_blobs[bottom_blobs.size() - 1]; - if (C.dims == 1 && C.w == 1) { // scalar @@ -260,66 +435,226 @@ int Gemm::forward(const std::vector& bottom_blobs, std::vector& top_bl if (top_blob.empty()) return -100; - #pragma omp parallel for num_threads(opt.num_threads) - for (int i = 0; i < M; i++) - { - const int out_hstep = top_blob.dims == 3 ? (int)top_blob.cstep : top_blob.w; + gemm_transB(A, BT, C, top_blob, alpha, beta, broadcast_type_C, output_transpose, opt); - const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; - const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; + return 0; +} - const float* ptrA = (const float*)A + i * A_hstep; +#if NCNN_INT8 +int Gemm::forward_int8(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + const Mat& A0 = constantA ? A_data : bottom_blobs[0]; + const Mat& B0 = constantB ? B_data : constantA ? bottom_blobs[0] : bottom_blobs[1]; - for (int j = 0; j < N; j++) + Mat A; + if (transA == 0) + { + A = A0; + } + else + { + // transpose A to row-major + if (A0.elemsize == 1) { - const float* ptrB = (const float*)B + j * B_hstep; + A.create(A0.h, A0.w, (size_t)1u, 1, opt.workspace_allocator); - float sum = 0.f; - if (ptrC) + for (int i = 0; i < A.h; i++) { - if (broadcast_type_C == 0) - { - sum = ptrC[0]; - } - if (broadcast_type_C == 1) - { - sum = ptrC[i]; - } - if (broadcast_type_C == 2) + signed char* ptr = A.row(i); + for (int j = 0; j < A.w; j++) { - sum = ptrC[i]; + ptr[j] = A0.row(j)[i]; } - if (broadcast_type_C == 3) - { - sum = ptrC[i * N + j]; - } - if (broadcast_type_C == 4) + } + } + else + { + A.create(A0.dims == 3 ? A0.c : A0.h, A0.w, (size_t)4u, 1, opt.workspace_allocator); + + for (int i = 0; i < A.h; i++) + { + float* ptr = A.row(i); + for (int j = 0; j < A.w; j++) { - sum = ptrC[j]; + ptr[j] = A0.dims == 3 ? A0.channel(j)[i] : A0.row(j)[i]; } + } + } + } - sum *= beta; + // dynamic quantize A + Mat A_int8 = A; + Mat A_int8_scales = A_data_int8_scales; + if (A_int8.elemsize != 1) + { + A_int8.create(A.w, A.dims == 3 ? A.c : A.h, (size_t)1u, 1, opt.workspace_allocator); + A_int8_scales.create(A_int8.h, (size_t)4u, 1, opt.workspace_allocator); + + for (int i = 0; i < A_int8.h; i++) + { + const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + const float* ptr = (const float*)A + i * A_hstep; + + float absmax = 0.f; + for (int k = 0; k < A_int8.w; k++) + { + absmax = std::max(absmax, (float)fabs(ptr[k])); } - for (int k = 0; k < K; k++) + // NCNN_LOGE("A[%d] absmax %f", i, absmax); + + float A_int8_scale = absmax == 0.f ? 1.f : 127.f / absmax; + A_int8_scales[i] = A_int8_scale; + + signed char* ptrAi = A_int8.row(i); + + for (int k = 0; k < A_int8.w; k++) { - sum += ptrA[k] * ptrB[k]; + ptrAi[k] = float2int8(ptr[k] * A_int8_scale); } + } + } - sum *= alpha; + // dynamic quantize B + Mat B0_int8 = B0; + float B_int8_scale = B_data_int8_scale; + if (B0_int8.elemsize != 1) + { + B0_int8.create(B0.w, B0.dims == 3 ? B0.c : B0.h, (size_t)1u, 1, opt.workspace_allocator); - if (output_transpose) + float absmax = 0.f; + for (int i = 0; i < B0_int8.h; i++) + { + const int B_hstep = B0.dims == 3 ? (int)B0.cstep : B0.w; + const float* ptr = (const float*)B0 + i * B_hstep; + + for (int k = 0; k < B0_int8.w; k++) { - top_blob[j * out_hstep + i] = sum; + absmax = std::max(absmax, (float)fabs(ptr[k])); } - else + } + + // NCNN_LOGE("B0 absmax %f", absmax); + + B_int8_scale = absmax == 0.f ? 1.f : 127.f / absmax; + + for (int i = 0; i < B0_int8.h; i++) + { + const int B_hstep = B0.dims == 3 ? (int)B0.cstep : B0.w; + const float* ptr = (const float*)B0 + i * B_hstep; + + signed char* ptrBi = B0_int8.row(i); + + for (int k = 0; k < B0_int8.w; k++) { - top_blob[i * out_hstep + j] = sum; + ptrBi[k] = float2int8(ptr[k] * B_int8_scale); } } } + Mat BT_int8; + if (transB == 0) + { + // transpose B to col-major + BT_int8.create(B0_int8.h, B0_int8.w, (size_t)1u, 1, opt.workspace_allocator); + + for (int i = 0; i < BT_int8.h; i++) + { + signed char* ptr = BT_int8.row(i); + for (int j = 0; j < BT_int8.w; j++) + { + ptr[j] = B0_int8.row(j)[i]; + } + } + } + else + { + BT_int8 = B0_int8; + } + + const int M = A_int8.h; + const int N = BT_int8.h; + + Mat C; + int broadcast_type_C = 0; + if (constantC) + { + C = C_data; + broadcast_type_C = constant_broadcast_type_C; + } + else + { + if (constantA && constantB && bottom_blobs.size() == 1) + { + C = bottom_blobs[0]; + } + else if ((constantA || constantB) && bottom_blobs.size() == 2) + { + C = bottom_blobs[1]; + } + else if (bottom_blobs.size() == 3) + { + C = bottom_blobs[2]; + } + + if (!C.empty()) + { + if (C.dims == 1 && C.w == 1) + { + // scalar + broadcast_type_C = 0; + } + if (C.dims == 1 && C.w == M) + { + // M + // auto broadcast from h to w is the ncnn-style convention + broadcast_type_C = 1; + } + if (C.dims == 1 && C.w == N) + { + // N + broadcast_type_C = 4; + } + if (C.dims == 2 && C.w == 1 && C.h == M) + { + // Mx1 + broadcast_type_C = 2; + } + if (C.dims == 2 && C.w == N && C.h == M) + { + // MxN + broadcast_type_C = 3; + } + if (C.dims == 2 && C.w == N && C.h == 1) + { + // 1xN + broadcast_type_C = 4; + } + } + } + + Mat& top_blob = top_blobs[0]; + if (output_transpose) + { + if (output_N1M) + top_blob.create(M, 1, N, 4u, opt.blob_allocator); + else + top_blob.create(M, N, 4u, opt.blob_allocator); + } + else + { + if (output_N1M) + top_blob.create(N, 1, M, 4u, opt.blob_allocator); + else + top_blob.create(N, M, 4u, opt.blob_allocator); + } + if (top_blob.empty()) + return -100; + + gemm_transB_int8(A_int8, BT_int8, A_int8_scales, B_int8_scale, C, top_blob, alpha, beta, broadcast_type_C, output_transpose, opt); + return 0; } +#endif // NCNN_INT8 } // namespace ncnn diff --git a/src/layer/gemm.h b/src/layer/gemm.h index e006114a1498..8408772c12c0 100644 --- a/src/layer/gemm.h +++ b/src/layer/gemm.h @@ -32,6 +32,11 @@ class Gemm : public Layer virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; +protected: +#if NCNN_INT8 + int forward_int8(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; +#endif + public: float alpha; float beta; @@ -50,6 +55,8 @@ class Gemm : public Layer int output_elemtype; // 0=auto 1=fp32 int output_transpose; + int int8_scale_term; + int constant_TILE_M; int constant_TILE_N; int constant_TILE_K; @@ -58,6 +65,11 @@ class Gemm : public Layer Mat A_data; Mat B_data; Mat C_data; + +#if NCNN_INT8 + Mat A_data_int8_scales; + float B_data_int8_scale; +#endif }; } // namespace ncnn diff --git a/src/layer/riscv/gemm_riscv.cpp b/src/layer/riscv/gemm_riscv.cpp index fa25a058cb1c..8dee572548ed 100644 --- a/src/layer/riscv/gemm_riscv.cpp +++ b/src/layer/riscv/gemm_riscv.cpp @@ -3947,6 +3947,14 @@ int Gemm_riscv::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt int Gemm_riscv::create_pipeline(const Option& opt) { +#if NCNN_INT8 + if (int8_scale_term) + { + support_packing = false; + return 0; + } +#endif + if (constantA) { const int M = constantM; @@ -4070,6 +4078,13 @@ int Gemm_riscv::create_pipeline(const Option& opt) int Gemm_riscv::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { +#if NCNN_INT8 + if (int8_scale_term) + { + return Gemm::forward_int8(bottom_blobs, top_blobs, opt); + } +#endif + int M; int N; if (constantA && constantB) diff --git a/src/layer/vulkan/gemm_vulkan.cpp b/src/layer/vulkan/gemm_vulkan.cpp index 0d403a5288b9..eed5dd357fd6 100644 --- a/src/layer/vulkan/gemm_vulkan.cpp +++ b/src/layer/vulkan/gemm_vulkan.cpp @@ -26,6 +26,19 @@ Gemm_vulkan::Gemm_vulkan() pipeline_gemm = 0; } +int Gemm_vulkan::load_param(const ParamDict& pd) +{ + int ret = Gemm::load_param(pd); + + if (int8_scale_term) + { + support_vulkan = false; + support_image_storage = false; + } + + return ret; +} + int Gemm_vulkan::create_pipeline(const Option& opt) { // const Mat& shape = top_shapes.empty() ? Mat() : top_shapes[0]; @@ -169,56 +182,78 @@ int Gemm_vulkan::forward(const std::vector& bottom_blobs, std::vectorconvert_packing(A0, A, 1, cmd, opt); vkdev->convert_packing(B0, B, 1, cmd, opt); - vkdev->convert_packing(C0, C, 1, cmd, opt); const int M = constantM ? constantM : transA ? A.w : (A.dims == 3 ? A.c : A.h); const int K = constantK ? constantK : transA ? (A.dims == 3 ? A.c : A.h) : A.w; const int N = constantN ? constantN : transB ? (B.dims == 3 ? B.c : B.h) : B.w; - int broadcast_type_C; + VkMat C; + int broadcast_type_C = -1; if (constantC) { + vkdev->convert_packing(C_data_gpu, C, 1, cmd, opt); broadcast_type_C = constant_broadcast_type_C; } else { - if (C.dims == 1 && C.w == 1) + VkMat C0; + if (constantA && constantB) { - // scalar - broadcast_type_C = 0; + C0 = bottom_blobs.size() == 1 ? bottom_blobs[0] : VkMat(); } - if (C.dims == 1 && C.w == M) + else if (constantA) { - // M - // auto broadcast from h to w is the ncnn-style convention - broadcast_type_C = 1; + C0 = bottom_blobs.size() == 2 ? bottom_blobs[1] : VkMat(); } - if (C.dims == 1 && C.w == N) + else if (constantB) { - // N - broadcast_type_C = 4; + C0 = bottom_blobs.size() == 2 ? bottom_blobs[1] : VkMat(); } - if (C.dims == 2 && C.w == 1 && C.h == M) - { - // Mx1 - broadcast_type_C = 2; - } - if (C.dims == 2 && C.w == N && C.h == M) + else { - // MxN - broadcast_type_C = 3; + C0 = bottom_blobs.size() == 3 ? bottom_blobs[2] : VkMat(); } - if (C.dims == 2 && C.w == N && C.h == 1) + + if (!C0.empty()) { - // 1xN - broadcast_type_C = 4; + vkdev->convert_packing(C0, C, 1, cmd, opt); + + if (C.dims == 1 && C.w == 1) + { + // scalar + broadcast_type_C = 0; + } + if (C.dims == 1 && C.w == M) + { + // M + // auto broadcast from h to w is the ncnn-style convention + broadcast_type_C = 1; + } + if (C.dims == 1 && C.w == N) + { + // N + broadcast_type_C = 4; + } + if (C.dims == 2 && C.w == 1 && C.h == M) + { + // Mx1 + broadcast_type_C = 2; + } + if (C.dims == 2 && C.w == N && C.h == M) + { + // MxN + broadcast_type_C = 3; + } + if (C.dims == 2 && C.w == N && C.h == 1) + { + // 1xN + broadcast_type_C = 4; + } } } @@ -301,56 +336,78 @@ int Gemm_vulkan::forward(const std::vector& bottom_blobs, std::vecto { const VkImageMat& A0 = constantA ? A_data_gpu_image : bottom_blobs[0]; const VkImageMat& B0 = constantB ? B_data_gpu_image : constantA ? bottom_blobs[0] : bottom_blobs[1]; - const VkImageMat& C0 = constantC ? C_data_gpu_image : bottom_blobs[bottom_blobs.size() - 1]; VkImageMat A; VkImageMat B; - VkImageMat C; vkdev->convert_packing(A0, A, 1, cmd, opt); vkdev->convert_packing(B0, B, 1, cmd, opt); - vkdev->convert_packing(C0, C, 1, cmd, opt); const int M = constantM ? constantM : transA ? A.w : (A.dims == 3 ? A.c : A.h); const int K = constantK ? constantK : transA ? (A.dims == 3 ? A.c : A.h) : A.w; const int N = constantN ? constantN : transB ? (B.dims == 3 ? B.c : B.h) : B.w; - int broadcast_type_C; + VkImageMat C; + int broadcast_type_C = -1; if (constantC) { + vkdev->convert_packing(C_data_gpu_image, C, 1, cmd, opt); broadcast_type_C = constant_broadcast_type_C; } else { - if (C.dims == 1 && C.w == 1) - { - // scalar - broadcast_type_C = 0; - } - if (C.dims == 1 && C.w == M) + VkImageMat C0; + if (constantA && constantB) { - // M - // auto broadcast from h to w is the ncnn-style convention - broadcast_type_C = 1; + C0 = bottom_blobs.size() == 1 ? bottom_blobs[0] : VkImageMat(); } - if (C.dims == 1 && C.w == N) + else if (constantA) { - // N - broadcast_type_C = 4; + C0 = bottom_blobs.size() == 2 ? bottom_blobs[1] : VkImageMat(); } - if (C.dims == 2 && C.w == 1 && C.h == M) + else if (constantB) { - // Mx1 - broadcast_type_C = 2; + C0 = bottom_blobs.size() == 2 ? bottom_blobs[1] : VkImageMat(); } - if (C.dims == 2 && C.w == N && C.h == M) + else { - // MxN - broadcast_type_C = 3; + C0 = bottom_blobs.size() == 3 ? bottom_blobs[2] : VkImageMat(); } - if (C.dims == 2 && C.w == N && C.h == 1) + + if (!C0.empty()) { - // 1xN - broadcast_type_C = 4; + vkdev->convert_packing(C0, C, 1, cmd, opt); + + if (C.dims == 1 && C.w == 1) + { + // scalar + broadcast_type_C = 0; + } + if (C.dims == 1 && C.w == M) + { + // M + // auto broadcast from h to w is the ncnn-style convention + broadcast_type_C = 1; + } + if (C.dims == 1 && C.w == N) + { + // N + broadcast_type_C = 4; + } + if (C.dims == 2 && C.w == 1 && C.h == M) + { + // Mx1 + broadcast_type_C = 2; + } + if (C.dims == 2 && C.w == N && C.h == M) + { + // MxN + broadcast_type_C = 3; + } + if (C.dims == 2 && C.w == N && C.h == 1) + { + // 1xN + broadcast_type_C = 4; + } } } diff --git a/src/layer/vulkan/gemm_vulkan.h b/src/layer/vulkan/gemm_vulkan.h index d9fa92018e42..b1b37927b400 100644 --- a/src/layer/vulkan/gemm_vulkan.h +++ b/src/layer/vulkan/gemm_vulkan.h @@ -24,6 +24,8 @@ class Gemm_vulkan : public Gemm public: Gemm_vulkan(); + virtual int load_param(const ParamDict& pd); + virtual int create_pipeline(const Option& opt); virtual int destroy_pipeline(const Option& opt); diff --git a/src/layer/x86/gemm_x86.cpp b/src/layer/x86/gemm_x86.cpp index d9b11f7ea8fa..268f85f332d8 100644 --- a/src/layer/x86/gemm_x86.cpp +++ b/src/layer/x86/gemm_x86.cpp @@ -7222,6 +7222,14 @@ static int gemm_AT_BT_x86(const Mat& AT, const Mat& BT, const Mat& C, Mat& top_b int Gemm_x86::create_pipeline(const Option& opt) { +#if NCNN_INT8 + if (int8_scale_term) + { + support_packing = false; + return 0; + } +#endif + if (constantA) { const int M = constantM; @@ -7355,6 +7363,13 @@ int Gemm_x86::create_pipeline(const Option& opt) int Gemm_x86::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { +#if NCNN_INT8 + if (int8_scale_term) + { + return Gemm::forward_int8(bottom_blobs, top_blobs, opt); + } +#endif + int M; int N; if (constantA && constantB) diff --git a/tests/test_gemm.cpp b/tests/test_gemm.cpp index c2900e9ac611..95f5436f6843 100644 --- a/tests/test_gemm.cpp +++ b/tests/test_gemm.cpp @@ -300,7 +300,7 @@ int main() || test_gemm_1(M, N, K); if (ret != 0) - return 0; + return ret; } return 0; diff --git a/tests/test_gemm_1.cpp b/tests/test_gemm_1.cpp index 59a0c8256278..7179bf8d26f9 100644 --- a/tests/test_gemm_1.cpp +++ b/tests/test_gemm_1.cpp @@ -120,13 +120,13 @@ int main() int ret = test_gemm_0(M, N, K, TILE_M, TILE_N, TILE_K); if (ret != 0) - return 0; + return ret; } // test no tiling int ret = test_gemm_0(M, N, K, 100, 100, 100); if (ret != 0) - return 0; + return ret; } return 0; diff --git a/tests/test_gemm_3.cpp b/tests/test_gemm_3.cpp new file mode 100644 index 000000000000..d7c6c531a05f --- /dev/null +++ b/tests/test_gemm_3.cpp @@ -0,0 +1,336 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "testutil.h" + +#if NCNN_INT8 +static int test_gemm_int8(int M, int N, int K, float alpha, int transA, int transB, int output_elemtype, int output_transpose, int constantA, int constantB, int output_N1M) +{ + ncnn::ParamDict pd; + pd.set(0, alpha); + pd.set(1, 1.f); // beta + pd.set(2, transA); + pd.set(3, transB); + pd.set(4, constantA); + pd.set(5, constantB); + pd.set(6, 1); + pd.set(7, M); + pd.set(8, N); + pd.set(9, K); + pd.set(10, -1); + pd.set(11, output_N1M); + pd.set(13, output_elemtype); + pd.set(14, output_transpose); + pd.set(18, 2); // int8_scale_term + + std::vector weights; + if (constantA) weights.push_back(transA ? (output_N1M ? RandomS8Mat(M, 1, K) : RandomS8Mat(M, K)) : (output_N1M ? RandomS8Mat(K, 1, M) : RandomS8Mat(K, M))); + if (constantB) weights.push_back(transB ? (output_N1M ? RandomS8Mat(K, 1, N) : RandomS8Mat(K, N)) : (output_N1M ? RandomS8Mat(N, 1, K) : RandomS8Mat(N, K))); + if (constantA) weights.push_back(RandomMat(M, 10.f, 20.f)); + if (constantB) weights.push_back(RandomMat(1, 10.f, 20.f)); + + std::vector a; + if (!constantA) a.push_back(transA ? (output_N1M ? ncnn::Mat(M, 1, K) : ncnn::Mat(M, K)) : (output_N1M ? ncnn::Mat(K, 1, M) : ncnn::Mat(K, M))); + if (!constantB) a.push_back(transB ? (output_N1M ? ncnn::Mat(K, 1, N) : ncnn::Mat(K, N)) : (output_N1M ? ncnn::Mat(N, 1, K) : ncnn::Mat(N, K))); + + for (size_t i = 0; i < a.size(); i++) + { + Randomize(a[i], -10.f, 10.f); + } + + int ret = test_layer("Gemm", pd, weights, a); + if (ret != 0) + { + fprintf(stderr, "test_gemm_int8 failed M=%d N=%d K=%d alpha=%f transA=%d transB=%d output_elemtype=%d output_transpose=%d constantA=%d constantB=%d output_N1M=%d\n", M, N, K, alpha, transA, transB, output_elemtype, output_transpose, constantA, constantB, output_N1M); + } + + return ret; +} + +static int test_gemm_int8_bias(int M, int N, int K, const ncnn::Mat& C, float alpha, float beta, int transA, int transB, int output_elemtype, int output_transpose, int constantA, int constantB, int constantC) +{ + int broadcast_type_C = 0; + if (C.dims == 1 && C.w == 1) + { + // scalar + broadcast_type_C = 0; + } + if (C.dims == 1 && C.w == M) + { + // M + // auto broadcast from h to w is the ncnn-style convention + broadcast_type_C = 1; + } + if (C.dims == 1 && C.w == N) + { + // N + broadcast_type_C = 4; + } + if (C.dims == 2 && C.w == 1 && C.h == M) + { + // Mx1 + broadcast_type_C = 2; + } + if (C.dims == 2 && C.w == N && C.h == M) + { + // MxN + broadcast_type_C = 3; + } + if (C.dims == 2 && C.w == N && C.h == 1) + { + // 1xN + broadcast_type_C = 4; + } + + ncnn::ParamDict pd; + pd.set(0, alpha); + pd.set(1, beta); + pd.set(2, transA); + pd.set(3, transB); + pd.set(4, constantA); + pd.set(5, constantB); + pd.set(6, constantC); + pd.set(7, M); + pd.set(8, N); + pd.set(9, K); + pd.set(10, broadcast_type_C); + // pd.set(12, 1); // output_elempack + pd.set(13, output_elemtype); + pd.set(14, output_transpose); + pd.set(18, 2); // int8_scale_term + + std::vector weights; + if (constantA) weights.push_back(transA ? RandomS8Mat(M, K) : RandomS8Mat(K, M)); + if (constantB) weights.push_back(transB ? RandomS8Mat(K, N) : RandomS8Mat(N, K)); + if (constantC) weights.push_back(C); + if (constantA) weights.push_back(RandomMat(M, 10.f, 20.f)); + if (constantB) weights.push_back(RandomMat(1, 10.f, 20.f)); + + std::vector a; + if (!constantA) a.push_back(transA ? ncnn::Mat(M, K) : ncnn::Mat(K, M)); + if (!constantB) a.push_back(transB ? ncnn::Mat(K, N) : ncnn::Mat(N, K)); + if (!constantC) a.push_back(C); + + for (size_t i = 0; i < a.size(); i++) + { + Randomize(a[i], -10.f, 10.f); + } + + int ret = test_layer("Gemm", pd, weights, a); + if (ret != 0) + { + fprintf(stderr, "test_gemm_int8_bias failed M=%d N=%d K=%d C.dims=%d C=(%d %d %d) alpha=%f beta=%f transA=%d transB=%d output_elemtype=%d output_transpose=%d constantA=%d constantB=%d constantC=%d\n", M, N, K, C.dims, C.w, C.h, C.c, alpha, beta, transA, transB, output_elemtype, output_transpose, constantA, constantB, constantC); + } + + return ret; +} + +static int test_gemm_int8_fp16s(int M, int N, int K, float alpha, int transA, int transB, int output_elemtype, int output_transpose, int constantA, int constantB, int output_N1M) +{ + ncnn::ParamDict pd; + pd.set(0, alpha); + pd.set(1, 1.f); // beta + pd.set(2, transA); + pd.set(3, transB); + pd.set(4, constantA); + pd.set(5, constantB); + pd.set(6, 1); + pd.set(7, M); + pd.set(8, N); + pd.set(9, K); + pd.set(10, -1); + pd.set(11, output_N1M); + pd.set(13, output_elemtype); + pd.set(14, output_transpose); + pd.set(18, 2); // int8_scale_term + + std::vector weights; + if (constantA) weights.push_back(transA ? (output_N1M ? RandomS8Mat(M, 1, K) : RandomS8Mat(M, K)) : (output_N1M ? RandomS8Mat(K, 1, M) : RandomS8Mat(K, M))); + if (constantB) weights.push_back(transB ? (output_N1M ? RandomS8Mat(K, 1, N) : RandomS8Mat(K, N)) : (output_N1M ? RandomS8Mat(N, 1, K) : RandomS8Mat(N, K))); + if (constantA) weights.push_back(RandomMat(M, 10.f, 20.f)); + if (constantB) weights.push_back(RandomMat(1, 10.f, 20.f)); + + std::vector a; + if (!constantA) a.push_back(transA ? (output_N1M ? ncnn::Mat(M, 1, K) : ncnn::Mat(M, K)) : (output_N1M ? ncnn::Mat(K, 1, M) : ncnn::Mat(K, M))); + if (!constantB) a.push_back(transB ? (output_N1M ? ncnn::Mat(K, 1, N) : ncnn::Mat(K, N)) : (output_N1M ? ncnn::Mat(N, 1, K) : ncnn::Mat(N, K))); + + for (size_t i = 0; i < a.size(); i++) + { + Randomize(a[i], -10.f, 10.f); + } + + ncnn::Option opt; + opt.num_threads = 1; + opt.use_packing_layout = true; + opt.use_fp16_packed = false; + opt.use_fp16_storage = true; + opt.use_fp16_arithmetic = false; + opt.use_bf16_storage = false; + + float epsilon = 0.001; + + int ret = test_layer_opt("Gemm", pd, weights, opt, a, 1, epsilon); + if (ret != 0) + { + fprintf(stderr, "test_gemm_int8_fp16s failed M=%d N=%d K=%d alpha=%f transA=%d transB=%d output_elemtype=%d output_transpose=%d constantA=%d constantB=%d output_N1M=%d\n", M, N, K, alpha, transA, transB, output_elemtype, output_transpose, constantA, constantB, output_N1M); + return ret; + } + + return 0; +} + +static int test_gemm_0(int M, int N, int K) +{ + return 0 + || test_gemm_int8(M, N, K, 2.1f, 0, 1, 0, 0, 0, 0, 0) + || test_gemm_int8(M, N, K, 3.1f, 1, 1, 0, 0, 0, 0, 0) + || test_gemm_int8(M, N, K, 4.1f, 0, 0, 0, 0, 0, 0, 1) + || test_gemm_int8(M, N, K, 5.1f, 1, 0, 0, 0, 0, 0, 1) + + || test_gemm_int8(M, N, K, 0.2f, 0, 1, 0, 0, 1, 0, 1) + || test_gemm_int8(M, N, K, 0.3f, 1, 1, 0, 0, 1, 0, 1) + || test_gemm_int8(M, N, K, 0.4f, 0, 0, 0, 0, 0, 1, 0) + || test_gemm_int8(M, N, K, 0.5f, 0, 1, 0, 0, 0, 1, 0) + + || test_gemm_int8(M, N, K, 1.2f, 0, 1, 0, 0, 1, 1, 0) + || test_gemm_int8(M, N, K, 1.3f, 1, 1, 0, 0, 1, 1, 1) + || test_gemm_int8(M, N, K, 1.4f, 0, 0, 0, 0, 1, 1, 0) + || test_gemm_int8(M, N, K, 1.5f, 1, 0, 0, 0, 1, 1, 1) + + || test_gemm_int8(M, N, K, -1.2f, 0, 1, 0, 1, 0, 0, 0) + || test_gemm_int8(M, N, K, -1.3f, 1, 1, 0, 1, 0, 0, 0) + || test_gemm_int8(M, N, K, -1.4f, 0, 0, 0, 1, 0, 0, 1) + || test_gemm_int8(M, N, K, -1.5f, 1, 0, 0, 1, 0, 0, 1) + + || test_gemm_int8(M, N, K, -2.0f, 0, 1, 0, 1, 1, 0, 1) + || test_gemm_int8(M, N, K, -3.0f, 1, 1, 0, 1, 1, 0, 1) + || test_gemm_int8(M, N, K, -4.0f, 0, 0, 0, 1, 0, 1, 0) + || test_gemm_int8(M, N, K, -5.0f, 0, 1, 0, 1, 0, 1, 0) + + || test_gemm_int8(M, N, K, -2.1f, 0, 1, 0, 1, 1, 1, 0) + || test_gemm_int8(M, N, K, -3.1f, 1, 1, 0, 1, 1, 1, 1) + || test_gemm_int8(M, N, K, -4.1f, 0, 0, 0, 1, 1, 1, 0) + || test_gemm_int8(M, N, K, -5.1f, 1, 0, 0, 1, 1, 1, 1) + + || test_gemm_int8_fp16s(M, N, K, 1.f, 0, 1, 0, 0, 0, 0, 0) + || test_gemm_int8_fp16s(M, N, K, 1.f, 1, 0, 0, 1, 0, 0, 0); +} + +static int test_gemm_1(int M, int N, int K) +{ + return 0 + || test_gemm_int8_bias(M, N, K, RandomMat(1), 2.1f, 0.5f, 0, 0, 0, 0, 0, 0, 0) + || test_gemm_int8_bias(M, N, K, RandomMat(1), 2.1f, 0.5f, 0, 0, 1, 1, 0, 0, 0) + || test_gemm_int8_bias(M, N, K, RandomMat(M), 3.1f, 0.6f, 0, 1, 2, 0, 0, 0, 0) + || test_gemm_int8_bias(M, N, K, RandomMat(M), 3.1f, 0.6f, 0, 1, 3, 1, 0, 0, 0) + || test_gemm_int8_bias(M, N, K, RandomMat(1, M), 4.1f, 0.7f, 1, 0, 0, 0, 0, 0, 0) + || test_gemm_int8_bias(M, N, K, RandomMat(1, M), 4.1f, 0.7f, 1, 0, 1, 1, 0, 0, 0) + || test_gemm_int8_bias(M, N, K, RandomMat(N, M), 5.1f, -0.8f, 1, 1, 2, 0, 0, 0, 0) + || test_gemm_int8_bias(M, N, K, RandomMat(N, M), 5.1f, -0.8f, 1, 1, 3, 1, 0, 0, 0) + || test_gemm_int8_bias(M, N, K, RandomMat(N, M), 1.f, 1.f, 1, 1, 0, 0, 0, 0, 0) + || test_gemm_int8_bias(M, N, K, RandomMat(N, M), 1.f, 1.f, 1, 1, 1, 1, 0, 0, 0) + || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), 2.1f, -0.5f, 0, 0, 2, 0, 0, 0, 0) + || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), 2.1f, -0.5f, 0, 0, 3, 1, 0, 0, 0) + || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), 0.8f, 1.f, 0, 0, 0, 0, 0, 0, 0) + || test_gemm_int8_bias(M, N, K, RandomMat(N), 0.8f, 1.f, 0, 0, 1, 1, 0, 0, 0) + || test_gemm_int8_bias(M, N, K, RandomMat(N), 3.1f, -0.6f, 0, 1, 2, 0, 0, 0, 0) + || test_gemm_int8_bias(M, N, K, RandomMat(N), 3.1f, -0.6f, 0, 1, 3, 1, 0, 0, 0) + + || test_gemm_int8_bias(M, N, K, RandomMat(1), -2.1f, 0.5f, 0, 0, 0, 0, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(1), -2.1f, 0.5f, 0, 0, 1, 1, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(M), -3.1f, 0.6f, 0, 1, 2, 0, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(M), -3.1f, 0.6f, 0, 1, 3, 1, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(1, M), -4.1f, 0.7f, 1, 0, 0, 0, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(1, M), -4.1f, 0.7f, 1, 0, 1, 1, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(N, M), -5.1f, -0.8f, 1, 1, 2, 0, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(N, M), -5.1f, -0.8f, 1, 1, 3, 1, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(N, M), 1.f, 1.f, 1, 1, 0, 0, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(N, M), 1.f, 1.f, 1, 1, 1, 1, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), -2.1f, -0.5f, 0, 0, 2, 0, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), -2.1f, -0.5f, 0, 0, 3, 1, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), 0.8f, 1.f, 0, 0, 0, 0, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(N), 0.8f, 1.f, 0, 0, 1, 1, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(N), -3.1f, -0.6f, 0, 1, 2, 0, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(N), -3.1f, -0.6f, 0, 1, 3, 1, 1, 1, 1); +} +#endif // NCNN_INT8 + +int main() +{ + SRAND(7767517); + +#if NCNN_INT8 + int mnk[][3] = { + {1, 1, 1}, + {1, 1, 23}, + {1, 1, 47}, + {1, 23, 1}, + {1, 23, 23}, + {1, 31, 1}, + {1, 35, 1}, + {1, 35, 47}, + {1, 47, 1}, + {2, 2, 2}, + {3, 3, 3}, + {4, 4, 4}, + {5, 5, 5}, + {6, 6, 6}, + {7, 7, 7}, + {7, 31, 3}, + {8, 8, 8}, + {12, 12, 23}, + {12, 23, 12}, + {12, 31, 12}, + {15, 15, 15}, + {16, 16, 16}, + {19, 44, 7}, + {20, 28, 7}, + {23, 31, 1}, + {23, 31, 23}, + {24, 24, 47}, + {24, 35, 24}, + {24, 47, 24}, + {31, 31, 31}, + {32, 32, 9}, + {35, 47, 48}, + {35, 48, 47}, + {40, 40, 40}, + {47, 48, 47} + }; + + int mnk_count = sizeof(mnk) / sizeof(int) / 3; + + for (int i = 0; i < mnk_count; i++) + { + int M = mnk[i][0]; + int N = mnk[i][1]; + int K = mnk[i][2]; + + int ret = test_gemm_0(M, N, K) || test_gemm_1(M, N, K); + if (ret != 0) + return ret; + + if (M != N) + { + int ret = test_gemm_0(N, M, K) || test_gemm_1(N, M, K); + if (ret != 0) + return ret; + } + } +#else + // test nothing for non-int8 build +#endif + + return 0; +} diff --git a/tests/test_gemm_4.cpp b/tests/test_gemm_4.cpp new file mode 100644 index 000000000000..3b25cf9e9f97 --- /dev/null +++ b/tests/test_gemm_4.cpp @@ -0,0 +1,140 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "testutil.h" + +#if NCNN_INT8 +static int test_gemm_int8(int M, int N, int K, int TILE_M, int TILE_N, int TILE_K, float alpha, int transA, int transB, int output_transpose) +{ + ncnn::ParamDict pd; + pd.set(0, alpha); + pd.set(1, 1.f); // beta + pd.set(2, transA); + pd.set(3, transB); + pd.set(14, output_transpose); + pd.set(18, 2); // int8_scale_term + + pd.set(20, TILE_M); + pd.set(21, TILE_N); + pd.set(22, TILE_K); + + std::vector weights(0); + + std::vector a(2); + a[0] = transA ? ncnn::Mat(M, K) : ncnn::Mat(K, M); + a[1] = transB ? ncnn::Mat(K, N) : ncnn::Mat(N, K); + + Randomize(a[0], -10.f, 10.f); + Randomize(a[1], -10.f, 10.f); + + int ret = test_layer("Gemm", pd, weights, a); + if (ret != 0) + { + fprintf(stderr, "test_gemm_int8 failed M=%d N=%d K=%d TILE_M=%d TILE_N=%d TILE_K=%d alpha=%f transA=%d transB=%d output_transpose=%d\n", M, N, K, TILE_M, TILE_N, TILE_K, alpha, transA, transB, output_transpose); + } + + return ret; +} + +static int test_gemm_0(int M, int N, int K, int TILE_M, int TILE_N, int TILE_K) +{ + return 0 + || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 2.1f, 0, 0, 0) + || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 3.1f, 0, 1, 0) + || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 4.1f, 1, 0, 0) + || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 5.1f, 1, 1, 0) + || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 2.1f, 0, 0, 1) + || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 3.1f, 0, 1, 1) + || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 4.1f, 1, 0, 1) + || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 5.1f, 1, 1, 1); +} +#endif // NCNN_INT8 + +int main() +{ + SRAND(7767517); + +#if NCNN_INT8 + int mnk[][3] = { + {1, 1, 1}, + {2, 2, 2}, + {3, 3, 3}, + {4, 4, 4}, + {5, 5, 5}, + {6, 6, 6}, + {7, 7, 7}, + {8, 8, 8}, + {15, 15, 15}, + {16, 16, 16}, + {24, 24, 24}, + {31, 31, 31}, + {31, 32, 31}, + {32, 31, 32}, + {32, 32, 32}, + {20, 32, 20}, + {40, 40, 40}, + {47, 47, 47}, + {48, 48, 48}, + {52, 52, 52}, + {63, 64, 63}, + {64, 63, 64}, + {64, 64, 64} + }; + + int tile_mnk[][3] = { + {1, 1, 1}, + {2, 2, 2}, + {4, 4, 4}, + {8, 8, 8}, + {12, 12, 12}, + {16, 16, 16}, + {20, 20, 20}, + {24, 24, 24}, + {28, 28, 28} + }; + + int mnk_count = sizeof(mnk) / sizeof(int) / 3; + int tile_mnk_count = sizeof(tile_mnk) / sizeof(int) / 3; + + for (int i = 0; i < mnk_count; i++) + { + int M = mnk[i][0]; + int N = mnk[i][1]; + int K = mnk[i][2]; + + for (int j = 0; j < tile_mnk_count; j++) + { + int TILE_M = tile_mnk[j][0]; + int TILE_N = tile_mnk[j][1]; + int TILE_K = tile_mnk[j][2]; + + if (TILE_M >= M && TILE_N >= N && TILE_K >= K) + continue; + + int ret = test_gemm_0(M, N, K, TILE_M, TILE_N, TILE_K); + if (ret != 0) + return ret; + } + + // test no tiling + int ret = test_gemm_0(M, N, K, 100, 100, 100); + if (ret != 0) + return ret; + } +#else + // test nothing for non-int8 build +#endif + + return 0; +} diff --git a/tools/modelwriter.h b/tools/modelwriter.h index ff86338bca9c..218b211901fb 100644 --- a/tools/modelwriter.h +++ b/tools/modelwriter.h @@ -1773,6 +1773,7 @@ int ModelWriter::save(const char* parampath, const char* binpath) fprintf_param_value(" 12=%d", output_elempack) fprintf_param_value(" 13=%d", output_elemtype) fprintf_param_value(" 14=%d", output_transpose) + fprintf_param_value(" 18=%d", int8_scale_term) fprintf_param_value(" 20=%d", constant_TILE_M) fprintf_param_value(" 21=%d", constant_TILE_N) fprintf_param_value(" 22=%d", constant_TILE_K) @@ -1789,6 +1790,23 @@ int ModelWriter::save(const char* parampath, const char* binpath) { fwrite_weight_tag_data(op->C_data, bp); } + +#if NCNN_INT8 + // write int8_scale data + if (op->int8_scale_term) + { + if (op->constantA == 1) + { + fwrite_weight_data(op->A_data_int8_scales, bp, 90, 100); + } + if (op->constantB == 1) + { + ncnn::Mat B_data_int8_scales(1); + B_data_int8_scales[0] = op->B_data_int8_scale; + fwrite_weight_data(B_data_int8_scales, bp, 90, 100); + } + } +#endif // NCNN_INT8 } else if (layer->type == "GLU") { diff --git a/tools/quantize/ncnn2int8.cpp b/tools/quantize/ncnn2int8.cpp index 5e92b333aa57..686accc6089c 100644 --- a/tools/quantize/ncnn2int8.cpp +++ b/tools/quantize/ncnn2int8.cpp @@ -134,6 +134,7 @@ class NetQuantize : public ModelWriter int quantize_gru(); int quantize_embed(); + int quantize_gemm(); int fuse_requantize(); }; @@ -613,6 +614,113 @@ int NetQuantize::quantize_embed() return 0; } +int NetQuantize::quantize_gemm() +{ + for (size_t i = 0; i < layers.size(); i++) + { + if (layers[i]->type != "Gemm") + continue; + + // Gemm - quantize weight from fp32 to int8 + ncnn::Gemm* gemm = (ncnn::Gemm*)layers[i]; + + fprintf(stderr, "quantize_gemm %s\n", gemm->name.c_str()); + + // TODO move to ncnn2table + + if (gemm->constantA) + { + if (gemm->transA == 1) + { + // transpose for easier quantization + ncnn::Mat A_data_transposed(gemm->constantK * gemm->constantM); + for (int i = 0; i < gemm->constantM; i++) + { + float* ptr = (float*)A_data_transposed + i * gemm->constantK; + for (int j = 0; j < gemm->constantK; j++) + { + ptr[j] = gemm->A_data[j * gemm->constantM + i]; + } + } + gemm->A_data = A_data_transposed; + gemm->transA = 0; + } + + gemm->A_data_int8_scales.create(gemm->constantM); + for (int i = 0; i < gemm->constantM; i++) + { + float absmax = 0.f; + + const float* ptr = (const float*)gemm->A_data + i * gemm->constantK; + for (int j = 0; j < gemm->constantK; j++) + { + absmax = std::max(absmax, (float)fabs(ptr[j])); + } + + gemm->A_data_int8_scales[i] = absmax == 0.f ? 1.f : 127 / absmax; + } + + ncnn::Mat A_data = gemm->A_data.reshape(gemm->constantK, gemm->constantM); + ncnn::Mat A_data_int8; + + ncnn::Option opt_q = opt; + opt_q.blob_allocator = A_data.allocator; + opt_q.use_packing_layout = false; + ncnn::quantize_to_int8(A_data, A_data_int8, gemm->A_data_int8_scales, opt_q); + if (A_data_int8.empty()) + return -100; + + gemm->A_data = A_data_int8.reshape(gemm->constantK * gemm->constantM); + } + + if (gemm->constantB) + { + if (gemm->transB == 0) + { + // transpose for easier quantization + ncnn::Mat B_data_transposed(gemm->constantK * gemm->constantN); + for (int i = 0; i < gemm->constantN; i++) + { + float* ptr = (float*)B_data_transposed + i * gemm->constantK; + for (int j = 0; j < gemm->constantK; j++) + { + ptr[j] = gemm->B_data[j * gemm->constantN + i]; + } + } + gemm->B_data = B_data_transposed; + gemm->transB = 1; + } + + const float* ptr = gemm->B_data; + float absmax = 0.f; + for (int j = 0; j < gemm->B_data.w; j++) + { + absmax = std::max(absmax, (float)fabs(ptr[j])); + } + + gemm->B_data_int8_scale = absmax == 0.f ? 1.f : 127 / absmax; + + ncnn::Mat B_data_int8_scales(1); + B_data_int8_scales[0] = gemm->B_data_int8_scale; + + ncnn::Mat B_data_int8; + + ncnn::Option opt_q = opt; + opt_q.blob_allocator = gemm->B_data.allocator; + opt_q.use_packing_layout = false; + ncnn::quantize_to_int8(gemm->B_data, B_data_int8, B_data_int8_scales, opt_q); + if (B_data_int8.empty()) + return -100; + + gemm->B_data = B_data_int8; + } + + gemm->int8_scale_term = 2; + } + + return 0; +} + int NetQuantize::fuse_requantize() { const size_t layer_count = layers.size(); @@ -861,6 +969,7 @@ int main(int argc, char** argv) quantizer.quantize_lstm(); quantizer.quantize_gru(); quantizer.quantize_embed(); + quantizer.quantize_gemm(); quantizer.fuse_requantize(); From 66b54cbea2ed7c6a17e04c2b5ad6b749bc8b76ec Mon Sep 17 00:00:00 2001 From: nihui Date: Tue, 15 Oct 2024 16:38:00 +0800 Subject: [PATCH 02/15] multiheadattention int8 quantization (#5733) * x86 vulkan fallback * comment about bf16s --- docs/developer-guide/operators.md | 5 + src/layer/arm/multiheadattention_arm.cpp | 42 +- src/layer/multiheadattention.cpp | 429 ++++++++++++++++++ src/layer/multiheadattention.h | 14 + .../vulkan/multiheadattention_vulkan.cpp | 13 + src/layer/vulkan/multiheadattention_vulkan.h | 2 + src/layer/x86/multiheadattention_x86.cpp | 163 ++++--- tests/test_multiheadattention_1.cpp | 198 ++++++++ tools/modelwriter.h | 14 + tools/quantize/ncnn2int8.cpp | 133 ++++++ 10 files changed, 953 insertions(+), 60 deletions(-) create mode 100644 tests/test_multiheadattention_1.cpp diff --git a/docs/developer-guide/operators.md b/docs/developer-guide/operators.md index 28f1ce626466..4c82fd472c10 100644 --- a/docs/developer-guide/operators.md +++ b/docs/developer-guide/operators.md @@ -1277,6 +1277,7 @@ y = affine(out) | 4 | vdim | int | embed_dim | | | 5 | attn_mask | int | 0 | | | 6 | scale | float | 1.f / sqrt(embed_dim / num_heads) | | +| 18 | int8_scale_term | int | 0 | | | weight | type | shape | | ------------- | ----- | --------------------- | @@ -1288,6 +1289,10 @@ y = affine(out) | v_bias_data | float | [embed_dim] | | out_weight_data| float/fp16/int8 | [qdim * embed_dim] | | out_bias_data | float | [qdim] | +| q_weight_data_int8_scales| float | [embed_dim] | +| k_weight_data_int8_scales| float | [embed_dim] | +| v_weight_data_int8_scales| float | [embed_dim] | +| out_weight_data_int8_scales| float | [1] | # MVN ``` diff --git a/src/layer/arm/multiheadattention_arm.cpp b/src/layer/arm/multiheadattention_arm.cpp index 9fedf8b16d74..46a0ec995e43 100644 --- a/src/layer/arm/multiheadattention_arm.cpp +++ b/src/layer/arm/multiheadattention_arm.cpp @@ -28,7 +28,7 @@ MultiHeadAttention_arm::MultiHeadAttention_arm() #endif #endif // __ARM_NEON - support_bf16_storage = false; + support_bf16_storage = false;// TODO enable bf16 when gemm has proper out_elemtype support q_gemm = 0; k_gemm = 0; @@ -76,10 +76,16 @@ int MultiHeadAttention_arm::create_pipeline(const Option& _opt) pd.set(11, 0); // output_N1M pd.set(12, 1); // output_elempack pd.set(14, 0); // output_transpose +#if NCNN_INT8 + pd.set(18, int8_scale_term); +#endif q_gemm->load_param(pd); - Mat weights[2]; + Mat weights[3]; weights[0] = q_weight_data; weights[1] = q_bias_data; +#if NCNN_INT8 + weights[2] = q_weight_data_int8_scales; +#endif q_gemm->load_model(ModelBinFromMatArray(weights)); q_gemm->create_pipeline(opt); @@ -105,10 +111,16 @@ int MultiHeadAttention_arm::create_pipeline(const Option& _opt) pd.set(11, 0); // output_N1M pd.set(12, 1); // output_elempack pd.set(14, 0); // output_transpose +#if NCNN_INT8 + pd.set(18, int8_scale_term); +#endif k_gemm->load_param(pd); - Mat weights[2]; + Mat weights[3]; weights[0] = k_weight_data; weights[1] = k_bias_data; +#if NCNN_INT8 + weights[2] = k_weight_data_int8_scales; +#endif k_gemm->load_model(ModelBinFromMatArray(weights)); k_gemm->create_pipeline(opt); @@ -134,10 +146,16 @@ int MultiHeadAttention_arm::create_pipeline(const Option& _opt) pd.set(11, 0); // output_N1M pd.set(12, 1); // output_elempack pd.set(14, 0); // output_transpose +#if NCNN_INT8 + pd.set(18, int8_scale_term); +#endif v_gemm->load_param(pd); - Mat weights[2]; + Mat weights[3]; weights[0] = v_weight_data; weights[1] = v_bias_data; +#if NCNN_INT8 + weights[2] = v_weight_data_int8_scales; +#endif v_gemm->load_model(ModelBinFromMatArray(weights)); v_gemm->create_pipeline(opt); @@ -161,10 +179,18 @@ int MultiHeadAttention_arm::create_pipeline(const Option& _opt) pd.set(9, embed_dim); // K = maxk*inch pd.set(10, 4); // constant_broadcast_type_C = null pd.set(11, 0); // output_N1M +#if NCNN_INT8 + pd.set(18, int8_scale_term); +#endif o_gemm->load_param(pd); - Mat weights[2]; + Mat weights[3]; weights[0] = out_weight_data; weights[1] = out_bias_data; +#if NCNN_INT8 + Mat out_weight_data_int8_scales(1); + out_weight_data_int8_scales[0] = out_weight_data_int8_scale; + weights[2] = out_weight_data_int8_scales; +#endif o_gemm->load_model(ModelBinFromMatArray(weights)); o_gemm->create_pipeline(opt); @@ -189,6 +215,9 @@ int MultiHeadAttention_arm::create_pipeline(const Option& _opt) pd.set(10, attn_mask ? 3 : -1); // constant_broadcast_type_C pd.set(11, 0); // output_N1M pd.set(12, 1); // output_elempack +#if NCNN_INT8 + pd.set(18, int8_scale_term); +#endif qk_gemm->load_param(pd); qk_gemm->load_model(ModelBinFromMatArray(0)); Option opt1 = opt; @@ -211,6 +240,9 @@ int MultiHeadAttention_arm::create_pipeline(const Option& _opt) pd.set(11, 0); // output_N1M pd.set(12, 1); // output_elempack pd.set(14, 1); // output_transpose +#if NCNN_INT8 + pd.set(18, int8_scale_term); +#endif qkv_gemm->load_param(pd); qkv_gemm->load_model(ModelBinFromMatArray(0)); Option opt1 = opt; diff --git a/src/layer/multiheadattention.cpp b/src/layer/multiheadattention.cpp index e25eec88a048..253ea47864bf 100644 --- a/src/layer/multiheadattention.cpp +++ b/src/layer/multiheadattention.cpp @@ -31,6 +31,7 @@ int MultiHeadAttention::load_param(const ParamDict& pd) vdim = pd.get(4, embed_dim); attn_mask = pd.get(5, 0); scale = pd.get(6, 1.f / sqrtf(embed_dim / num_heads)); + int8_scale_term = pd.get(18, 0); return 0; } @@ -71,12 +72,29 @@ int MultiHeadAttention::load_model(const ModelBin& mb) if (out_bias_data.empty()) return -100; +#if NCNN_INT8 + if (int8_scale_term) + { + q_weight_data_int8_scales = mb.load(embed_dim, 1); + k_weight_data_int8_scales = mb.load(embed_dim, 1); + v_weight_data_int8_scales = mb.load(embed_dim, 1); + out_weight_data_int8_scale = mb.load(1, 1)[0]; + } +#endif // NCNN_INT8 + return 0; } // refers to https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html int MultiHeadAttention::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { +#if NCNN_INT8 + if (int8_scale_term) + { + return forward_int8(bottom_blobs, top_blobs, opt); + } +#endif + const Mat& q_blob = bottom_blobs[0]; const Mat& k_blob = (bottom_blobs.size() == 1 || (bottom_blobs.size() == 2 && attn_mask)) ? q_blob : bottom_blobs[1]; const Mat& v_blob = (bottom_blobs.size() == 1 || (bottom_blobs.size() == 2 && attn_mask)) ? q_blob : (bottom_blobs.size() == 2 || (bottom_blobs.size() == 3 && attn_mask)) ? k_blob : bottom_blobs[2]; @@ -316,4 +334,415 @@ int MultiHeadAttention::forward(const std::vector& bottom_blobs, std::vecto return 0; } +#if NCNN_INT8 +static inline signed char float2int8(float v) +{ + int int32 = static_cast(round(v)); + if (int32 > 127) return 127; + if (int32 < -127) return -127; + return (signed char)int32; +} + +static void dynamic_quantize_2d(const Mat& blob, Mat& blob_int8, float& scale, const Option& opt) +{ + blob_int8.create(blob.w, blob.h, (size_t)1u, 1, opt.workspace_allocator); + + float absmax = 0.f; + for (int i = 0; i < blob_int8.h; i++) + { + const float* ptr = blob.row(i); + + for (int j = 0; j < blob_int8.w; j++) + { + absmax = std::max(absmax, (float)fabs(ptr[j])); + } + } + + scale = absmax == 0.f ? 1.f : 127.f / absmax; + + for (int i = 0; i < blob_int8.h; i++) + { + const float* ptr = blob.row(i); + signed char* outptr = blob_int8.row(i); + + for (int j = 0; j < blob_int8.w; j++) + { + outptr[j] = float2int8(ptr[j] * scale); + } + } +} + +static void dynamic_quantize_2d_per_h(const Mat& blob, Mat& blob_int8, Mat& scales, const Option& opt) +{ + blob_int8.create(blob.w, blob.h, (size_t)1u, 1, opt.workspace_allocator); + scales.create(blob.h, (size_t)4u, 1, opt.workspace_allocator); + + for (int i = 0; i < blob_int8.h; i++) + { + const float* ptr = blob.row(i); + + float absmax = 0.f; + for (int j = 0; j < blob_int8.w; j++) + { + absmax = std::max(absmax, (float)fabs(ptr[j])); + } + + scales[i] = absmax == 0.f ? 1.f : 127.f / absmax; + } + + for (int i = 0; i < blob_int8.h; i++) + { + const float* ptr = blob.row(i); + signed char* outptr = blob_int8.row(i); + const float scale = scales[i]; + + for (int j = 0; j < blob_int8.w; j++) + { + outptr[j] = float2int8(ptr[j] * scale); + } + } +} + +int MultiHeadAttention::forward_int8(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + const Mat& q_blob = bottom_blobs[0]; + const Mat& k_blob = (bottom_blobs.size() == 1 || (bottom_blobs.size() == 2 && attn_mask)) ? q_blob : bottom_blobs[1]; + const Mat& v_blob = (bottom_blobs.size() == 1 || (bottom_blobs.size() == 2 && attn_mask)) ? q_blob : (bottom_blobs.size() == 2 || (bottom_blobs.size() == 3 && attn_mask)) ? k_blob : bottom_blobs[2]; + const Mat& attn_mask_blob = attn_mask ? bottom_blobs[bottom_blobs.size() - 1] : Mat(); + + const int src_seqlen = q_blob.h; + const int dst_seqlen = k_blob.h; + const int embed_dim_per_head = embed_dim / num_heads; + const int qdim = weight_data_size / embed_dim; + + // assert k_blob.h == v_blob.h + + Mat& top_blob = top_blobs[0]; + top_blob.create(qdim, src_seqlen, 4u, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + Mat xq(embed_dim_per_head, src_seqlen, num_heads, 4u, opt.workspace_allocator); + if (xq.empty()) + return -100; + Mat xk(embed_dim_per_head, dst_seqlen, num_heads, 4u, opt.workspace_allocator); + if (xk.empty()) + return -100; + Mat xv(dst_seqlen, embed_dim_per_head, num_heads, 4u, opt.workspace_allocator); + if (xv.empty()) + return -100; + + Mat xqk(dst_seqlen, src_seqlen, num_heads, 4u, opt.workspace_allocator); + if (xqk.empty()) + return -100; + + Mat xqkv(embed_dim_per_head, num_heads, src_seqlen, 4u, opt.workspace_allocator); + if (xqkv.empty()) + return -100; + + // dynamic quantize q_blob + Mat q_blob_int8; + float q_blob_int8_scale; + dynamic_quantize_2d(q_blob, q_blob_int8, q_blob_int8_scale, opt); + + // dynamic quantize k_blob + Mat k_blob_int8; + float k_blob_int8_scale; + if (bottom_blobs.size() == 1) + { + k_blob_int8 = q_blob_int8; + k_blob_int8_scale = q_blob_int8_scale; + } + else + { + dynamic_quantize_2d(k_blob, k_blob_int8, k_blob_int8_scale, opt); + } + + // dynamic quantize v_blob + Mat v_blob_int8; + float v_blob_int8_scale; + if (bottom_blobs.size() == 1) + { + v_blob_int8 = q_blob_int8; + v_blob_int8_scale = q_blob_int8_scale; + } + else if (bottom_blobs.size() == 2) + { + v_blob_int8 = k_blob_int8; + v_blob_int8_scale = k_blob_int8_scale; + } + else + { + dynamic_quantize_2d(v_blob, v_blob_int8, v_blob_int8_scale, opt); + } + + // NCNN_LOGE("%.4f %.4f", q_weight_data_int8_scale, q_blob_int8_scale); + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_heads; q++) + { + // xq = affine(q) * scale + { + Mat outm = xq.channel(q); + + for (int i = 0; i < src_seqlen; i++) + { + float* outptr = outm.row(i); + + for (int j = 0; j < embed_dim_per_head; j++) + { + const signed char* ptr = q_blob_int8.row(i); + const signed char* kptr = (const signed char*)q_weight_data + qdim * (q * embed_dim_per_head + j); + + int sum = 0; + for (int k = 0; k < qdim; k++) + { + sum += *ptr++ * *kptr++; + } + const float q_descale = 1.f / (q_weight_data_int8_scales[q * embed_dim_per_head + j] * q_blob_int8_scale); + float sum_fp32 = sum * q_descale + q_bias_data[q * embed_dim_per_head + j]; + + outptr[j] = sum_fp32 * scale; + } + } + } + + // xk = affine(k) + { + float* outptr = xk.channel(q); + + for (int i = 0; i < k_blob_int8.h; i++) + { + for (int j = 0; j < embed_dim_per_head; j++) + { + const signed char* ptr = k_blob_int8.row(i); + const signed char* kptr = (const signed char*)k_weight_data + kdim * (q * embed_dim_per_head + j); + + int sum = 0; + for (int k = 0; k < kdim; k++) + { + sum += *ptr++ * *kptr++; + } + const float k_descale = 1.f / (k_weight_data_int8_scales[q * embed_dim_per_head + j] * k_blob_int8_scale); + float sum_fp32 = sum * k_descale + k_bias_data[q * embed_dim_per_head + j]; + + *outptr++ = sum_fp32; + } + } + } + + // xv = affine(v) + { + Mat outm = xv.channel(q); + + for (int i = 0; i < embed_dim_per_head; i++) + { + float* outptr = outm.row(i); + + for (int j = 0; j < v_blob_int8.h; j++) + { + const signed char* ptr = v_blob_int8.row(j); + const signed char* kptr = (const signed char*)v_weight_data + vdim * (q * embed_dim_per_head + i); + + int sum = 0; + for (int k = 0; k < vdim; k++) + { + sum += *ptr++ * *kptr++; + } + const float v_descale = 1.f / (v_weight_data_int8_scales[q * embed_dim_per_head + i] * v_blob_int8_scale); + float sum_fp32 = sum * v_descale + v_bias_data[q * embed_dim_per_head + i]; + + *outptr++ = sum_fp32; + } + } + } + + // xqk = xq * xk + // xq (embed_dim_per_head, src_seqlen) + // xk (embed_dim_per_head, dst_seqlen) + { + const Mat xqm = xq.channel(q); + const Mat xkm = xk.channel(q); + + Mat outm = xqk.channel(q); + + // dynamic quantize xqm per h + Mat xqm_int8; + Mat xqm_int8_scales; + dynamic_quantize_2d_per_h(xqm, xqm_int8, xqm_int8_scales, opt); + + // dynamic quantize xkm + Mat xkm_int8; + float xkm_int8_scale; + dynamic_quantize_2d(xkm, xkm_int8, xkm_int8_scale, opt); + + for (int i = 0; i < src_seqlen; i++) + { + float* outptr = outm.row(i); + const float xqk_descale = 1.f / (xqm_int8_scales[i] * xkm_int8_scale); + + for (int j = 0; j < dst_seqlen; j++) + { + const signed char* qptr = xqm_int8.row(i); + const signed char* kptr = xkm_int8.row(j); + + int sum = 0; + for (int k = 0; k < embed_dim_per_head; k++) + { + sum += *qptr++ * *kptr++; + } + float sum_fp32 = sum * xqk_descale; + + outptr[j] = sum_fp32; + } + } + } + + // xqk = xqk + mask + if (attn_mask) + { + const Mat& maskm = attn_mask_blob.dims == 3 ? attn_mask_blob.channel(q) : attn_mask_blob; + Mat outm = xqk.channel(q); + + for (int i = 0; i < src_seqlen; i++) + { + const float* mptr = maskm.row(i); + float* outptr = outm.row(i); + + for (int j = 0; j < dst_seqlen; j++) + { + outptr[j] += mptr[j]; + } + } + } + + // softmax(xqk) + { + Mat outm = xqk.channel(q); + + for (int i = 0; i < src_seqlen; i++) + { + float* ptr = outm.row(i); + + float max = -FLT_MAX; + for (int j = 0; j < dst_seqlen; j++) + { + max = std::max(max, ptr[j]); + } + + float sum = 0.f; + for (int j = 0; j < dst_seqlen; j++) + { + ptr[j] = (float)(expf(ptr[j] - max)); + sum += ptr[j]; + } + + for (int j = 0; j < dst_seqlen; j++) + { + ptr[j] /= sum; + } + } + } + + // xqkv = xqk * xv + // xqk (dst_seqlen, src_seqlen) + // xv (dst_seqlen, embed_dim_per_head) + // out (embed_dim_per_head, num_heads, src_seqlen) + { + const Mat xqkm = xqk.channel(q); + const Mat xvm = xv.channel(q); + + // dynamic quantize xqkm + Mat xqkm_int8; + Mat xqkm_int8_scales; + dynamic_quantize_2d_per_h(xqkm, xqkm_int8, xqkm_int8_scales, opt); + + // dynamic quantize xvm per h + Mat xvm_int8; + float xvm_int8_scale; + dynamic_quantize_2d(xvm, xvm_int8, xvm_int8_scale, opt); + + for (int i = 0; i < src_seqlen; i++) + { + float* outptr = xqkv.channel(i).row(q); + const float xqkv_descale = 1.f / (xqkm_int8_scales[i] * xvm_int8_scale); + + for (int j = 0; j < embed_dim_per_head; j++) + { + const signed char* qkptr = xqkm_int8.row(i); + const signed char* vptr = xvm_int8.row(j); + + int sum = 0; + for (int k = 0; k < dst_seqlen; k++) + { + sum += *qkptr++ * *vptr++; + } + float sum_fp32 = sum * xqkv_descale; + + outptr[j] = sum_fp32; + } + } + } + } + + // dynamic quantize xqkv + Mat xqkv_int8; + Mat xqkv_int8_scales; + { + xqkv_int8.create(xqkv.w, xqkv.h, xqkv.c, (size_t)1u, 1, opt.workspace_allocator); + xqkv_int8_scales.create(src_seqlen, (size_t)4u, 1, opt.workspace_allocator); + + for (int i = 0; i < xqkv_int8.c; i++) + { + const float* ptr = xqkv.channel(i); + + float absmax = 0.f; + for (int j = 0; j < xqkv_int8.w * xqkv_int8.h; j++) + { + absmax = std::max(absmax, (float)fabs(ptr[j])); + } + + xqkv_int8_scales[i] = absmax == 0.f ? 1.f : 127.f / absmax; + } + + for (int i = 0; i < xqkv_int8.c; i++) + { + const float* ptr = xqkv.channel(i); + signed char* outptr = xqkv_int8.channel(i); + + for (int j = 0; j < xqkv_int8.w * xqkv_int8.h; j++) + { + outptr[j] = float2int8(ptr[j] * xqkv_int8_scales[i]); + } + } + } + + // out = affine(xqkv) + // xqkv (embed_dim, src_seqlen) + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < src_seqlen; i++) + { + float* outptr = top_blob.row(i); + + for (int j = 0; j < qdim; j++) + { + const signed char* ptr = xqkv_int8.channel(i); + const signed char* kptr = (const signed char*)out_weight_data + embed_dim * j; + + int sum = 0; + for (int k = 0; k < embed_dim; k++) + { + sum += *ptr++ * *kptr++; + } + const float out_descale = 1.f / (out_weight_data_int8_scale * xqkv_int8_scales[i]); + float sum_fp32 = sum * out_descale + out_bias_data[j]; + + outptr[j] = sum_fp32; + } + } + + return 0; +} +#endif + } // namespace ncnn diff --git a/src/layer/multiheadattention.h b/src/layer/multiheadattention.h index 55764bd9c64e..6d32cfae2dd5 100644 --- a/src/layer/multiheadattention.h +++ b/src/layer/multiheadattention.h @@ -30,6 +30,11 @@ class MultiHeadAttention : public Layer virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; +protected: +#if NCNN_INT8 + int forward_int8(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; +#endif + public: int embed_dim; int num_heads; @@ -39,6 +44,8 @@ class MultiHeadAttention : public Layer int attn_mask; float scale; + int int8_scale_term; + Mat q_weight_data; Mat q_bias_data; Mat k_weight_data; @@ -47,6 +54,13 @@ class MultiHeadAttention : public Layer Mat v_bias_data; Mat out_weight_data; Mat out_bias_data; + +#if NCNN_INT8 + Mat q_weight_data_int8_scales; + Mat k_weight_data_int8_scales; + Mat v_weight_data_int8_scales; + float out_weight_data_int8_scale; +#endif }; } // namespace ncnn diff --git a/src/layer/vulkan/multiheadattention_vulkan.cpp b/src/layer/vulkan/multiheadattention_vulkan.cpp index 1abc09c30e6b..b8cfd399cf20 100644 --- a/src/layer/vulkan/multiheadattention_vulkan.cpp +++ b/src/layer/vulkan/multiheadattention_vulkan.cpp @@ -43,6 +43,19 @@ MultiHeadAttention_vulkan::MultiHeadAttention_vulkan() pipeline_multiheadattention_qkv_cross_pack4to1 = 0; } +int MultiHeadAttention_vulkan::load_param(const ParamDict& pd) +{ + int ret = MultiHeadAttention::load_param(pd); + + if (int8_scale_term) + { + support_vulkan = false; + support_image_storage = false; + } + + return ret; +} + int MultiHeadAttention_vulkan::create_pipeline(const Option& opt) { const int embed_dim_per_head = embed_dim / num_heads; diff --git a/src/layer/vulkan/multiheadattention_vulkan.h b/src/layer/vulkan/multiheadattention_vulkan.h index 3b77d96db484..58e06bfc1915 100644 --- a/src/layer/vulkan/multiheadattention_vulkan.h +++ b/src/layer/vulkan/multiheadattention_vulkan.h @@ -24,6 +24,8 @@ class MultiHeadAttention_vulkan : public MultiHeadAttention public: MultiHeadAttention_vulkan(); + virtual int load_param(const ParamDict& pd); + virtual int create_pipeline(const Option& opt); virtual int destroy_pipeline(const Option& opt); diff --git a/src/layer/x86/multiheadattention_x86.cpp b/src/layer/x86/multiheadattention_x86.cpp index 9bddb3a78ef7..08a0c50d462c 100644 --- a/src/layer/x86/multiheadattention_x86.cpp +++ b/src/layer/x86/multiheadattention_x86.cpp @@ -36,8 +36,26 @@ MultiHeadAttention_x86::MultiHeadAttention_x86() o_gemm = 0; } -int MultiHeadAttention_x86::create_pipeline(const Option& opt) +int MultiHeadAttention_x86::create_pipeline(const Option& _opt) { + Option opt = _opt; + if (int8_scale_term) + { + support_packing = false; + + opt.use_packing_layout = false; // TODO enable packing + } + + { + qk_softmax = ncnn::create_layer_cpu(ncnn::LayerType::Softmax); + ncnn::ParamDict pd; + pd.set(0, -1); + pd.set(1, 1); + qk_softmax->load_param(pd); + qk_softmax->load_model(ModelBinFromMatArray(0)); + qk_softmax->create_pipeline(opt); + } + const int qdim = weight_data_size / embed_dim; { @@ -57,10 +75,16 @@ int MultiHeadAttention_x86::create_pipeline(const Option& opt) pd.set(11, 0); // output_N1M pd.set(12, 1); // output_elempack pd.set(14, 0); // output_transpose +#if NCNN_INT8 + pd.set(18, int8_scale_term); +#endif q_gemm->load_param(pd); - Mat weights[2]; + Mat weights[3]; weights[0] = q_weight_data; weights[1] = q_bias_data; +#if NCNN_INT8 + weights[2] = q_weight_data_int8_scales; +#endif q_gemm->load_model(ModelBinFromMatArray(weights)); q_gemm->create_pipeline(opt); @@ -86,10 +110,16 @@ int MultiHeadAttention_x86::create_pipeline(const Option& opt) pd.set(11, 0); // output_N1M pd.set(12, 1); // output_elempack pd.set(14, 0); // output_transpose +#if NCNN_INT8 + pd.set(18, int8_scale_term); +#endif k_gemm->load_param(pd); - Mat weights[2]; + Mat weights[3]; weights[0] = k_weight_data; weights[1] = k_bias_data; +#if NCNN_INT8 + weights[2] = k_weight_data_int8_scales; +#endif k_gemm->load_model(ModelBinFromMatArray(weights)); k_gemm->create_pipeline(opt); @@ -115,10 +145,16 @@ int MultiHeadAttention_x86::create_pipeline(const Option& opt) pd.set(11, 0); // output_N1M pd.set(12, 1); // output_elempack pd.set(14, 0); // output_transpose +#if NCNN_INT8 + pd.set(18, int8_scale_term); +#endif v_gemm->load_param(pd); - Mat weights[2]; + Mat weights[3]; weights[0] = v_weight_data; weights[1] = v_bias_data; +#if NCNN_INT8 + weights[2] = v_weight_data_int8_scales; +#endif v_gemm->load_model(ModelBinFromMatArray(weights)); v_gemm->create_pipeline(opt); @@ -129,6 +165,41 @@ int MultiHeadAttention_x86::create_pipeline(const Option& opt) } } + { + o_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm); + ncnn::ParamDict pd; + pd.set(2, 1); // transA + pd.set(3, 1); // transB + pd.set(4, 0); // constantA + pd.set(5, 1); // constantB + pd.set(6, 1); // constantC + pd.set(7, 0); // M = outch + pd.set(8, qdim); // N = size + pd.set(9, embed_dim); // K = maxk*inch + pd.set(10, 4); // constant_broadcast_type_C + pd.set(11, 0); // output_N1M +#if NCNN_INT8 + pd.set(18, int8_scale_term); +#endif + o_gemm->load_param(pd); + Mat weights[3]; + weights[0] = out_weight_data; + weights[1] = out_bias_data; +#if NCNN_INT8 + Mat out_weight_data_int8_scales(1); + out_weight_data_int8_scales[0] = out_weight_data_int8_scale; + weights[2] = out_weight_data_int8_scales; +#endif + o_gemm->load_model(ModelBinFromMatArray(weights)); + o_gemm->create_pipeline(opt); + + if (opt.lightmode) + { + out_weight_data.release(); + out_bias_data.release(); + } + } + { qk_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm); ncnn::ParamDict pd; @@ -143,12 +214,16 @@ int MultiHeadAttention_x86::create_pipeline(const Option& opt) pd.set(10, attn_mask ? 3 : -1); // constant_broadcast_type_C pd.set(11, 0); // output_N1M pd.set(12, 1); // output_elempack +#if NCNN_INT8 + pd.set(18, int8_scale_term); +#endif qk_gemm->load_param(pd); qk_gemm->load_model(ModelBinFromMatArray(0)); Option opt1 = opt; opt1.num_threads = 1; qk_gemm->create_pipeline(opt1); } + { qkv_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm); ncnn::ParamDict pd; @@ -164,6 +239,9 @@ int MultiHeadAttention_x86::create_pipeline(const Option& opt) pd.set(11, 0); // output_N1M pd.set(12, 1); // output_elempack pd.set(14, 1); // output_transpose +#if NCNN_INT8 + pd.set(18, int8_scale_term); +#endif qkv_gemm->load_param(pd); qkv_gemm->load_model(ModelBinFromMatArray(0)); Option opt1 = opt; @@ -171,48 +249,24 @@ int MultiHeadAttention_x86::create_pipeline(const Option& opt) qkv_gemm->create_pipeline(opt1); } + return 0; +} + +int MultiHeadAttention_x86::destroy_pipeline(const Option& _opt) +{ + Option opt = _opt; + if (int8_scale_term) { - qk_softmax = ncnn::create_layer_cpu(ncnn::LayerType::Softmax); - ncnn::ParamDict pd; - pd.set(0, -1); - pd.set(1, 1); - qk_softmax->load_param(pd); - qk_softmax->load_model(ModelBinFromMatArray(0)); - qk_softmax->create_pipeline(opt); + opt.use_packing_layout = false; // TODO enable packing } + if (qk_softmax) { - o_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm); - ncnn::ParamDict pd; - pd.set(2, 1); // transA - pd.set(3, 1); // transB - pd.set(4, 0); // constantA - pd.set(5, 1); // constantB - pd.set(6, 1); // constantC - pd.set(7, 0); // M = outch - pd.set(8, qdim); // N = size - pd.set(9, embed_dim); // K = maxk*inch - pd.set(10, 4); // constant_broadcast_type_C - pd.set(11, 0); // output_N1M - o_gemm->load_param(pd); - Mat weights[2]; - weights[0] = out_weight_data; - weights[1] = out_bias_data; - o_gemm->load_model(ModelBinFromMatArray(weights)); - o_gemm->create_pipeline(opt); - - if (opt.lightmode) - { - out_weight_data.release(); - out_bias_data.release(); - } + qk_softmax->destroy_pipeline(opt); + delete qk_softmax; + qk_softmax = 0; } - return 0; -} - -int MultiHeadAttention_x86::destroy_pipeline(const Option& opt) -{ if (q_gemm) { q_gemm->destroy_pipeline(opt); @@ -234,6 +288,13 @@ int MultiHeadAttention_x86::destroy_pipeline(const Option& opt) v_gemm = 0; } + if (o_gemm) + { + o_gemm->destroy_pipeline(opt); + delete o_gemm; + o_gemm = 0; + } + if (qk_gemm) { qk_gemm->destroy_pipeline(opt); @@ -247,30 +308,22 @@ int MultiHeadAttention_x86::destroy_pipeline(const Option& opt) qkv_gemm = 0; } - if (qk_softmax) - { - qk_softmax->destroy_pipeline(opt); - delete qk_softmax; - qk_softmax = 0; - } - - if (o_gemm) - { - o_gemm->destroy_pipeline(opt); - delete o_gemm; - o_gemm = 0; - } - return 0; } -int MultiHeadAttention_x86::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +int MultiHeadAttention_x86::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& _opt) const { const Mat& q_blob = bottom_blobs[0]; const Mat& k_blob = (bottom_blobs.size() == 1 || (bottom_blobs.size() == 2 && attn_mask)) ? q_blob : bottom_blobs[1]; const Mat& v_blob = (bottom_blobs.size() == 1 || (bottom_blobs.size() == 2 && attn_mask)) ? q_blob : (bottom_blobs.size() == 2 || (bottom_blobs.size() == 3 && attn_mask)) ? k_blob : bottom_blobs[2]; const Mat& attn_mask_blob = attn_mask ? bottom_blobs[bottom_blobs.size() - 1] : Mat(); + Option opt = _opt; + if (int8_scale_term) + { + opt.use_packing_layout = false; // TODO enable packing + } + Mat attn_mask_blob_unpacked; if (attn_mask && attn_mask_blob.elempack != 1) { diff --git a/tests/test_multiheadattention_1.cpp b/tests/test_multiheadattention_1.cpp new file mode 100644 index 000000000000..c29930a0be8c --- /dev/null +++ b/tests/test_multiheadattention_1.cpp @@ -0,0 +1,198 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "testutil.h" + +#if NCNN_INT8 +static int test_multiheadattention_int8(const ncnn::Mat& q, const ncnn::Mat& k, const ncnn::Mat& v, int embed_dim, int num_heads, int attn_mask) +{ + const int qdim = q.w; + const int kdim = k.w; + const int vdim = v.w; + + ncnn::ParamDict pd; + pd.set(0, embed_dim); + pd.set(1, num_heads); + pd.set(2, embed_dim * qdim); + pd.set(3, kdim); + pd.set(4, vdim); + pd.set(5, attn_mask); + pd.set(6, 1.f / sqrtf(embed_dim / num_heads)); + pd.set(18, 2); // int8_scale_term + + std::vector weights(12); + weights[0] = RandomS8Mat(embed_dim * qdim); + weights[1] = RandomMat(embed_dim); + weights[2] = RandomS8Mat(embed_dim * kdim); + weights[3] = RandomMat(embed_dim); + weights[4] = RandomS8Mat(embed_dim * vdim); + weights[5] = RandomMat(embed_dim); + weights[6] = RandomS8Mat(qdim * embed_dim); + weights[7] = RandomMat(qdim); + weights[8] = RandomMat(embed_dim, 160.f, 200.f); + weights[9] = RandomMat(embed_dim, 160.f, 200.f); + weights[10] = RandomMat(embed_dim, 160.f, 200.f); + weights[11] = RandomMat(1, 160.f, 200.f); + + std::vector as(3); + as[0] = q; + as[1] = k; + as[2] = v; + + if (attn_mask) + { + as.push_back(RandomMat(k.h, q.h)); + } + + float epsilon = 0.15; + + int ret = test_layer("MultiHeadAttention", pd, weights, as, 1, epsilon); + if (ret != 0) + { + fprintf(stderr, "test_multiheadattention_int8 failed q=(%d %d) k=(%d %d) v=(%d %d) embed_dim=%d num_heads=%d kdim=%d vdim=%d attn_mask=%d\n", q.w, q.h, k.w, k.h, v.w, v.h, embed_dim, num_heads, kdim, vdim, attn_mask); + } + + return ret; +} + +static int test_multiheadattention_int8_samekv(const ncnn::Mat& q, const ncnn::Mat& kv, int embed_dim, int num_heads) +{ + const int qdim = q.w; + const int kvdim = kv.w; + + ncnn::ParamDict pd; + pd.set(0, embed_dim); + pd.set(1, num_heads); + pd.set(2, embed_dim * qdim); + pd.set(3, kvdim); + pd.set(4, kvdim); + pd.set(6, 1.f / sqrtf(embed_dim / num_heads)); + pd.set(18, 2); // int8_scale_term + + std::vector weights(12); + weights[0] = RandomS8Mat(embed_dim * qdim); + weights[1] = RandomMat(embed_dim); + weights[2] = RandomS8Mat(embed_dim * kvdim); + weights[3] = RandomMat(embed_dim); + weights[4] = RandomS8Mat(embed_dim * kvdim); + weights[5] = RandomMat(embed_dim); + weights[6] = RandomS8Mat(qdim * embed_dim); + weights[7] = RandomMat(qdim); + weights[8] = RandomMat(embed_dim, 160.f, 200.f); + weights[9] = RandomMat(embed_dim, 160.f, 200.f); + weights[10] = RandomMat(embed_dim, 160.f, 200.f); + weights[11] = RandomMat(1, 160.f, 200.f); + + std::vector as(2); + as[0] = q; + as[1] = kv; + + float epsilon = 0.15; + + int ret = test_layer("MultiHeadAttention", pd, weights, as, 1, epsilon); + if (ret != 0) + { + fprintf(stderr, "test_multiheadattention_int8_samekv failed q=(%d %d) kv=(%d %d) embed_dim=%d num_heads=%d kvdim=%d\n", q.w, q.h, kv.w, kv.h, embed_dim, num_heads, kvdim); + } + + return ret; +} + +static int test_multiheadattention_int8_sameqkv(const ncnn::Mat& a, int embed_dim, int num_heads) +{ + const int qdim = a.w; + + ncnn::ParamDict pd; + pd.set(0, embed_dim); + pd.set(1, num_heads); + pd.set(2, embed_dim * qdim); + pd.set(3, qdim); + pd.set(4, qdim); + pd.set(6, 1.f / sqrtf(embed_dim / num_heads)); + pd.set(18, 2); // int8_scale_term + + std::vector weights(12); + weights[0] = RandomS8Mat(embed_dim * qdim); + weights[1] = RandomMat(embed_dim); + weights[2] = RandomS8Mat(embed_dim * qdim); + weights[3] = RandomMat(embed_dim); + weights[4] = RandomS8Mat(embed_dim * qdim); + weights[5] = RandomMat(embed_dim); + weights[6] = RandomS8Mat(qdim * embed_dim); + weights[7] = RandomMat(qdim); + weights[8] = RandomMat(embed_dim, 160.f, 200.f); + weights[9] = RandomMat(embed_dim, 160.f, 200.f); + weights[10] = RandomMat(embed_dim, 160.f, 200.f); + weights[11] = RandomMat(1, 160.f, 200.f); + + std::vector as(1); + as[0] = a; + + float epsilon = 0.15; + + int ret = test_layer("MultiHeadAttention", pd, weights, as, 1, epsilon); + if (ret != 0) + { + fprintf(stderr, "test_multiheadattention_int8_sameqkv failed a=(%d %d) embed_dim=%d num_heads=%d\n", a.w, a.h, embed_dim, num_heads); + } + + return ret; +} + +static int test_multiheadattention_0() +{ + return 0 + || test_multiheadattention_int8(RandomMat(62, 66), RandomMat(32, 66), RandomMat(20, 66), 62, 2, 0) + || test_multiheadattention_int8(RandomMat(26, 64), RandomMat(32, 64), RandomMat(18, 64), 26, 2, 1) + || test_multiheadattention_int8(RandomMat(64, 128), RandomMat(64, 128), RandomMat(64, 128), 64, 4, 0) + || test_multiheadattention_int8(RandomMat(48, 127), RandomMat(64, 127), RandomMat(64, 127), 64, 16, 1) + || test_multiheadattention_int8(RandomMat(16, 128), RandomMat(44, 128), RandomMat(55, 128), 16, 2, 0) + || test_multiheadattention_int8(RandomMat(12, 128), RandomMat(44, 127), RandomMat(55, 127), 16, 4, 1) + || test_multiheadattention_int8(RandomMat(12, 17), RandomMat(28, 127), RandomMat(32, 127), 12, 3, 0) + || test_multiheadattention_int8(RandomMat(12, 17), RandomMat(28, 32), RandomMat(11, 32), 12, 3, 1); +} + +static int test_multiheadattention_1() +{ + return 0 + || test_multiheadattention_int8_samekv(RandomMat(64, 128), RandomMat(64, 128), 64, 4) + || test_multiheadattention_int8_samekv(RandomMat(48, 127), RandomMat(64, 127), 64, 16) + || test_multiheadattention_int8_samekv(RandomMat(16, 128), RandomMat(44, 128), 16, 2) + || test_multiheadattention_int8_samekv(RandomMat(12, 128), RandomMat(22, 127), 16, 4) + || test_multiheadattention_int8_samekv(RandomMat(12, 17), RandomMat(28, 127), 12, 3) + || test_multiheadattention_int8_samekv(RandomMat(12, 17), RandomMat(11, 32), 12, 3); +} + +static int test_multiheadattention_2() +{ + return 0 + || test_multiheadattention_int8_sameqkv(RandomMat(64, 128), 64, 4) + || test_multiheadattention_int8_sameqkv(RandomMat(48, 127), 64, 8); +} +#endif + +int main() +{ + SRAND(7767517); + +#if NCNN_INT8 + return 0 + || test_multiheadattention_0() + || test_multiheadattention_1() + || test_multiheadattention_2(); +#else + // test nothing + return 0; +#endif +} diff --git a/tools/modelwriter.h b/tools/modelwriter.h index 218b211901fb..6d73952e4cf1 100644 --- a/tools/modelwriter.h +++ b/tools/modelwriter.h @@ -2038,6 +2038,7 @@ int ModelWriter::save(const char* parampath, const char* binpath) fprintf_param_value(" 4=%d", vdim) fprintf_param_value(" 5=%d", attn_mask) fprintf_param_value(" 6=%e", scale) + fprintf_param_value(" 18=%d", int8_scale_term) fwrite_weight_tag_data(op->q_weight_data, bp); fwrite_weight_data(op->q_bias_data, bp); @@ -2047,6 +2048,19 @@ int ModelWriter::save(const char* parampath, const char* binpath) fwrite_weight_data(op->v_bias_data, bp); fwrite_weight_tag_data(op->out_weight_data, bp); fwrite_weight_data(op->out_bias_data, bp); + +#if NCNN_INT8 + // write int8_scale data + if (op->int8_scale_term) + { + fwrite_weight_data(op->q_weight_data_int8_scales, bp, 90, 100); + fwrite_weight_data(op->k_weight_data_int8_scales, bp, 90, 100); + fwrite_weight_data(op->v_weight_data_int8_scales, bp, 90, 100); + ncnn::Mat out_weight_data_int8_scales(1); + out_weight_data_int8_scales[0] = op->out_weight_data_int8_scale; + fwrite_weight_data(out_weight_data_int8_scales, bp, 90, 100); + } +#endif // NCNN_INT8 } else if (layer->type == "MVN") { diff --git a/tools/quantize/ncnn2int8.cpp b/tools/quantize/ncnn2int8.cpp index 686accc6089c..cbc67d1d9e3f 100644 --- a/tools/quantize/ncnn2int8.cpp +++ b/tools/quantize/ncnn2int8.cpp @@ -135,6 +135,7 @@ class NetQuantize : public ModelWriter int quantize_embed(); int quantize_gemm(); + int quantize_multiheadattention(); int fuse_requantize(); }; @@ -721,6 +722,137 @@ int NetQuantize::quantize_gemm() return 0; } +int NetQuantize::quantize_multiheadattention() +{ + for (size_t i = 0; i < layers.size(); i++) + { + if (layers[i]->type != "MultiHeadAttention") + continue; + + // MultiHeadAttention - quantize weight from fp32 to int8 + ncnn::MultiHeadAttention* mha = (ncnn::MultiHeadAttention*)layers[i]; + + fprintf(stderr, "quantize_multiheadattention %s\n", mha->name.c_str()); + + // TODO move to ncnn2table + + const int qdim = mha->weight_data_size / mha->embed_dim; + + { + mha->q_weight_data_int8_scales.create(mha->embed_dim); + for (int i = 0; i < mha->embed_dim; i++) + { + float absmax = 0.f; + + const float* ptr = (const float*)mha->q_weight_data + i * qdim; + for (int j = 0; j < qdim; j++) + { + absmax = std::max(absmax, (float)fabs(ptr[j])); + } + + mha->q_weight_data_int8_scales[i] = absmax == 0.f ? 1.f : 127 / absmax; + } + + ncnn::Mat q_weight_data = mha->q_weight_data.reshape(qdim, mha->embed_dim); + ncnn::Mat q_weight_data_int8; + + ncnn::Option opt_q = opt; + opt_q.blob_allocator = q_weight_data.allocator; + opt_q.use_packing_layout = false; + ncnn::quantize_to_int8(q_weight_data, q_weight_data_int8, mha->q_weight_data_int8_scales, opt_q); + if (q_weight_data_int8.empty()) + return -100; + + mha->q_weight_data = q_weight_data_int8.reshape(qdim * mha->embed_dim); + } + + { + mha->k_weight_data_int8_scales.create(mha->embed_dim); + for (int i = 0; i < mha->embed_dim; i++) + { + float absmax = 0.f; + + const float* ptr = (const float*)mha->k_weight_data + i * mha->kdim; + for (int j = 0; j < mha->kdim; j++) + { + absmax = std::max(absmax, (float)fabs(ptr[j])); + } + + mha->k_weight_data_int8_scales[i] = absmax == 0.f ? 1.f : 127 / absmax; + } + + ncnn::Mat k_weight_data = mha->k_weight_data.reshape(mha->kdim, mha->embed_dim); + ncnn::Mat k_weight_data_int8; + + ncnn::Option opt_q = opt; + opt_q.blob_allocator = k_weight_data.allocator; + opt_q.use_packing_layout = false; + ncnn::quantize_to_int8(k_weight_data, k_weight_data_int8, mha->k_weight_data_int8_scales, opt_q); + if (k_weight_data_int8.empty()) + return -100; + + mha->k_weight_data = k_weight_data_int8.reshape(mha->kdim * mha->embed_dim); + } + + { + mha->v_weight_data_int8_scales.create(mha->embed_dim); + for (int i = 0; i < mha->embed_dim; i++) + { + float absmax = 0.f; + + const float* ptr = (const float*)mha->v_weight_data + i * mha->vdim; + for (int j = 0; j < mha->vdim; j++) + { + absmax = std::max(absmax, (float)fabs(ptr[j])); + } + + mha->v_weight_data_int8_scales[i] = absmax == 0.f ? 1.f : 127 / absmax; + } + + ncnn::Mat v_weight_data = mha->v_weight_data.reshape(mha->vdim, mha->embed_dim); + ncnn::Mat v_weight_data_int8; + + ncnn::Option opt_q = opt; + opt_q.blob_allocator = v_weight_data.allocator; + opt_q.use_packing_layout = false; + ncnn::quantize_to_int8(v_weight_data, v_weight_data_int8, mha->v_weight_data_int8_scales, opt_q); + if (v_weight_data_int8.empty()) + return -100; + + mha->v_weight_data = v_weight_data_int8.reshape(mha->vdim * mha->embed_dim); + } + + { + const float* ptr = mha->out_weight_data; + float absmax = 0.f; + for (int j = 0; j < mha->out_weight_data.w; j++) + { + absmax = std::max(absmax, (float)fabs(ptr[j])); + } + + mha->out_weight_data_int8_scale = absmax == 0.f ? 1.f : 127 / absmax; + + ncnn::Mat out_weight_data_int8_scales(1); + out_weight_data_int8_scales[0] = mha->out_weight_data_int8_scale; + + ncnn::Mat out_weight_data_int8; + + ncnn::Option opt_q = opt; + opt_q.blob_allocator = mha->out_weight_data.allocator; + opt_q.use_packing_layout = false; + ncnn::quantize_to_int8(mha->out_weight_data, out_weight_data_int8, out_weight_data_int8_scales, opt_q); + if (out_weight_data_int8.empty()) + return -100; + + mha->out_weight_data = out_weight_data_int8; + } + + mha->int8_scale_term = 2; + } + + return 0; +} + int NetQuantize::fuse_requantize() { const size_t layer_count = layers.size(); @@ -970,6 +1102,7 @@ int main(int argc, char** argv) quantizer.quantize_gru(); quantizer.quantize_embed(); quantizer.quantize_gemm(); + quantizer.quantize_multiheadattention(); quantizer.fuse_requantize(); From 121b1fecd5cf1b12df233c2fbc97baec873983f9 Mon Sep 17 00:00:00 2001 From: nihui Date: Tue, 15 Oct 2024 08:40:09 +0000 Subject: [PATCH 03/15] apply code-format changes --- src/layer/arm/multiheadattention_arm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layer/arm/multiheadattention_arm.cpp b/src/layer/arm/multiheadattention_arm.cpp index 46a0ec995e43..1e9148730d7c 100644 --- a/src/layer/arm/multiheadattention_arm.cpp +++ b/src/layer/arm/multiheadattention_arm.cpp @@ -28,7 +28,7 @@ MultiHeadAttention_arm::MultiHeadAttention_arm() #endif #endif // __ARM_NEON - support_bf16_storage = false;// TODO enable bf16 when gemm has proper out_elemtype support + support_bf16_storage = false; // TODO enable bf16 when gemm has proper out_elemtype support q_gemm = 0; k_gemm = 0; From f8560112975a80dcdac84a7d0bd6284467715a49 Mon Sep 17 00:00:00 2001 From: nihui Date: Tue, 15 Oct 2024 19:28:53 +0800 Subject: [PATCH 04/15] pnnx drop onnx weight-like graph input (#5736) --- tools/pnnx/src/CMakeLists.txt | 1 + tools/pnnx/src/load_onnx.cpp | 3 + .../pass_onnx/eliminate_initializer_input.cpp | 74 +++++++++++++++++++ .../pass_onnx/eliminate_initializer_input.h | 25 +++++++ 4 files changed, 103 insertions(+) create mode 100644 tools/pnnx/src/pass_onnx/eliminate_initializer_input.cpp create mode 100644 tools/pnnx/src/pass_onnx/eliminate_initializer_input.h diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index 7743a8ae453e..2281875dbd43 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -662,6 +662,7 @@ if(onnxruntime_FOUND) set(pnnx_pass_onnx_SRCS pass_onnx/canonicalize.cpp pass_onnx/dead_code_elimination.cpp + pass_onnx/eliminate_initializer_input.cpp pass_onnx/eliminate_noop.cpp pass_onnx/fold_constants.cpp pass_onnx/inline_containers.cpp diff --git a/tools/pnnx/src/load_onnx.cpp b/tools/pnnx/src/load_onnx.cpp index 9adf2b470888..e39c20296598 100644 --- a/tools/pnnx/src/load_onnx.cpp +++ b/tools/pnnx/src/load_onnx.cpp @@ -29,6 +29,7 @@ #include "pass_onnx/canonicalize.h" #include "pass_onnx/dead_code_elimination.h" +#include "pass_onnx/eliminate_initializer_input.h" #include "pass_onnx/eliminate_noop.h" #include "pass_onnx/fold_constants.h" #include "pass_onnx/inline_containers.h" @@ -531,6 +532,8 @@ int load_onnx(const std::string& onnxpath, Graph& pnnx_graph, return -1; } + onnx2pnnx::eliminate_initializer_input(model); + // input shape sanity check if (!check_input_shape(model, input_shapes, input_types)) { diff --git a/tools/pnnx/src/pass_onnx/eliminate_initializer_input.cpp b/tools/pnnx/src/pass_onnx/eliminate_initializer_input.cpp new file mode 100644 index 000000000000..be447bd26da3 --- /dev/null +++ b/tools/pnnx/src/pass_onnx/eliminate_initializer_input.cpp @@ -0,0 +1,74 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "eliminate_initializer_input.h" + +#include +#include +#include + +namespace pnnx { + +namespace onnx2pnnx { + +void eliminate_initializer_input(onnx::ModelProto& model) +{ + // collect initializers + std::unordered_set initializers; + { + const onnx::GraphProto& graph = model.graph(); + for (int i = 0; i < graph.initializer_size(); i++) + { + initializers.insert(graph.initializer(i).name()); + } + } + + // collect initializer graph input + std::vector initializer_input_indexes; + { + const onnx::GraphProto& graph = model.graph(); + for (int i = 0; i < graph.input_size(); i++) + { + const std::string& input_name = graph.input(i).name(); + if (initializers.find(input_name) == initializers.end()) + continue; + + initializer_input_indexes.push_back(i); + } + } + + // eliminate initializer graph input + { + onnx::GraphProto* graph = model.mutable_graph(); + + for (size_t i = 0; i < initializer_input_indexes.size(); i++) + { + const int initializer_input_index = initializer_input_indexes[i]; + + // ..... iii ....... + const int graph_input_size = graph->input_size(); + for (int j = initializer_input_index; j < graph_input_size - 1; j++) + { + graph->mutable_input()->SwapElements(j, j + 1); + } + + // ..... ....... iii + graph->mutable_input()->RemoveLast(); + } + } +} + +} // namespace onnx2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_onnx/eliminate_initializer_input.h b/tools/pnnx/src/pass_onnx/eliminate_initializer_input.h new file mode 100644 index 000000000000..f82b71cd1876 --- /dev/null +++ b/tools/pnnx/src/pass_onnx/eliminate_initializer_input.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "onnx-ml.pb.h" + +namespace pnnx { + +namespace onnx2pnnx { + +void eliminate_initializer_input(onnx::ModelProto& model); + +} // namespace onnx2pnnx + +} // namespace pnnx From 8105c75120f4f3a996395f21323e0578d65664dc Mon Sep 17 00:00:00 2001 From: nihui Date: Thu, 17 Oct 2024 10:09:52 +0800 Subject: [PATCH 05/15] improve compatibility of harmonyos cpu topology abi (#5740) --- src/cpu.cpp | 71 +++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 61 insertions(+), 10 deletions(-) diff --git a/src/cpu.cpp b/src/cpu.cpp index f9e64a1cc75b..9ab0ebb31e99 100644 --- a/src/cpu.cpp +++ b/src/cpu.cpp @@ -793,20 +793,66 @@ static int get_thread_siblings(int cpuid) char path[256]; sprintf(path, "/sys/devices/system/cpu/cpu%d/topology/thread_siblings", cpuid); - FILE* fp = fopen(path, "rb"); - if (!fp) - return -1; - - int thread_siblings = -1; - int nscan = fscanf(fp, "%x", &thread_siblings); - if (nscan != 1) + FILE* fp = 0; //fopen(path, "rb"); + if (fp) { - // ignore + int thread_siblings = -1; + int nscan = fscanf(fp, "%x", &thread_siblings); + if (nscan != 1) + { + // ignore + } + + fclose(fp); + + return thread_siblings; } - fclose(fp); + // second try, parse from human-readable thread_siblings_list + sprintf(path, "/sys/devices/system/cpu/cpu%d/topology/thread_siblings_list", cpuid); + + fp = fopen(path, "rb"); + if (fp) + { + int thread_siblings = -1; + + int id0; + char sep; + int id1; + + int nscan = fscanf(fp, "%d", &id0); + if (nscan == 1) + { + thread_siblings = (1 << id0); + + while (fscanf(fp, "%c%d", &sep, &id1) == 2) + { + if (sep == ',') + { + thread_siblings |= (1 << id1); + } + if (sep == '-' && id0 < id1) + { + for (int i = id0 + 1; i <= id1; i++) + { + thread_siblings |= (1 << i); + } + } + + id0 = id1; + } + } + else + { + // ignore + } + + fclose(fp); + + return thread_siblings; + } - return thread_siblings; + return -1; } #endif // defined __ANDROID__ || defined __linux__ @@ -869,6 +915,11 @@ static int get_physical_cpucount() count++; } } + if (count == 0) + { + // cannot resolve siblings, fallback to all cpu count + count = g_cpucount; + } #elif __APPLE__ size_t len = sizeof(count); sysctlbyname("hw.physicalcpu_max", &count, &len, NULL, 0); From bd1f39ed82b88e3a894b52fdfcbeeb64742a13d4 Mon Sep 17 00:00:00 2001 From: nihui Date: Thu, 17 Oct 2024 10:13:03 +0800 Subject: [PATCH 06/15] blacklist mesa vulkan cooperative matrix feature (#5739) ref https://gitlab.freedesktop.org/mesa/mesa/-/issues/10847 --- src/gpu.cpp | 67 +++++++++++++++++++++++++++++++++++------ src/gpu.h | 5 +++ src/vulkan_header_fix.h | 59 ++++++++++++++++++++++++++++++++++++ 3 files changed, 122 insertions(+), 9 deletions(-) diff --git a/src/gpu.cpp b/src/gpu.cpp index 5b34c224da08..cdcec0f7f02e 100644 --- a/src/gpu.cpp +++ b/src/gpu.cpp @@ -269,6 +269,10 @@ class GpuInfoPrivate char device_name[VK_MAX_PHYSICAL_DEVICE_NAME_SIZE]; uint8_t pipeline_cache_uuid[VK_UUID_SIZE]; + // driver properties + uint32_t driver_id; + char driver_name[VK_MAX_DRIVER_NAME_SIZE]; + // 0 = discrete gpu // 1 = integrated gpu // 2 = virtual gpu @@ -349,6 +353,7 @@ class GpuInfoPrivate int support_VK_KHR_cooperative_matrix; int support_VK_KHR_dedicated_allocation; int support_VK_KHR_descriptor_update_template; + int support_VK_KHR_driver_properties; int support_VK_KHR_external_memory; int support_VK_KHR_get_memory_requirements2; int support_VK_KHR_maintenance1; @@ -434,6 +439,16 @@ uint8_t* GpuInfo::pipeline_cache_uuid() const return d->pipeline_cache_uuid; } +uint32_t GpuInfo::driver_id() const +{ + return d->driver_id; +} + +const char* GpuInfo::driver_name() const +{ + return d->driver_name; +} + int GpuInfo::type() const { return d->type; @@ -709,6 +724,11 @@ int GpuInfo::support_VK_KHR_descriptor_update_template() const return d->support_VK_KHR_descriptor_update_template; } +int GpuInfo::support_VK_KHR_driver_properties() const +{ + return d->support_VK_KHR_driver_properties; +} + int GpuInfo::support_VK_KHR_external_memory() const { return d->support_VK_KHR_external_memory; @@ -1438,15 +1458,15 @@ int create_gpu_instance(const char* driver_path) VkPhysicalDeviceProperties physicalDeviceProperties; vkGetPhysicalDeviceProperties(physicalDevice, &physicalDeviceProperties); - // NCNN_LOGE("[%u] apiVersion = %u.%u.%u", i, VK_VERSION_MAJOR(physicalDeviceProperties.apiVersion), - // VK_VERSION_MINOR(physicalDeviceProperties.apiVersion), VK_VERSION_PATCH(physicalDeviceProperties.apiVersion)); - // NCNN_LOGE("[%u] driverVersion = %u.%u.%u", i, VK_VERSION_MAJOR(physicalDeviceProperties.driverVersion), - // VK_VERSION_MINOR(physicalDeviceProperties.driverVersion), VK_VERSION_PATCH(physicalDeviceProperties.driverVersion)); - // NCNN_LOGE("[%u] vendorID = %x", i, physicalDeviceProperties.vendorID); - // NCNN_LOGE("[%u] deviceID = %x", i, physicalDeviceProperties.deviceID); - // NCNN_LOGE("[%u] deviceType = %x", i, physicalDeviceProperties.deviceType); - // NCNN_LOGE("[%u] deviceName = %s", i, physicalDeviceProperties.deviceName); - // NCNN_LOGE("[%u] pipelineCacheUUID = %u", i, physicalDeviceProperties.pipelineCacheUUID); + // NCNN_LOGE("[%u] apiVersion = %u.%u.%u", i, VK_VERSION_MAJOR(physicalDeviceProperties.apiVersion), + // VK_VERSION_MINOR(physicalDeviceProperties.apiVersion), VK_VERSION_PATCH(physicalDeviceProperties.apiVersion)); + // NCNN_LOGE("[%u] driverVersion = %u.%u.%u", i, VK_VERSION_MAJOR(physicalDeviceProperties.driverVersion), + // VK_VERSION_MINOR(physicalDeviceProperties.driverVersion), VK_VERSION_PATCH(physicalDeviceProperties.driverVersion)); + // NCNN_LOGE("[%u] vendorID = %x", i, physicalDeviceProperties.vendorID); + // NCNN_LOGE("[%u] deviceID = %x", i, physicalDeviceProperties.deviceID); + // NCNN_LOGE("[%u] deviceType = %x", i, physicalDeviceProperties.deviceType); + // NCNN_LOGE("[%u] deviceName = %s", i, physicalDeviceProperties.deviceName); + // NCNN_LOGE("[%u] pipelineCacheUUID = %u", i, physicalDeviceProperties.pipelineCacheUUID); // mali // t760 = 0x13b5 0x7500001 / 0x7501000 @@ -1676,6 +1696,7 @@ int create_gpu_instance(const char* driver_path) gpu_info.support_VK_KHR_cooperative_matrix = 0; gpu_info.support_VK_KHR_dedicated_allocation = 0; gpu_info.support_VK_KHR_descriptor_update_template = 0; + gpu_info.support_VK_KHR_driver_properties = 0; gpu_info.support_VK_KHR_external_memory = 0; gpu_info.support_VK_KHR_get_memory_requirements2 = 0; gpu_info.support_VK_KHR_maintenance1 = 0; @@ -1720,6 +1741,8 @@ int create_gpu_instance(const char* driver_path) gpu_info.support_VK_KHR_dedicated_allocation = exp.specVersion; else if (strcmp(exp.extensionName, "VK_KHR_descriptor_update_template") == 0) gpu_info.support_VK_KHR_descriptor_update_template = exp.specVersion; + else if (strcmp(exp.extensionName, "VK_KHR_driver_properties") == 0) + gpu_info.support_VK_KHR_driver_properties = exp.specVersion; else if (strcmp(exp.extensionName, "VK_KHR_external_memory") == 0) gpu_info.support_VK_KHR_external_memory = exp.specVersion; else if (strcmp(exp.extensionName, "VK_KHR_get_memory_requirements2") == 0) @@ -1793,6 +1816,8 @@ int create_gpu_instance(const char* driver_path) gpu_info.support_cooperative_matrix_16_8_8 = false; gpu_info.support_cooperative_matrix_16_8_16 = false; gpu_info.support_cooperative_matrix_16_16_16 = false; + gpu_info.driver_id = 0; + gpu_info.driver_name[0] = '\0'; if (support_VK_KHR_get_physical_device_properties2) { void* queryExtensionFeatures = 0; @@ -1855,6 +1880,16 @@ int create_gpu_instance(const char* driver_path) queryExtensionFeatures = &queryCooperativeMatrixFeaturesNV; } + // query driver properties + VkPhysicalDeviceDriverPropertiesKHR queryDriverProperties; + queryDriverProperties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_DRIVER_PROPERTIES_KHR; + queryDriverProperties.pNext = 0; + if (gpu_info.support_VK_KHR_driver_properties) + { + queryDriverProperties.pNext = queryExtensionFeatures; + queryExtensionFeatures = &queryDriverProperties; + } + VkPhysicalDeviceFeatures2KHR queryFeatures; queryFeatures.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2_KHR; queryFeatures.pNext = queryExtensionFeatures; @@ -1889,6 +1924,11 @@ int create_gpu_instance(const char* driver_path) { gpu_info.support_cooperative_matrix = queryCooperativeMatrixFeaturesNV.cooperativeMatrix; } + if (gpu_info.support_VK_KHR_driver_properties) + { + gpu_info.driver_id = queryDriverProperties.driverID; + memcpy(gpu_info.driver_name, queryDriverProperties.driverName, VK_MAX_DRIVER_NAME_SIZE); + } } else { @@ -1921,6 +1961,13 @@ int create_gpu_instance(const char* driver_path) gpu_info.support_fp16_arithmetic = false; } + if (gpu_info.driver_id == VK_DRIVER_ID_MESA_RADV || gpu_info.driver_id == VK_DRIVER_ID_INTEL_OPEN_SOURCE_MESA) + { + // cooperative matrix produces wrong result on mesa vulkan drivers :( + // https://gitlab.freedesktop.org/mesa/mesa/-/issues/10847 + gpu_info.support_cooperative_matrix = false; + } + if (gpu_info.support_cooperative_matrix) { // query supported cooperative matrix types and operations @@ -2462,6 +2509,8 @@ VulkanDevice::VulkanDevice(int device_index) enabledExtensions.push_back("VK_KHR_dedicated_allocation"); if (info.support_VK_KHR_descriptor_update_template()) enabledExtensions.push_back("VK_KHR_descriptor_update_template"); + if (info.support_VK_KHR_driver_properties()) + enabledExtensions.push_back("VK_KHR_driver_properties"); if (info.support_VK_KHR_external_memory()) enabledExtensions.push_back("VK_KHR_external_memory"); if (info.support_VK_KHR_get_memory_requirements2()) diff --git a/src/gpu.h b/src/gpu.h index 4d131f71c8bd..b98827b69ebd 100644 --- a/src/gpu.h +++ b/src/gpu.h @@ -207,6 +207,10 @@ class NCNN_EXPORT GpuInfo const char* device_name() const; uint8_t* pipeline_cache_uuid() const; + // driver properties + uint32_t driver_id() const; + const char* driver_name() const; + // 0 = discrete gpu // 1 = integrated gpu // 2 = virtual gpu @@ -287,6 +291,7 @@ class NCNN_EXPORT GpuInfo int support_VK_KHR_cooperative_matrix() const; int support_VK_KHR_dedicated_allocation() const; int support_VK_KHR_descriptor_update_template() const; + int support_VK_KHR_driver_properties() const; int support_VK_KHR_external_memory() const; int support_VK_KHR_get_memory_requirements2() const; int support_VK_KHR_maintenance1() const; diff --git a/src/vulkan_header_fix.h b/src/vulkan_header_fix.h index 1402bba2bed1..54c135e3330f 100644 --- a/src/vulkan_header_fix.h +++ b/src/vulkan_header_fix.h @@ -1219,4 +1219,63 @@ typedef struct VkPhysicalDeviceCooperativeMatrixPropertiesKHR typedef VkResult(VKAPI_PTR* PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR)(VkPhysicalDevice physicalDevice, uint32_t* pPropertyCount, VkCooperativeMatrixPropertiesKHR* pProperties); #endif // VK_KHR_cooperative_matrix +#ifndef VK_KHR_driver_properties +#define VK_KHR_driver_properties 1 +#define VK_MAX_DRIVER_NAME_SIZE 256U +#define VK_MAX_DRIVER_INFO_SIZE 256U +#define VK_MAX_DRIVER_NAME_SIZE_KHR VK_MAX_DRIVER_NAME_SIZE +#define VK_MAX_DRIVER_INFO_SIZE_KHR VK_MAX_DRIVER_INFO_SIZE +#define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_DRIVER_PROPERTIES (VkStructureType)1000196000 +#define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_DRIVER_PROPERTIES_KHR VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_DRIVER_PROPERTIES +typedef enum VkDriverId +{ + VK_DRIVER_ID_AMD_PROPRIETARY = 1, + VK_DRIVER_ID_AMD_OPEN_SOURCE = 2, + VK_DRIVER_ID_MESA_RADV = 3, + VK_DRIVER_ID_NVIDIA_PROPRIETARY = 4, + VK_DRIVER_ID_INTEL_PROPRIETARY_WINDOWS = 5, + VK_DRIVER_ID_INTEL_OPEN_SOURCE_MESA = 6, + VK_DRIVER_ID_IMAGINATION_PROPRIETARY = 7, + VK_DRIVER_ID_QUALCOMM_PROPRIETARY = 8, + VK_DRIVER_ID_ARM_PROPRIETARY = 9, + VK_DRIVER_ID_GOOGLE_SWIFTSHADER = 10, + VK_DRIVER_ID_GGP_PROPRIETARY = 11, + VK_DRIVER_ID_BROADCOM_PROPRIETARY = 12, + VK_DRIVER_ID_MESA_LLVMPIPE = 13, + VK_DRIVER_ID_MOLTENVK = 14, + VK_DRIVER_ID_COREAVI_PROPRIETARY = 15, + VK_DRIVER_ID_JUICE_PROPRIETARY = 16, + VK_DRIVER_ID_VERISILICON_PROPRIETARY = 17, + VK_DRIVER_ID_MESA_TURNIP = 18, + VK_DRIVER_ID_MESA_V3DV = 19, + VK_DRIVER_ID_MESA_PANVK = 20, + VK_DRIVER_ID_SAMSUNG_PROPRIETARY = 21, + VK_DRIVER_ID_MESA_VENUS = 22, + VK_DRIVER_ID_MESA_DOZEN = 23, + VK_DRIVER_ID_MESA_NVK = 24, + VK_DRIVER_ID_IMAGINATION_OPEN_SOURCE_MESA = 25, + VK_DRIVER_ID_MESA_AGXV = 26, + VK_DRIVER_ID_MAX_ENUM = 0x7FFFFFFF +} VkDriverId; +typedef struct VkConformanceVersion +{ + uint8_t major; + uint8_t minor; + uint8_t subminor; + uint8_t patch; +} VkConformanceVersion; +typedef struct VkPhysicalDeviceDriverProperties +{ + VkStructureType sType; + void* pNext; + VkDriverId driverID; + char driverName[VK_MAX_DRIVER_NAME_SIZE]; + char driverInfo[VK_MAX_DRIVER_INFO_SIZE]; + VkConformanceVersion conformanceVersion; +} VkPhysicalDeviceDriverProperties; +typedef VkDriverId VkDriverIdKHR; +typedef VkConformanceVersion VkConformanceVersionKHR; +typedef VkPhysicalDeviceDriverProperties VkPhysicalDeviceDriverPropertiesKHR; +#endif // VK_KHR_driver_properties + #endif // NCNN_VULKAN_HEADER_FIX_H From 73d35193265edf9f01ff24ec5a898ed9faaec324 Mon Sep 17 00:00:00 2001 From: nihui Date: Fri, 18 Oct 2024 10:23:10 +0800 Subject: [PATCH 07/15] layernorm x86 optimization, re (#5745) --- src/layer/x86/layernorm_x86.cpp | 735 ++++++++++++++++---------------- 1 file changed, 358 insertions(+), 377 deletions(-) diff --git a/src/layer/x86/layernorm_x86.cpp b/src/layer/x86/layernorm_x86.cpp index 21840c6b3d20..91e36163ff32 100644 --- a/src/layer/x86/layernorm_x86.cpp +++ b/src/layer/x86/layernorm_x86.cpp @@ -13,9 +13,6 @@ // specific language governing permissions and limitations under the License. #include "layernorm_x86.h" -#include "x86_usability.h" - -#include #if __SSE2__ #include @@ -24,6 +21,8 @@ #endif // __AVX__ #endif // __SSE2__ +#include "x86_usability.h" + namespace ncnn { LayerNorm_x86::LayerNorm_x86() @@ -33,37 +32,53 @@ LayerNorm_x86::LayerNorm_x86() #endif // __SSE2__ } -static NCNN_FORCEINLINE void fast_mean(float* ptr, float* mean, int elempack, int elemcount, int size) +static void layernorm(float* ptr, const float* gamma_ptr, const float* beta_ptr, float eps, int elemcount, int elempack) { - int i = 0; + const int size = elemcount * elempack; + #if __SSE2__ #if __AVX__ #if __AVX512F__ - __m512 _sum_512 = _mm512_setzero_ps(); - for (; i + 16 <= size; i += 16, ptr += 16) - { - __m512 _cur = _mm512_loadu_ps(ptr); - _sum_512 = _mm512_add_ps(_sum_512, _cur); - } + __m512 _mean_avx512 = _mm512_set1_ps(0.f); #endif // __AVX512F__ - __m256 _sum_256 = _mm256_setzero_ps(); - for (; i + 8 <= size; i += 8, ptr += 8) - { - __m256 _cur = _mm256_loadu_ps(ptr); - _sum_256 = _mm256_add_ps(_sum_256, _cur); - } + __m256 _mean_avx = _mm256_set1_ps(0.f); #endif // __AVX__ - __m128 _sum_128 = _mm_setzero_ps(); - for (; i + 4 <= size; i += 4, ptr += 4) - { - __m128 _cur = _mm_loadu_ps(ptr); - _sum_128 = _mm_add_ps(_sum_128, _cur); - } + __m128 _mean = _mm_set1_ps(0.f); #endif // __SSE2__ - float sum = 0.0f; - for (; i < size; ++i, ++ptr) + float mean = 0.f; { - sum += *ptr; + const float* ptr0 = ptr; + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = _mm512_loadu_ps(ptr0); + _mean_avx512 = _mm512_add_ps(_mean_avx512, _p); + ptr0 += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = _mm256_loadu_ps(ptr0); + _mean_avx = _mm256_add_ps(_mean_avx, _p); + ptr0 += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = _mm_loadu_ps(ptr0); + _mean = _mm_add_ps(_mean, _p); + ptr0 += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + mean += ptr0[0]; + ptr0++; + } } #if __SSE2__ @@ -71,110 +86,128 @@ static NCNN_FORCEINLINE void fast_mean(float* ptr, float* mean, int elempack, in #if __AVX512F__ if (elempack == 16) { - __m512 _mean = _mm512_div_ps(_sum_512, _mm512_set1_ps((float)elemcount)); - _mm512_storeu_ps(mean, _mean); + __m512 _elemcount = _mm512_set1_ps((float)elemcount); + _mean_avx512 = _mm512_div_ps(_mean_avx512, _elemcount); } #endif // __AVX512F__ - if (elempack == 8) { #if __AVX512F__ { - __m256 _low = _mm512_castps512_ps256(_sum_512); - __m256 _high = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(_sum_512), 1)); - _sum_256 = _mm256_add_ps(_sum_256, _high); - _sum_256 = _mm256_add_ps(_sum_256, _low); + __m256 _mean0 = _mm512_castps512_ps256(_mean_avx512); + __m256 _mean1 = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(_mean_avx512), 1)); + _mean_avx = _mm256_add_ps(_mean_avx, _mean0); + _mean_avx = _mm256_add_ps(_mean_avx, _mean1); } #endif // __AVX512F__ - __m256 _mean = _mm256_div_ps(_sum_256, _mm256_set1_ps((float)elemcount)); - _mm256_storeu_ps(mean, _mean); + + __m256 _elemcount = _mm256_set1_ps((float)elemcount); + _mean_avx = _mm256_div_ps(_mean_avx, _elemcount); +#if __AVX512F__ + _mean_avx512 = _mm512_insertf32x8(_mm512_castps256_ps512(_mean_avx), _mean_avx, 1); +#endif // __AVX512F__ } #endif // __AVX__ - if (elempack == 4) { #if __AVX__ #if __AVX512F__ { - __m256 _low = _mm512_castps512_ps256(_sum_512); - __m256 _high = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(_sum_512), 1)); - _sum_256 = _mm256_add_ps(_sum_256, _high); - _sum_256 = _mm256_add_ps(_sum_256, _low); + __m256 _mean0 = _mm512_castps512_ps256(_mean_avx512); + __m256 _mean1 = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(_mean_avx512), 1)); + _mean_avx = _mm256_add_ps(_mean_avx, _mean0); + _mean_avx = _mm256_add_ps(_mean_avx, _mean1); } #endif // __AVX512F__ { - __m128 _low = _mm256_castps256_ps128(_sum_256); - __m128 _high = _mm256_extractf128_ps(_sum_256, 1); - _sum_128 = _mm_add_ps(_sum_128, _low); - _sum_128 = _mm_add_ps(_sum_128, _high); + __m128 _mean0 = _mm256_castps256_ps128(_mean_avx); + __m128 _mean1 = _mm256_extractf128_ps(_mean_avx, 1); + _mean = _mm_add_ps(_mean, _mean0); + _mean = _mm_add_ps(_mean, _mean1); } #endif // __AVX__ - __m128 _mean = _mm_div_ps(_sum_128, _mm_set1_ps((float)elemcount)); - _mm_storeu_ps(mean, _mean); + + __m128 _elemcount = _mm_set1_ps((float)elemcount); + _mean = _mm_div_ps(_mean, _elemcount); +#if __AVX__ + _mean_avx = _mm256_insertf128_ps(_mm256_castps128_ps256(_mean), _mean, 1); +#if __AVX512F__ + _mean_avx512 = _mm512_insertf32x8(_mm512_castps256_ps512(_mean_avx), _mean_avx, 1); +#endif // __AVX512F__ +#endif // __AVX__ } #endif // __SSE2__ - if (elempack == 1) { #if __SSE2__ #if __AVX__ #if __AVX512F__ - sum += _mm512_comp_reduce_add_ps(_sum_512); + mean += _mm512_comp_reduce_add_ps(_mean_avx512); #endif // __AVX512F__ - sum += _mm256_reduce_add_ps(_sum_256); + mean += _mm256_reduce_add_ps(_mean_avx); #endif // __AVX__ - sum += _mm_reduce_add_ps(_sum_128); + mean += _mm_reduce_add_ps(_mean); #endif // __SSE2__ - mean[0] = sum / elemcount; - } -} -static NCNN_FORCEINLINE void fast_var(float* ptr, float* var, const float* mean, int elempack, int elemcount, int size) -{ - const float _mean = mean[0]; + mean = mean / elemcount; #if __SSE2__ - __m128 _mean_128 = (elempack == 4) ? _mm_loadu_ps(mean) : _mm_set1_ps(_mean); + _mean = _mm_set1_ps(mean); #if __AVX__ - __m256 _mean_256 = (elempack == 8) ? _mm256_loadu_ps(mean) : _mm256_insertf128_ps(_mm256_castps128_ps256(_mean_128), _mean_128, 1); + _mean_avx = _mm256_insertf128_ps(_mm256_castps128_ps256(_mean), _mean, 1); #if __AVX512F__ - __m512 _mean_512 = (elempack == 16) ? _mm512_loadu_ps(mean) : _mm512_insertf32x8(_mm512_castps256_ps512(_mean_256), _mean_256, 1); + _mean_avx512 = _mm512_insertf32x8(_mm512_castps256_ps512(_mean_avx), _mean_avx, 1); #endif // __AVX512F__ #endif // __AVX__ #endif // __SSE2__ + } - int i = 0; #if __SSE2__ #if __AVX__ #if __AVX512F__ - __m512 _sq_sum_512 = _mm512_setzero_ps(); - for (; i + 16 <= size; i += 16, ptr += 16) - { - __m512 _cur = _mm512_loadu_ps(ptr); - _cur = _mm512_sub_ps(_cur, _mean_512); - _sq_sum_512 = _mm512_fmadd_ps(_cur, _cur, _sq_sum_512); - } + __m512 _var_avx512 = _mm512_set1_ps(0.f); #endif // __AVX512F__ - __m256 _sq_sum_256 = _mm256_setzero_ps(); - for (; i + 8 <= size; i += 8, ptr += 8) - { - __m256 _cur = _mm256_loadu_ps(ptr); - _cur = _mm256_sub_ps(_cur, _mean_256); - _sq_sum_256 = _mm256_comp_fmadd_ps(_cur, _cur, _sq_sum_256); - } + __m256 _var_avx = _mm256_set1_ps(0.f); #endif // __AVX__ - __m128 _sq_sum_128 = _mm_setzero_ps(); - for (; i + 4 <= size; i += 4, ptr += 4) - { - __m128 _cur = _mm_loadu_ps(ptr); - _cur = _mm_sub_ps(_cur, _mean_128); - _sq_sum_128 = _mm_comp_fmadd_ps(_cur, _cur, _sq_sum_128); - } + __m128 _var = _mm_set1_ps(0.f); #endif // __SSE2__ - float sq_sum = 0.0f; - for (; i < size; ++i, ++ptr) + float var = 0.f; { - float tmp = *ptr - _mean; - sq_sum += tmp * tmp; + const float* ptr0 = ptr; + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = _mm512_loadu_ps(ptr0); + _p = _mm512_sub_ps(_p, _mean_avx512); + _var_avx512 = _mm512_fmadd_ps(_p, _p, _var_avx512); + ptr0 += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = _mm256_loadu_ps(ptr0); + _p = _mm256_sub_ps(_p, _mean_avx); + _var_avx = _mm256_comp_fmadd_ps(_p, _p, _var_avx); + ptr0 += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = _mm_loadu_ps(ptr0); + _p = _mm_sub_ps(_p, _mean); + _var = _mm_comp_fmadd_ps(_p, _p, _var); + ptr0 += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + float v = ptr0[0] - mean; + var += v * v; + ptr0++; + } } #if __SSE2__ @@ -182,384 +215,332 @@ static NCNN_FORCEINLINE void fast_var(float* ptr, float* var, const float* mean, #if __AVX512F__ if (elempack == 16) { - __m512 _var = _mm512_div_ps(_sq_sum_512, _mm512_set1_ps((float)elemcount)); - _mm512_storeu_ps(var, _var); + __m512 _elemcount = _mm512_set1_ps((float)elemcount); + __m512 _eps = _mm512_set1_ps(eps); + _var_avx512 = _mm512_div_ps(_var_avx512, _elemcount); + _var_avx512 = _mm512_add_ps(_var_avx512, _eps); + __m256 _var0 = _mm256_rsqrt_ps(_mm512_extractf32x8_ps(_var_avx512, 0)); + __m256 _var1 = _mm256_rsqrt_ps(_mm512_extractf32x8_ps(_var_avx512, 1)); + _var_avx512 = _mm512_insertf32x8(_mm512_castps256_ps512(_var0), _var1, 1); + _mean_avx512 = _mm512_mul_ps(_mean_avx512, _var_avx512); } #endif // __AVX512F__ - if (elempack == 8) { #if __AVX512F__ { - __m256 _low = _mm512_castps512_ps256(_sq_sum_512); - __m256 _high = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(_sq_sum_512), 1)); - _sq_sum_256 = _mm256_add_ps(_sq_sum_256, _low); - _sq_sum_256 = _mm256_add_ps(_sq_sum_256, _high); + __m256 _var0 = _mm512_castps512_ps256(_var_avx512); + __m256 _var1 = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(_var_avx512), 1)); + _var_avx = _mm256_add_ps(_var_avx, _var0); + _var_avx = _mm256_add_ps(_var_avx, _var1); } #endif // __AVX512F__ - __m256 _var = _mm256_div_ps(_sq_sum_256, _mm256_set1_ps((float)elemcount)); - _mm256_storeu_ps(var, _var); + + __m256 _elemcount = _mm256_set1_ps((float)elemcount); + __m256 _eps = _mm256_set1_ps(eps); + _var_avx = _mm256_div_ps(_var_avx, _elemcount); + _var_avx = _mm256_add_ps(_var_avx, _eps); + _var_avx = _mm256_rsqrt_ps(_var_avx); + _mean_avx = _mm256_mul_ps(_mean_avx, _var_avx); +#if __AVX512F__ + _var_avx512 = _mm512_insertf32x8(_mm512_castps256_ps512(_var_avx), _var_avx, 1); + _mean_avx512 = _mm512_insertf32x8(_mm512_castps256_ps512(_mean_avx), _mean_avx, 1); +#endif // __AVX512F__ } #endif // __AVX__ - if (elempack == 4) { #if __AVX__ #if __AVX512F__ { - __m256 _low = _mm512_castps512_ps256(_sq_sum_512); - __m256 _high = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(_sq_sum_512), 1)); - _sq_sum_256 = _mm256_add_ps(_sq_sum_256, _high); - _sq_sum_256 = _mm256_add_ps(_sq_sum_256, _low); + __m256 _var0 = _mm512_castps512_ps256(_var_avx512); + __m256 _var1 = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(_var_avx512), 1)); + _var_avx = _mm256_add_ps(_var_avx, _var0); + _var_avx = _mm256_add_ps(_var_avx, _var1); } #endif // __AVX512F__ { - __m128 _low = _mm256_castps256_ps128(_sq_sum_256); - __m128 _high = _mm256_extractf128_ps(_sq_sum_256, 1); - _sq_sum_128 = _mm_add_ps(_sq_sum_128, _low); - _sq_sum_128 = _mm_add_ps(_sq_sum_128, _high); + __m128 _var0 = _mm256_castps256_ps128(_var_avx); + __m128 _var1 = _mm256_extractf128_ps(_var_avx, 1); + _var = _mm_add_ps(_var, _var0); + _var = _mm_add_ps(_var, _var1); } #endif // __AVX__ - __m128 _var = _mm_div_ps(_sq_sum_128, _mm_set1_ps((float)elemcount)); - _mm_storeu_ps(var, _var); - } -#endif // __SSE2__ - if (elempack == 1) - { -#if __SSE2__ + __m128 _elemcount = _mm_set1_ps((float)elemcount); + __m128 _eps = _mm_set1_ps(eps); + _var = _mm_div_ps(_var, _elemcount); + _var = _mm_add_ps(_var, _eps); + _var = _mm_rsqrt_ps(_var); + _mean = _mm_mul_ps(_mean, _var); #if __AVX__ + _var_avx = _mm256_insertf128_ps(_mm256_castps128_ps256(_var), _var, 1); + _mean_avx = _mm256_insertf128_ps(_mm256_castps128_ps256(_mean), _mean, 1); #if __AVX512F__ - sq_sum += _mm512_comp_reduce_add_ps(_sq_sum_512); + _var_avx512 = _mm512_insertf32x8(_mm512_castps256_ps512(_var_avx), _var_avx, 1); + _mean_avx512 = _mm512_insertf32x8(_mm512_castps256_ps512(_mean_avx), _mean_avx, 1); #endif // __AVX512F__ - sq_sum += _mm256_reduce_add_ps(_sq_sum_256); #endif // __AVX__ - sq_sum += _mm_reduce_add_ps(_sq_sum_128); -#endif // __SSE2__ - var[0] = sq_sum / elemcount; } -} - -static NCNN_FORCEINLINE void fast_fmadd(float* ptr, const float* a, const float* b, int elempack, int size) -{ - const float _a = a[0]; - const float _b = b[0]; +#endif // __SSE2__ + if (elempack == 1) + { #if __SSE2__ - __m128 _a_128 = (elempack == 4) ? _mm_loadu_ps(a) : _mm_set1_ps(_a); - __m128 _b_128 = (elempack == 4) ? _mm_loadu_ps(b) : _mm_set1_ps(_b); #if __AVX__ - __m256 _a_256 = (elempack == 8) ? _mm256_loadu_ps(a) : _mm256_insertf128_ps(_mm256_castps128_ps256(_a_128), _a_128, 1); - __m256 _b_256 = (elempack == 8) ? _mm256_loadu_ps(b) : _mm256_insertf128_ps(_mm256_castps128_ps256(_b_128), _b_128, 1); #if __AVX512F__ - __m512 _a_512 = (elempack == 16) ? _mm512_loadu_ps(a) : _mm512_insertf32x8(_mm512_castps256_ps512(_a_256), _a_256, 1); - __m512 _b_512 = (elempack == 16) ? _mm512_loadu_ps(b) : _mm512_insertf32x8(_mm512_castps256_ps512(_b_256), _b_256, 1); + var += _mm512_comp_reduce_add_ps(_var_avx512); #endif // __AVX512F__ + var += _mm256_reduce_add_ps(_var_avx); #endif // __AVX__ + var += _mm_reduce_add_ps(_var); #endif // __SSE2__ - int i = 0; + var = 1.f / sqrtf(var / elemcount + eps); + mean = mean * var; #if __SSE2__ + _var = _mm_set1_ps(var); + _mean = _mm_set1_ps(mean); #if __AVX__ + _var_avx = _mm256_insertf128_ps(_mm256_castps128_ps256(_var), _var, 1); + _mean_avx = _mm256_insertf128_ps(_mm256_castps128_ps256(_mean), _mean, 1); #if __AVX512F__ - for (; i + 16 <= size; i += 16, ptr += 16) - { - __m512 _cur = _mm512_loadu_ps(ptr); - _cur = _mm512_fmadd_ps(_cur, _a_512, _b_512); - _mm512_storeu_ps(ptr, _cur); - } + _var_avx512 = _mm512_insertf32x8(_mm512_castps256_ps512(_var_avx), _var_avx, 1); + _mean_avx512 = _mm512_insertf32x8(_mm512_castps256_ps512(_mean_avx), _mean_avx, 1); #endif // __AVX512F__ - for (; i + 8 <= size; i += 8, ptr += 8) - { - __m256 _cur = _mm256_loadu_ps(ptr); - _cur = _mm256_comp_fmadd_ps(_cur, _a_256, _b_256); - _mm256_storeu_ps(ptr, _cur); - } #endif // __AVX__ - for (; i + 4 <= size; i += 4, ptr += 4) - { - __m128 _cur = _mm_loadu_ps(ptr); - _cur = _mm_comp_fmadd_ps(_cur, _a_128, _b_128); - _mm_storeu_ps(ptr, _cur); - } #endif // __SSE2__ - for (; i < size; ++i, ++ptr) - { - *ptr = (*ptr) * _a + _b; } -} -static NCNN_FORCEINLINE void fast_fmadd_fmadd(float* ptr, const float* a, const float* b, const float* gamma, const float* beta, int elempack, int size) -{ + if (gamma_ptr && beta_ptr) + { + int i = 0; #if __SSE2__ #if __AVX__ #if __AVX512F__ - if (elempack == 16) - { - int i = 0; - __m512 _a_512 = _mm512_loadu_ps(a); - __m512 _b_512 = _mm512_loadu_ps(b); - for (; i + 16 <= size; i += 16, ptr += 16, ++gamma, ++beta) + if (elempack == 16) { - __m512 _cur = _mm512_loadu_ps(ptr); - __m512 _gamma = _mm512_set1_ps(*gamma); - __m512 _beta = _mm512_set1_ps(*beta); - _cur = _mm512_fmadd_ps(_cur, _a_512, _b_512); - _cur = _mm512_fmadd_ps(_cur, _gamma, _beta); - _mm512_storeu_ps(ptr, _cur); + for (; i + 15 < size; i += 16) + { + __m512 _p = _mm512_loadu_ps(ptr); + __m512 _gamma = _mm512_set1_ps(gamma_ptr[0]); + __m512 _beta = _mm512_set1_ps(beta_ptr[0]); + _p = _mm512_fmsub_ps(_p, _var_avx512, _mean_avx512); + _p = _mm512_fmadd_ps(_p, _gamma, _beta); + _mm512_storeu_ps(ptr, _p); + ptr += 16; + gamma_ptr += 1; + beta_ptr += 1; + } } - } #endif // __AVX512F__ - - if (elempack == 8) - { - int i = 0; - __m256 _a_256 = _mm256_loadu_ps(a); - __m256 _b_256 = _mm256_loadu_ps(b); -#if __AVX512F__ - __m512 _a_512 = _mm512_insertf32x8(_mm512_castps256_ps512(_a_256), _a_256, 1); - __m512 _b_512 = _mm512_insertf32x8(_mm512_castps256_ps512(_b_256), _b_256, 1); - for (; i + 16 <= size; i += 16, ptr += 16, gamma += 2, beta += 2) + if (elempack == 8) { - __m512 _cur = _mm512_loadu_ps(ptr); - __m512 _gamma_0 = _mm512_set1_ps(gamma[0]); - __m512 _gamma_1 = _mm512_set1_ps(gamma[1]); - __m512 _beta_0 = _mm512_set1_ps(beta[0]); - __m512 _beta_1 = _mm512_set1_ps(beta[1]); - _gamma_0 = _mm512_mask_blend_ps(0xFF00, _gamma_0, _gamma_1); - _beta_0 = _mm512_mask_blend_ps(0xFF00, _beta_0, _beta_1); - _cur = _mm512_fmadd_ps(_cur, _a_512, _b_512); - _cur = _mm512_fmadd_ps(_cur, _gamma_0, _beta_0); - _mm512_storeu_ps(ptr, _cur); - } +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = _mm512_loadu_ps(ptr); + __m256 _gamma0 = _mm256_set1_ps(gamma_ptr[0]); + __m256 _gamma1 = _mm256_set1_ps(gamma_ptr[1]); + __m512 _gamma = _mm512_insertf32x8(_mm512_castps256_ps512(_gamma0), _gamma1, 1); + __m256 _beta0 = _mm256_set1_ps(beta_ptr[0]); + __m256 _beta1 = _mm256_set1_ps(beta_ptr[1]); + __m512 _beta = _mm512_insertf32x8(_mm512_castps256_ps512(_beta0), _beta1, 1); + _p = _mm512_fmsub_ps(_p, _var_avx512, _mean_avx512); + _p = _mm512_fmadd_ps(_p, _gamma, _beta); + _mm512_storeu_ps(ptr, _p); + ptr += 16; + gamma_ptr += 2; + beta_ptr += 2; + } #endif // __AVX512F__ - - for (; i + 8 <= size; i += 8, ptr += 8, ++gamma, ++beta) - { - __m256 _cur = _mm256_loadu_ps(ptr); - __m256 _gamma = _mm256_set1_ps(*gamma); - __m256 _beta = _mm256_set1_ps(*beta); - _cur = _mm256_comp_fmadd_ps(_cur, _a_256, _b_256); - _cur = _mm256_comp_fmadd_ps(_cur, _gamma, _beta); - _mm256_storeu_ps(ptr, _cur); + for (; i + 7 < size; i += 8) + { + __m256 _p = _mm256_loadu_ps(ptr); + __m256 _gamma = _mm256_set1_ps(gamma_ptr[0]); + __m256 _beta = _mm256_set1_ps(beta_ptr[0]); + _p = _mm256_comp_fmsub_ps(_p, _var_avx, _mean_avx); + _p = _mm256_comp_fmadd_ps(_p, _gamma, _beta); + _mm256_storeu_ps(ptr, _p); + ptr += 8; + gamma_ptr += 1; + beta_ptr += 1; + } } - } #endif // __AVX__ - - if (elempack == 4) - { - int i = 0; - __m128 _a_128 = _mm_loadu_ps(a); - __m128 _b_128 = _mm_loadu_ps(b); + if (elempack == 4) + { #if __AVX__ - __m256 _a_256 = _mm256_insertf128_ps(_mm256_castps128_ps256(_a_128), _a_128, 1); - __m256 _b_256 = _mm256_insertf128_ps(_mm256_castps128_ps256(_b_128), _b_128, 1); #if __AVX512F__ - __m512 _a_512 = _mm512_insertf32x8(_mm512_castps256_ps512(_a_256), _a_256, 1); - __m512 _b_512 = _mm512_insertf32x8(_mm512_castps256_ps512(_b_256), _b_256, 1); - for (; i + 16 <= size; i += 16, ptr += 16, gamma += 4, beta += 4) - { - __m512 _cur = _mm512_loadu_ps(ptr); - __m512 _gamma_0 = _mm512_set1_ps(gamma[0]); - __m512 _gamma_1 = _mm512_set1_ps(gamma[1]); - __m512 _gamma_2 = _mm512_set1_ps(gamma[2]); - __m512 _gamma_3 = _mm512_set1_ps(gamma[3]); - __m512 _beta_0 = _mm512_set1_ps(beta[0]); - __m512 _beta_1 = _mm512_set1_ps(beta[1]); - __m512 _beta_2 = _mm512_set1_ps(beta[2]); - __m512 _beta_3 = _mm512_set1_ps(beta[3]); - _gamma_0 = _mm512_mask_blend_ps(0x00F0, _gamma_0, _gamma_1); - _gamma_0 = _mm512_mask_blend_ps(0x0F00, _gamma_0, _gamma_2); - _gamma_0 = _mm512_mask_blend_ps(0xF000, _gamma_0, _gamma_3); - _beta_0 = _mm512_mask_blend_ps(0x00F0, _beta_0, _beta_1); - _beta_0 = _mm512_mask_blend_ps(0x0F00, _beta_0, _beta_2); - _beta_0 = _mm512_mask_blend_ps(0xF000, _beta_0, _beta_3); - _cur = _mm512_fmadd_ps(_cur, _a_512, _b_512); - _cur = _mm512_fmadd_ps(_cur, _gamma_0, _beta_0); - _mm512_storeu_ps(ptr, _cur); - } + for (; i + 15 < size; i += 16) + { + __m512 _p = _mm512_loadu_ps(ptr); + __m128 _gamma0 = _mm_set1_ps(gamma_ptr[0]); + __m128 _gamma1 = _mm_set1_ps(gamma_ptr[1]); + __m128 _gamma2 = _mm_set1_ps(gamma_ptr[2]); + __m128 _gamma3 = _mm_set1_ps(gamma_ptr[3]); + __m256 _gamma01 = _mm256_insertf128_ps(_mm256_castps128_ps256(_gamma0), _gamma1, 1); + __m256 _gamma23 = _mm256_insertf128_ps(_mm256_castps128_ps256(_gamma2), _gamma3, 1); + __m512 _gamma = _mm512_insertf32x8(_mm512_castps256_ps512(_gamma01), _gamma23, 1); + __m128 _beta0 = _mm_set1_ps(beta_ptr[0]); + __m128 _beta1 = _mm_set1_ps(beta_ptr[1]); + __m128 _beta2 = _mm_set1_ps(beta_ptr[2]); + __m128 _beta3 = _mm_set1_ps(beta_ptr[3]); + __m256 _beta01 = _mm256_insertf128_ps(_mm256_castps128_ps256(_beta0), _beta1, 1); + __m256 _beta23 = _mm256_insertf128_ps(_mm256_castps128_ps256(_beta2), _beta3, 1); + __m512 _beta = _mm512_insertf32x8(_mm512_castps256_ps512(_beta01), _beta23, 1); + _p = _mm512_fmsub_ps(_p, _var_avx512, _mean_avx512); + _p = _mm512_fmadd_ps(_p, _gamma, _beta); + _mm512_storeu_ps(ptr, _p); + ptr += 16; + gamma_ptr += 4; + beta_ptr += 4; + } #endif // __AVX512F__ - - for (; i + 8 <= size; i += 8, ptr += 8, gamma += 2, beta += 2) - { - __m256 _cur = _mm256_loadu_ps(ptr); - __m256 _gamma_0 = _mm256_set1_ps(gamma[0]); - __m256 _gamma_1 = _mm256_set1_ps(gamma[1]); - __m256 _beta_0 = _mm256_set1_ps(beta[0]); - __m256 _beta_1 = _mm256_set1_ps(beta[1]); - _gamma_0 = _mm256_blend_ps(_gamma_0, _gamma_1, 0xF0); - _beta_0 = _mm256_blend_ps(_beta_0, _beta_1, 0xF0); - _cur = _mm256_comp_fmadd_ps(_cur, _a_256, _b_256); - _cur = _mm256_comp_fmadd_ps(_cur, _gamma_0, _beta_0); - _mm256_storeu_ps(ptr, _cur); + for (; i + 7 < size; i += 8) + { + __m256 _p = _mm256_loadu_ps(ptr); + __m128 _gamma0 = _mm_set1_ps(gamma_ptr[0]); + __m128 _gamma1 = _mm_set1_ps(gamma_ptr[1]); + __m256 _gamma = _mm256_insertf128_ps(_mm256_castps128_ps256(_gamma0), _gamma1, 1); + __m128 _beta0 = _mm_set1_ps(beta_ptr[0]); + __m128 _beta1 = _mm_set1_ps(beta_ptr[1]); + __m256 _beta = _mm256_insertf128_ps(_mm256_castps128_ps256(_beta0), _beta1, 1); + _p = _mm256_comp_fmsub_ps(_p, _var_avx, _mean_avx); + _p = _mm256_comp_fmadd_ps(_p, _gamma, _beta); + _mm256_storeu_ps(ptr, _p); + ptr += 8; + gamma_ptr += 2; + beta_ptr += 2; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = _mm_loadu_ps(ptr); + __m128 _gamma = _mm_set1_ps(gamma_ptr[0]); + __m128 _beta = _mm_set1_ps(beta_ptr[0]); + _p = _mm_comp_fmsub_ps(_p, _var, _mean); + _p = _mm_comp_fmadd_ps(_p, _gamma, _beta); + _mm_storeu_ps(ptr, _p); + ptr += 4; + gamma_ptr += 1; + beta_ptr += 1; + } } + if (elempack == 1) + { +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = _mm512_loadu_ps(ptr); + __m512 _gamma = _mm512_loadu_ps(gamma_ptr); + __m512 _beta = _mm512_loadu_ps(beta_ptr); + _p = _mm512_fmsub_ps(_p, _var_avx512, _mean_avx512); + _p = _mm512_fmadd_ps(_p, _gamma, _beta); + _mm512_storeu_ps(ptr, _p); + ptr += 16; + gamma_ptr += 16; + beta_ptr += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = _mm256_loadu_ps(ptr); + __m256 _gamma = _mm256_loadu_ps(gamma_ptr); + __m256 _beta = _mm256_loadu_ps(beta_ptr); + _p = _mm256_comp_fmsub_ps(_p, _var_avx, _mean_avx); + _p = _mm256_comp_fmadd_ps(_p, _gamma, _beta); + _mm256_storeu_ps(ptr, _p); + ptr += 8; + gamma_ptr += 8; + beta_ptr += 8; + } #endif // __AVX__ - - for (; i + 4 <= size; i += 4, ptr += 4, ++gamma, ++beta) + for (; i + 3 < size; i += 4) + { + __m128 _p = _mm_loadu_ps(ptr); + __m128 _gamma = _mm_loadu_ps(gamma_ptr); + __m128 _beta = _mm_loadu_ps(beta_ptr); + _p = _mm_comp_fmsub_ps(_p, _var, _mean); + _p = _mm_comp_fmadd_ps(_p, _gamma, _beta); + _mm_storeu_ps(ptr, _p); + ptr += 4; + gamma_ptr += 4; + beta_ptr += 4; + } + } +#endif // __SSE2__ + for (; i < size; i++) { - __m128 _cur = _mm_loadu_ps(ptr); - __m128 _gamma = _mm_set1_ps(*gamma); - __m128 _beta = _mm_set1_ps(*beta); - _cur = _mm_comp_fmadd_ps(_cur, _a_128, _b_128); - _cur = _mm_comp_fmadd_ps(_cur, _gamma, _beta); - _mm_storeu_ps(ptr, _cur); + ptr[0] = (ptr[0] * var - mean) * gamma_ptr[0] + beta_ptr[0]; + ptr++; + gamma_ptr++; + beta_ptr++; } } -#endif // __SSE2__ - - if (elempack == 1) + else { int i = 0; - const float _a = a[0]; - const float _b = b[0]; #if __SSE2__ - __m128 _a_128 = _mm_set1_ps(_a); - __m128 _b_128 = _mm_set1_ps(_b); #if __AVX__ - __m256 _a_256 = _mm256_insertf128_ps(_mm256_castps128_ps256(_a_128), _a_128, 1); - __m256 _b_256 = _mm256_insertf128_ps(_mm256_castps128_ps256(_b_128), _b_128, 1); #if __AVX512F__ - __m512 _a_512 = _mm512_insertf32x8(_mm512_castps256_ps512(_a_256), _a_256, 1); - __m512 _b_512 = _mm512_insertf32x8(_mm512_castps256_ps512(_b_256), _b_256, 1); - for (; i + 16 <= size; i += 16, ptr += 16, gamma += 16, beta += 16) + for (; i + 15 < size; i += 16) { - __m512 _cur = _mm512_loadu_ps(ptr); - __m512 _gamma = _mm512_loadu_ps(gamma); - __m512 _beta = _mm512_loadu_ps(beta); - _cur = _mm512_fmadd_ps(_cur, _a_512, _b_512); - _cur = _mm512_fmadd_ps(_cur, _gamma, _beta); - _mm512_storeu_ps(ptr, _cur); + __m512 _p = _mm512_loadu_ps(ptr); + _p = _mm512_fmsub_ps(_p, _var_avx512, _mean_avx512); + _mm512_storeu_ps(ptr, _p); + ptr += 16; } #endif // __AVX512F__ - - for (; i + 8 <= size; i += 8, ptr += 8, gamma += 8, beta += 8) + for (; i + 7 < size; i += 8) { - __m256 _cur = _mm256_loadu_ps(ptr); - __m256 _gamma = _mm256_loadu_ps(gamma); - __m256 _beta = _mm256_loadu_ps(beta); - _cur = _mm256_comp_fmadd_ps(_cur, _a_256, _b_256); - _cur = _mm256_comp_fmadd_ps(_cur, _gamma, _beta); - _mm256_storeu_ps(ptr, _cur); + __m256 _p = _mm256_loadu_ps(ptr); + _p = _mm256_comp_fmsub_ps(_p, _var_avx, _mean_avx); + _mm256_storeu_ps(ptr, _p); + ptr += 8; } #endif // __AVX__ - - for (; i + 4 <= size; i += 4, ptr += 4, gamma += 4, beta += 4) + for (; i + 3 < size; i += 4) { - __m128 _cur = _mm_loadu_ps(ptr); - __m128 _gamma = _mm_loadu_ps(gamma); - __m128 _beta = _mm_loadu_ps(beta); - _cur = _mm_comp_fmadd_ps(_cur, _a_128, _b_128); - _cur = _mm_comp_fmadd_ps(_cur, _gamma, _beta); - _mm_storeu_ps(ptr, _cur); + __m128 _p = _mm_loadu_ps(ptr); + _p = _mm_comp_fmsub_ps(_p, _var, _mean); + _mm_storeu_ps(ptr, _p); + ptr += 4; } #endif // __SSE2__ - - for (; i < size; ++i, ++ptr, ++gamma, ++beta) + for (; i < size; i++) { - *ptr = ((*ptr) * _a + _b) * (*gamma) + (*beta); + ptr[0] = ptr[0] * var - mean; + ptr++; } } } -static NCNN_FORCEINLINE void fast_1d_layer_norm(float* ptr, int elempack, int elemcount, int size, const float* gamma, const float* beta, int affine, float eps) -{ - float mean[16] = {0.f}, var[16] = {0.f}; - fast_mean(ptr, mean, elempack, elemcount, size); - fast_var(ptr, var, mean, elempack, elemcount, size); - float *a = var, *b = mean; - -#if __SSE2__ -#if __AVX__ -#if __AVX512F__ - if (elempack == 16) - { - __m512 _a = _mm512_set1_ps(1.0f); - __m512 _eps = _mm512_set1_ps(eps); - __m512 _b = _mm512_setzero_ps(); - __m512 _var = _mm512_loadu_ps(var); - _var = _mm512_add_ps(_var, _eps); - __m512 _sqrt_var = _mm512_sqrt_ps(_var); - _a = _mm512_div_ps(_a, _sqrt_var); - __m512 _mean = _mm512_loadu_ps(mean); - _b = _mm512_fnmadd_ps(_mean, _a, _b); - - _mm512_storeu_ps(a, _a); - _mm512_storeu_ps(b, _b); - } -#endif // __AVX512F__ - if (elempack == 8) - { - __m256 _a = _mm256_set1_ps(1.0f); - __m256 _eps = _mm256_set1_ps(eps); - __m256 _b = _mm256_setzero_ps(); - __m256 _var = _mm256_loadu_ps(var); - _var = _mm256_add_ps(_var, _eps); - __m256 _sqrt_var = _mm256_sqrt_ps(_var); - _a = _mm256_div_ps(_a, _sqrt_var); - __m256 _mean = _mm256_loadu_ps(mean); - _b = _mm256_comp_fnmadd_ps(_mean, _a, _b); - - _mm256_storeu_ps(a, _a); - _mm256_storeu_ps(b, _b); - } -#endif // __AVX__ - if (elempack == 4) - { - __m128 _a = _mm_set1_ps(1.0f); - __m128 _eps = _mm_set1_ps(eps); - __m128 _b = _mm_setzero_ps(); - __m128 _var = _mm_loadu_ps(var); - _var = _mm_add_ps(_var, _eps); - __m128 _sqrt_var = _mm_sqrt_ps(_var); - _a = _mm_div_ps(_a, _sqrt_var); - __m128 _mean = _mm_loadu_ps(mean); - _b = _mm_comp_fnmadd_ps(_mean, _a, _b); - - _mm_storeu_ps(a, _a); - _mm_storeu_ps(b, _b); - } -#endif // __SSE2__ - if (elempack == 1) - { - a[0] = 1.0f / sqrtf(var[0] + eps); - b[0] = -mean[0] * (a[0]); - } - - if (affine) - { - fast_fmadd_fmadd(ptr, a, b, gamma, beta, elempack, size); - } - else - { - fast_fmadd(ptr, a, b, elempack, size); - } -} - int LayerNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const { - int dims = bottom_top_blob.dims; - int elempack = bottom_top_blob.elempack; - int w = bottom_top_blob.w; - int h = bottom_top_blob.h; - int channels = bottom_top_blob.c; - - const float* gamma = gamma_data; - const float* beta = beta_data; + const int dims = bottom_top_blob.dims; + const int elempack = bottom_top_blob.elempack; + const int w = bottom_top_blob.w; + const int h = bottom_top_blob.h; + const int channels = bottom_top_blob.c; if (dims == 1) { - int elemcount = w * elempack; + // assert affine_size == w + float* ptr = bottom_top_blob; - // 1D layer norm is special. Treat them as unpacked. - fast_1d_layer_norm(ptr, 1, elemcount, elemcount, gamma, beta, affine, eps); + layernorm(ptr, gamma_data, beta_data, eps, w * elempack, 1); } if (dims == 2) { + // assert affine_size == w + #pragma omp parallel for num_threads(opt.num_threads) - for (int i = 0; i < h; ++i) + for (int i = 0; i < h; i++) { float* ptr = bottom_top_blob.row(i); - fast_1d_layer_norm(ptr, elempack, w, w * elempack, gamma, beta, affine, eps); + layernorm(ptr, gamma_data, beta_data, eps, w, elempack); } } @@ -568,22 +549,22 @@ int LayerNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons if (affine_size == w) { #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; ++q) + for (int q = 0; q < channels; q++) { - for (int i = 0; i < h; ++i) + for (int i = 0; i < h; i++) { float* ptr = bottom_top_blob.channel(q).row(i); - fast_1d_layer_norm(ptr, elempack, w, w * elempack, gamma, beta, affine, eps); + layernorm(ptr, gamma_data, beta_data, eps, w, elempack); } } } else // if (affine_size == w * h) { #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; ++q) + for (int q = 0; q < channels; q++) { float* ptr = bottom_top_blob.channel(q); - fast_1d_layer_norm(ptr, elempack, w * h, w * h * elempack, gamma, beta, affine, eps); + layernorm(ptr, gamma_data, beta_data, eps, w * h, elempack); } } } From cbd17cd062235ec9533fcc42aeb4a683d8ce32ee Mon Sep 17 00:00:00 2001 From: Upliner Mikhalych Date: Fri, 18 Oct 2024 05:23:57 +0300 Subject: [PATCH 08/15] Fix #5741 don't crash when vkCreateDevice fails (#5742) --- src/gpu.cpp | 45 ++++++++++++++++++++++++++++++++++++--------- src/gpu.h | 1 + src/net.cpp | 4 ++-- 3 files changed, 39 insertions(+), 11 deletions(-) diff --git a/src/gpu.cpp b/src/gpu.cpp index cdcec0f7f02e..c8b48901ef1e 100644 --- a/src/gpu.cpp +++ b/src/gpu.cpp @@ -2259,10 +2259,7 @@ class VkDummyCompute : public VkCompute class VulkanDevicePrivate { public: - VulkanDevicePrivate(VulkanDevice* _vkdev) - : vkdev(_vkdev) - { - } + VulkanDevicePrivate(VulkanDevice* _vkdev); VulkanDevice* const vkdev; // dummy buffer and image @@ -2317,8 +2314,22 @@ class VulkanDevicePrivate // to pack1 | pack4 | pack8 mutable ncnn::Packing_vulkan* uop_packing[2][2][3][3][3]; mutable Mutex uop_lock; + + // device is valid and sucessfully initialized + bool valid; }; +VulkanDevicePrivate::VulkanDevicePrivate(VulkanDevice* _vkdev) + : vkdev(_vkdev) +{ + device = 0; + texelfetch_sampler = 0; + dummy_allocator = 0; + pipeline_cache = 0; + valid = false; + memset(uop_packing, 0, sizeof(uop_packing)); +} + int VulkanDevicePrivate::create_dummy_buffer_image() { dummy_allocator = new VkDummyAllocator(vkdev); @@ -2357,7 +2368,11 @@ void VulkanDevicePrivate::destroy_dummy_buffer_image() dummy_image_readonly.release(); #endif - delete dummy_allocator; + if (dummy_allocator) + { + delete dummy_allocator; + dummy_allocator = 0; + } } const ncnn::Packing_vulkan* VulkanDevicePrivate::get_utility_operator(int storage_type_from, int storage_type_to, int cast_type_from_index, int cast_type_to_index, int packing_type_to_index) const @@ -2702,6 +2717,7 @@ VulkanDevice::VulkanDevice(int device_index) if (ret != VK_SUCCESS) { NCNN_LOGE("vkCreateDevice failed %d", ret); + return; } init_device_extension(); @@ -2761,7 +2777,6 @@ VulkanDevice::VulkanDevice(int device_index) samplerCreateInfo.borderColor = VK_BORDER_COLOR_FLOAT_TRANSPARENT_BLACK; samplerCreateInfo.unnormalizedCoordinates = VK_TRUE; - d->texelfetch_sampler = 0; ret = vkCreateSampler(d->device, &samplerCreateInfo, 0, &d->texelfetch_sampler); if (ret != VK_SUCCESS) { @@ -2773,11 +2788,12 @@ VulkanDevice::VulkanDevice(int device_index) if (cret != 0) { NCNN_LOGE("VulkanDevice create_dummy_buffer_image failed %d", cret); + return; } d->pipeline_cache = new PipelineCache(this); - memset(d->uop_packing, 0, sizeof(d->uop_packing)); + d->valid = true; } VulkanDevice::~VulkanDevice() @@ -2802,9 +2818,15 @@ VulkanDevice::~VulkanDevice() } d->staging_allocators.clear(); - delete d->pipeline_cache; + if (d->pipeline_cache) + { + delete d->pipeline_cache; + } - vkDestroyDevice(d->device, 0); + if (d->device) + { + vkDestroyDevice(d->device, 0); + } delete d; } @@ -2824,6 +2846,11 @@ VkDevice VulkanDevice::vkdevice() const return d->device; } +bool VulkanDevice::is_valid() const +{ + return d->valid; +} + VkShaderModule VulkanDevice::compile_shader_module(const uint32_t* spv_data, size_t spv_data_size) const { VkShaderModuleCreateInfo shaderModuleCreateInfo; diff --git a/src/gpu.h b/src/gpu.h index b98827b69ebd..0df13e7ea5ed 100644 --- a/src/gpu.h +++ b/src/gpu.h @@ -341,6 +341,7 @@ class NCNN_EXPORT VulkanDevice const GpuInfo& info; VkDevice vkdevice() const; + bool is_valid() const; VkShaderModule compile_shader_module(const uint32_t* spv_data, size_t spv_data_size) const; diff --git a/src/net.cpp b/src/net.cpp index 3574944e726e..32b5b2abd601 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -1381,7 +1381,7 @@ int Net::load_param(const DataReader& dr) if (opt.use_vulkan_compute) { if (!d->vkdev) d->vkdev = get_gpu_device(); - if (!d->vkdev) opt.use_vulkan_compute = false; // no vulkan device, fallback to cpu + if (!d->vkdev || !d->vkdev->is_valid()) opt.use_vulkan_compute = false; // no valid vulkan device, fallback to cpu } if (opt.use_vulkan_compute) { @@ -1677,7 +1677,7 @@ int Net::load_param_bin(const DataReader& dr) if (opt.use_vulkan_compute) { if (!d->vkdev) d->vkdev = get_gpu_device(); - if (!d->vkdev) opt.use_vulkan_compute = false; // no vulkan device, fallback to cpu + if (!d->vkdev || !d->vkdev->is_valid()) opt.use_vulkan_compute = false; // no valid vulkan device, fallback to cpu } if (opt.use_vulkan_compute) { From 8fe62812c9154d07ecd48b529a986f3c1488d4e4 Mon Sep 17 00:00:00 2001 From: nihui Date: Fri, 18 Oct 2024 21:32:18 +0800 Subject: [PATCH 09/15] arm neon optimization for layernorm fp32/bf16s/fp16s (#5746) --- src/layer/arm/layernorm_arm.cpp | 517 ++++++++++++++++++++++++ src/layer/arm/layernorm_arm.h | 40 ++ src/layer/arm/layernorm_arm_asimdhp.cpp | 351 ++++++++++++++++ src/net.cpp | 2 +- tests/testutil.cpp | 2 +- 5 files changed, 910 insertions(+), 2 deletions(-) create mode 100644 src/layer/arm/layernorm_arm.cpp create mode 100644 src/layer/arm/layernorm_arm.h create mode 100644 src/layer/arm/layernorm_arm_asimdhp.cpp diff --git a/src/layer/arm/layernorm_arm.cpp b/src/layer/arm/layernorm_arm.cpp new file mode 100644 index 000000000000..4c49a5e76b7e --- /dev/null +++ b/src/layer/arm/layernorm_arm.cpp @@ -0,0 +1,517 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "layernorm_arm.h" + +#if __ARM_NEON +#include +#include "neon_mathfun.h" +#endif // __ARM_NEON + +#include "arm_usability.h" +#include "cpu.h" + +namespace ncnn { + +LayerNorm_arm::LayerNorm_arm() +{ +#if __ARM_NEON + support_packing = true; +#if NCNN_ARM82 + support_fp16_storage = cpu_support_arm_asimdhp(); +#endif +#endif // __ARM_NEON + +#if NCNN_BF16 + support_bf16_storage = true; +#endif +} + +static void layernorm(float* ptr, const float* gamma_ptr, const float* beta_ptr, float eps, int elemcount, int elempack) +{ + const int size = elemcount * elempack; + +#if __ARM_NEON + float32x4_t _mean = vdupq_n_f32(0.f); +#endif // __ARM_NEON + float mean = 0.f; + { + const float* ptr0 = ptr; + + int i = 0; +#if __ARM_NEON + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vld1q_f32(ptr0); + _mean = vaddq_f32(_mean, _p); + ptr0 += 4; + } +#endif // __ARM_NEON + for (; i < size; i++) + { + mean += ptr0[0]; + ptr0++; + } + } + +#if __ARM_NEON + if (elempack == 4) + { + float32x4_t _elemcount = vdupq_n_f32(elemcount); + _mean = div_ps(_mean, _elemcount); + } +#endif // __ARM_NEON + if (elempack == 1) + { +#if __ARM_NEON +#if __aarch64__ + mean += vaddvq_f32(_mean); +#else + float32x2_t _s2 = vadd_f32(vget_low_f32(_mean), vget_high_f32(_mean)); + _s2 = vpadd_f32(_s2, _s2); + mean += vget_lane_f32(_s2, 0); +#endif +#endif // __ARM_NEON + + mean = mean / elemcount; +#if __ARM_NEON + _mean = vdupq_n_f32(mean); +#endif // __ARM_NEON + } + +#if __ARM_NEON + float32x4_t _var = vdupq_n_f32(0.f); +#endif // __ARM_NEON + float var = 0.f; + { + const float* ptr0 = ptr; + + int i = 0; +#if __ARM_NEON + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vld1q_f32(ptr0); + _p = vsubq_f32(_p, _mean); + _var = vmlaq_f32(_var, _p, _p); + ptr0 += 4; + } +#endif // __ARM_NEON + for (; i < size; i++) + { + float v = ptr0[0] - mean; + var += v * v; + ptr0++; + } + } + +#if __ARM_NEON + if (elempack == 4) + { + float32x4_t _elemcount = vdupq_n_f32(elemcount); + float32x4_t _eps = vdupq_n_f32(eps); + _var = div_ps(_var, _elemcount); + _var = vaddq_f32(_var, _eps); + float32x4_t _rsqrt_var = vrsqrteq_f32(_var); + _rsqrt_var = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_var, _rsqrt_var), _rsqrt_var), _rsqrt_var); + _var = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_var, _rsqrt_var), _rsqrt_var), _rsqrt_var); + _mean = vmulq_f32(_mean, _var); + _mean = vnegq_f32(_mean); + } +#endif // __ARM_NEON + if (elempack == 1) + { +#if __ARM_NEON +#if __aarch64__ + var += vaddvq_f32(_var); +#else + float32x2_t _s2 = vadd_f32(vget_low_f32(_var), vget_high_f32(_var)); + _s2 = vpadd_f32(_s2, _s2); + var += vget_lane_f32(_s2, 0); +#endif +#endif // __ARM_NEON + + var = 1.f / sqrtf(var / elemcount + eps); + mean = -mean * var; +#if __ARM_NEON + _var = vdupq_n_f32(var); + _mean = vdupq_n_f32(mean); +#endif // __ARM_NEON + } + + if (gamma_ptr && beta_ptr) + { + int i = 0; +#if __ARM_NEON + if (elempack == 4) + { + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vld1q_f32(ptr); + float32x4_t _gamma = vdupq_n_f32(gamma_ptr[0]); + float32x4_t _beta = vdupq_n_f32(beta_ptr[0]); + _p = vmlaq_f32(_mean, _p, _var); + _p = vmlaq_f32(_beta, _p, _gamma); + vst1q_f32(ptr, _p); + ptr += 4; + gamma_ptr += 1; + beta_ptr += 1; + } + } + if (elempack == 1) + { + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vld1q_f32(ptr); + float32x4_t _gamma = vld1q_f32(gamma_ptr); + float32x4_t _beta = vld1q_f32(beta_ptr); + _p = vmlaq_f32(_mean, _p, _var); + _p = vmlaq_f32(_beta, _p, _gamma); + vst1q_f32(ptr, _p); + ptr += 4; + gamma_ptr += 4; + beta_ptr += 4; + } + } +#endif // __ARM_NEON + for (; i < size; i++) + { + ptr[0] = (ptr[0] * var + mean) * gamma_ptr[0] + beta_ptr[0]; + ptr++; + gamma_ptr++; + beta_ptr++; + } + } + else + { + int i = 0; +#if __ARM_NEON + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vld1q_f32(ptr); + _p = vmlaq_f32(_mean, _p, _var); + vst1q_f32(ptr, _p); + ptr += 4; + } +#endif // __ARM_NEON + for (; i < size; i++) + { + ptr[0] = ptr[0] * var + mean; + ptr++; + } + } +} + +int LayerNorm_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const +{ + int elembits = bottom_top_blob.elembits(); + +#if NCNN_ARM82 + if (support_fp16_storage && opt.use_fp16_storage && elembits == 16) + return forward_inplace_fp16s(bottom_top_blob, opt); +#endif + +#if NCNN_BF16 + if (opt.use_bf16_storage && elembits == 16) + return forward_inplace_bf16s(bottom_top_blob, opt); +#endif + + const int dims = bottom_top_blob.dims; + const int w = bottom_top_blob.w; + const int h = bottom_top_blob.h; + const int channels = bottom_top_blob.c; + const int elempack = bottom_top_blob.elempack; + + if (dims == 1) + { + // assert affine_size == w + + float* ptr = bottom_top_blob; + layernorm(ptr, gamma_data, beta_data, eps, w * elempack, 1); + } + + if (dims == 2) + { + // assert affine_size == w + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + float* ptr = bottom_top_blob.row(i); + layernorm(ptr, gamma_data, beta_data, eps, w, elempack); + } + } + + if (dims == 3) + { + if (affine_size == w) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + for (int i = 0; i < h; i++) + { + float* ptr = bottom_top_blob.channel(q).row(i); + layernorm(ptr, gamma_data, beta_data, eps, w, elempack); + } + } + } + else // if (affine_size == w * h) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + float* ptr = bottom_top_blob.channel(q); + layernorm(ptr, gamma_data, beta_data, eps, w * h, elempack); + } + } + } + + return 0; +} + +#if NCNN_BF16 +static void layernorm_bf16s(unsigned short* ptr, const float* gamma_ptr, const float* beta_ptr, float eps, int elemcount, int elempack) +{ + const int size = elemcount * elempack; + +#if __ARM_NEON + float32x4_t _mean = vdupq_n_f32(0.f); +#endif // __ARM_NEON + float mean = 0.f; + { + const unsigned short* ptr0 = ptr; + + int i = 0; +#if __ARM_NEON + for (; i + 3 < size; i += 4) + { + float32x4_t _p = bfloat2float(vld1_u16(ptr0)); + _mean = vaddq_f32(_mean, _p); + ptr0 += 4; + } +#endif // __ARM_NEON + for (; i < size; i++) + { + mean += bfloat16_to_float32(ptr0[0]); + ptr0++; + } + } + +#if __ARM_NEON + if (elempack == 4) + { + float32x4_t _elemcount = vdupq_n_f32(elemcount); + _mean = div_ps(_mean, _elemcount); + } +#endif // __ARM_NEON + if (elempack == 1) + { +#if __ARM_NEON +#if __aarch64__ + mean += vaddvq_f32(_mean); +#else + float32x2_t _s2 = vadd_f32(vget_low_f32(_mean), vget_high_f32(_mean)); + _s2 = vpadd_f32(_s2, _s2); + mean += vget_lane_f32(_s2, 0); +#endif +#endif // __ARM_NEON + + mean = mean / elemcount; +#if __ARM_NEON + _mean = vdupq_n_f32(mean); +#endif // __ARM_NEON + } + +#if __ARM_NEON + float32x4_t _var = vdupq_n_f32(0.f); +#endif // __ARM_NEON + float var = 0.f; + { + const unsigned short* ptr0 = ptr; + + int i = 0; +#if __ARM_NEON + for (; i + 3 < size; i += 4) + { + float32x4_t _p = bfloat2float(vld1_u16(ptr0)); + _p = vsubq_f32(_p, _mean); + _var = vmlaq_f32(_var, _p, _p); + ptr0 += 4; + } +#endif // __ARM_NEON + for (; i < size; i++) + { + float v = bfloat16_to_float32(ptr0[0]) - mean; + var += v * v; + ptr0++; + } + } + +#if __ARM_NEON + if (elempack == 4) + { + float32x4_t _elemcount = vdupq_n_f32(elemcount); + float32x4_t _eps = vdupq_n_f32(eps); + _var = div_ps(_var, _elemcount); + _var = vaddq_f32(_var, _eps); + float32x4_t _rsqrt_var = vrsqrteq_f32(_var); + _rsqrt_var = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_var, _rsqrt_var), _rsqrt_var), _rsqrt_var); + _var = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_var, _rsqrt_var), _rsqrt_var), _rsqrt_var); + _mean = vmulq_f32(_mean, _var); + _mean = vnegq_f32(_mean); + } +#endif // __ARM_NEON + if (elempack == 1) + { +#if __ARM_NEON +#if __aarch64__ + var += vaddvq_f32(_var); +#else + float32x2_t _s2 = vadd_f32(vget_low_f32(_var), vget_high_f32(_var)); + _s2 = vpadd_f32(_s2, _s2); + var += vget_lane_f32(_s2, 0); +#endif +#endif // __ARM_NEON + + var = 1.f / sqrtf(var / elemcount + eps); + mean = -mean * var; +#if __ARM_NEON + _var = vdupq_n_f32(var); + _mean = vdupq_n_f32(mean); +#endif // __ARM_NEON + } + + if (gamma_ptr && beta_ptr) + { + int i = 0; +#if __ARM_NEON + if (elempack == 4) + { + for (; i + 3 < size; i += 4) + { + float32x4_t _p = bfloat2float(vld1_u16(ptr)); + float32x4_t _gamma = vdupq_n_f32(gamma_ptr[0]); + float32x4_t _beta = vdupq_n_f32(beta_ptr[0]); + _p = vmlaq_f32(_mean, _p, _var); + _p = vmlaq_f32(_beta, _p, _gamma); + vst1_u16(ptr, float2bfloat(_p)); + ptr += 4; + gamma_ptr += 1; + beta_ptr += 1; + } + } + if (elempack == 1) + { + for (; i + 3 < size; i += 4) + { + float32x4_t _p = bfloat2float(vld1_u16(ptr)); + float32x4_t _gamma = vld1q_f32(gamma_ptr); + float32x4_t _beta = vld1q_f32(beta_ptr); + _p = vmlaq_f32(_mean, _p, _var); + _p = vmlaq_f32(_beta, _p, _gamma); + vst1_u16(ptr, float2bfloat(_p)); + ptr += 4; + gamma_ptr += 4; + beta_ptr += 4; + } + } +#endif // __ARM_NEON + for (; i < size; i++) + { + float v = bfloat16_to_float32(ptr[0]); + ptr[0] = float32_to_bfloat16((v * var + mean) * gamma_ptr[0] + beta_ptr[0]); + ptr++; + gamma_ptr++; + beta_ptr++; + } + } + else + { + int i = 0; +#if __ARM_NEON + for (; i + 3 < size; i += 4) + { + float32x4_t _p = bfloat2float(vld1_u16(ptr)); + _p = vmlaq_f32(_mean, _p, _var); + vst1_u16(ptr, float2bfloat(_p)); + ptr += 4; + } +#endif // __ARM_NEON + for (; i < size; i++) + { + float v = bfloat16_to_float32(ptr[0]); + ptr[0] = float32_to_bfloat16(v * var + mean); + ptr++; + } + } +} + +int LayerNorm_arm::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const +{ + const int dims = bottom_top_blob.dims; + const int w = bottom_top_blob.w; + const int h = bottom_top_blob.h; + const int channels = bottom_top_blob.c; + const int elempack = bottom_top_blob.elempack; + + if (dims == 1) + { + // assert affine_size == w + + unsigned short* ptr = bottom_top_blob; + layernorm_bf16s(ptr, gamma_data, beta_data, eps, w * elempack, 1); + } + + if (dims == 2) + { + // assert affine_size == w + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + unsigned short* ptr = bottom_top_blob.row(i); + layernorm_bf16s(ptr, gamma_data, beta_data, eps, w, elempack); + } + } + + if (dims == 3) + { + if (affine_size == w) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + for (int i = 0; i < h; i++) + { + unsigned short* ptr = bottom_top_blob.channel(q).row(i); + layernorm_bf16s(ptr, gamma_data, beta_data, eps, w, elempack); + } + } + } + else // if (affine_size == w * h) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + unsigned short* ptr = bottom_top_blob.channel(q); + layernorm_bf16s(ptr, gamma_data, beta_data, eps, w * h, elempack); + } + } + } + + return 0; +} +#endif // NCNN_BF16 + +} // namespace ncnn diff --git a/src/layer/arm/layernorm_arm.h b/src/layer/arm/layernorm_arm.h new file mode 100644 index 000000000000..d3bcac1b2762 --- /dev/null +++ b/src/layer/arm/layernorm_arm.h @@ -0,0 +1,40 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef LAYER_LAYERNORM_ARM_H +#define LAYER_LAYERNORM_ARM_H + +#include "layernorm.h" + +namespace ncnn { + +class LayerNorm_arm : public LayerNorm +{ +public: + LayerNorm_arm(); + + virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; + +protected: +#if NCNN_ARM82 + int forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) const; +#endif +#if NCNN_BF16 + int forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const; +#endif +}; + +} // namespace ncnn + +#endif // LAYER_LAYERNORM_ARM_H diff --git a/src/layer/arm/layernorm_arm_asimdhp.cpp b/src/layer/arm/layernorm_arm_asimdhp.cpp new file mode 100644 index 000000000000..1b746707dc8e --- /dev/null +++ b/src/layer/arm/layernorm_arm_asimdhp.cpp @@ -0,0 +1,351 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "layernorm_arm.h" + +#if __ARM_NEON +#include +#endif // __ARM_NEON + +#include "arm_usability.h" + +namespace ncnn { + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +static void layernorm_fp16s(__fp16* ptr, const float* gamma_ptr, const float* beta_ptr, float eps, int elemcount, int elempack) +{ + const int size = elemcount * elempack; + + float32x4_t _mean0 = vdupq_n_f32(0.f); + float32x4_t _mean1 = vdupq_n_f32(0.f); + float mean = 0.f; + { + const __fp16* ptr0 = ptr; + + int i = 0; + for (; i + 7 < size; i += 8) + { + float16x8_t _p = vld1q_f16(ptr0); + float32x4_t _p0 = vcvt_f32_f16(vget_low_f16(_p)); + float32x4_t _p1 = vcvt_f32_f16(vget_high_f16(_p)); + _mean0 = vaddq_f32(_mean0, _p0); + _mean1 = vaddq_f32(_mean1, _p1); + ptr0 += 8; + } + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vcvt_f32_f16(vld1_f16(ptr0)); + _mean0 = vaddq_f32(_mean0, _p); + ptr0 += 4; + } + for (; i < size; i++) + { + mean += (float)ptr0[0]; + ptr0++; + } + } + + if (elempack == 8) + { + float32x4_t _elemcount = vdupq_n_f32(elemcount); + _mean0 = vdivq_f32(_mean0, _elemcount); + _mean1 = vdivq_f32(_mean1, _elemcount); + } + if (elempack == 4) + { + _mean0 = vaddq_f32(_mean0, _mean1); + + float32x4_t _elemcount = vdupq_n_f32(elemcount); + _mean0 = vdivq_f32(_mean0, _elemcount); + _mean1 = _mean0; + } + if (elempack == 1) + { + _mean0 = vaddq_f32(_mean0, _mean1); + mean += vaddvq_f32(_mean0); + + mean = mean / elemcount; + _mean0 = vdupq_n_f32(mean); + _mean1 = _mean0; + } + + float32x4_t _var0 = vdupq_n_f32(0.f); + float32x4_t _var1 = vdupq_n_f32(0.f); + float var = 0.f; + { + const __fp16* ptr0 = ptr; + + int i = 0; + for (; i + 7 < size; i += 8) + { + float16x8_t _p = vld1q_f16(ptr0); + float32x4_t _p0 = vcvt_f32_f16(vget_low_f16(_p)); + float32x4_t _p1 = vcvt_f32_f16(vget_high_f16(_p)); + _p0 = vsubq_f32(_p0, _mean0); + _p1 = vsubq_f32(_p1, _mean1); + _var0 = vmlaq_f32(_var0, _p0, _p0); + _var1 = vmlaq_f32(_var1, _p1, _p1); + ptr0 += 8; + } + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vcvt_f32_f16(vld1_f16(ptr0)); + _p = vsubq_f32(_p, _mean0); + _var0 = vmlaq_f32(_var0, _p, _p); + ptr0 += 4; + } + for (; i < size; i++) + { + float v = (float)ptr0[0] - mean; + var += v * v; + ptr0++; + } + } + + if (elempack == 8) + { + float32x4_t _elemcount = vdupq_n_f32(elemcount); + float32x4_t _eps = vdupq_n_f32(eps); + _var0 = vdivq_f32(_var0, _elemcount); + _var1 = vdivq_f32(_var1, _elemcount); + _var0 = vaddq_f32(_var0, _eps); + _var1 = vaddq_f32(_var1, _eps); + float32x4_t _rsqrt_var0 = vrsqrteq_f32(_var0); + float32x4_t _rsqrt_var1 = vrsqrteq_f32(_var1); + _rsqrt_var0 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_var0, _rsqrt_var0), _rsqrt_var0), _rsqrt_var0); + _rsqrt_var1 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_var1, _rsqrt_var1), _rsqrt_var1), _rsqrt_var1); + _var0 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_var0, _rsqrt_var0), _rsqrt_var0), _rsqrt_var0); + _var1 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_var1, _rsqrt_var1), _rsqrt_var1), _rsqrt_var1); + _mean0 = vmulq_f32(_mean0, _var0); + _mean1 = vmulq_f32(_mean1, _var1); + _mean0 = vnegq_f32(_mean0); + _mean1 = vnegq_f32(_mean1); + } + if (elempack == 4) + { + _var0 = vaddq_f32(_var0, _var1); + + float32x4_t _elemcount = vdupq_n_f32(elemcount); + float32x4_t _eps = vdupq_n_f32(eps); + _var0 = vdivq_f32(_var0, _elemcount); + _var0 = vaddq_f32(_var0, _eps); + float32x4_t _rsqrt_var = vrsqrteq_f32(_var0); + _rsqrt_var = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_var0, _rsqrt_var), _rsqrt_var), _rsqrt_var); + _var0 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_var0, _rsqrt_var), _rsqrt_var), _rsqrt_var); + _var1 = _var0; + _mean0 = vmulq_f32(_mean0, _var0); + _mean0 = vnegq_f32(_mean0); + _mean1 = _mean0; + } + if (elempack == 1) + { + _var0 = vaddq_f32(_var0, _var1); + var += vaddvq_f32(_var0); + + var = 1.f / sqrtf(var / elemcount + eps); + mean = -mean * var; + _var0 = vdupq_n_f32(var); + _var1 = _var0; + _mean0 = vdupq_n_f32(mean); + _mean1 = _mean0; + } + + if (gamma_ptr && beta_ptr) + { + int i = 0; + if (elempack == 8) + { + for (; i + 7 < size; i += 8) + { + float16x8_t _p = vld1q_f16(ptr); + float32x4_t _p0 = vcvt_f32_f16(vget_low_f16(_p)); + float32x4_t _p1 = vcvt_f32_f16(vget_high_f16(_p)); + float32x4_t _gamma = vdupq_n_f32(gamma_ptr[0]); + float32x4_t _beta = vdupq_n_f32(beta_ptr[0]); + _p0 = vmlaq_f32(_mean0, _p0, _var0); + _p1 = vmlaq_f32(_mean1, _p1, _var1); + _p0 = vmlaq_f32(_beta, _p0, _gamma); + _p1 = vmlaq_f32(_beta, _p1, _gamma); + _p = vcombine_f16(vcvt_f16_f32(_p0), vcvt_f16_f32(_p1)); + vst1q_f16(ptr, _p); + ptr += 8; + gamma_ptr += 1; + beta_ptr += 1; + } + } + if (elempack == 4) + { + for (; i + 7 < size; i += 8) + { + float16x8_t _p = vld1q_f16(ptr); + float32x4_t _p0 = vcvt_f32_f16(vget_low_f16(_p)); + float32x4_t _p1 = vcvt_f32_f16(vget_high_f16(_p)); + float32x4_t _gamma0 = vdupq_n_f32(gamma_ptr[0]); + float32x4_t _gamma1 = vdupq_n_f32(gamma_ptr[1]); + float32x4_t _beta0 = vdupq_n_f32(beta_ptr[0]); + float32x4_t _beta1 = vdupq_n_f32(beta_ptr[1]); + _p0 = vmlaq_f32(_mean0, _p0, _var0); + _p1 = vmlaq_f32(_mean1, _p1, _var1); + _p0 = vmlaq_f32(_beta0, _p0, _gamma0); + _p1 = vmlaq_f32(_beta1, _p1, _gamma1); + _p = vcombine_f16(vcvt_f16_f32(_p0), vcvt_f16_f32(_p1)); + vst1q_f16(ptr, _p); + ptr += 8; + gamma_ptr += 2; + beta_ptr += 2; + } + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vcvt_f32_f16(vld1_f16(ptr)); + float32x4_t _gamma = vdupq_n_f32(gamma_ptr[0]); + float32x4_t _beta = vdupq_n_f32(beta_ptr[0]); + _p = vmlaq_f32(_mean0, _p, _var0); + _p = vmlaq_f32(_beta, _p, _gamma); + vst1_f16(ptr, vcvt_f16_f32(_p)); + ptr += 4; + gamma_ptr += 1; + beta_ptr += 1; + } + } + if (elempack == 1) + { + for (; i + 7 < size; i += 8) + { + float16x8_t _p = vld1q_f16(ptr); + float32x4_t _p0 = vcvt_f32_f16(vget_low_f16(_p)); + float32x4_t _p1 = vcvt_f32_f16(vget_high_f16(_p)); + float32x4_t _gamma0 = vld1q_f32(gamma_ptr); + float32x4_t _gamma1 = vld1q_f32(gamma_ptr + 4); + float32x4_t _beta0 = vld1q_f32(beta_ptr); + float32x4_t _beta1 = vld1q_f32(beta_ptr + 4); + _p0 = vmlaq_f32(_mean0, _p0, _var0); + _p1 = vmlaq_f32(_mean1, _p1, _var1); + _p0 = vmlaq_f32(_beta0, _p0, _gamma0); + _p1 = vmlaq_f32(_beta1, _p1, _gamma1); + _p = vcombine_f16(vcvt_f16_f32(_p0), vcvt_f16_f32(_p1)); + vst1q_f16(ptr, _p); + ptr += 8; + gamma_ptr += 8; + beta_ptr += 8; + } + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vcvt_f32_f16(vld1_f16(ptr)); + float32x4_t _gamma = vld1q_f32(gamma_ptr); + float32x4_t _beta = vld1q_f32(beta_ptr); + _p = vmlaq_f32(_mean0, _p, _var0); + _p = vmlaq_f32(_beta, _p, _gamma); + vst1_f16(ptr, vcvt_f16_f32(_p)); + ptr += 4; + gamma_ptr += 4; + beta_ptr += 4; + } + } + for (; i < size; i++) + { + float v = (float)ptr[0]; + ptr[0] = (__fp16)((v * var + mean) * gamma_ptr[0] + beta_ptr[0]); + ptr++; + gamma_ptr++; + beta_ptr++; + } + } + else + { + int i = 0; + for (; i + 7 < size; i += 8) + { + float16x8_t _p = vld1q_f16(ptr); + float32x4_t _p0 = vcvt_f32_f16(vget_low_f16(_p)); + float32x4_t _p1 = vcvt_f32_f16(vget_high_f16(_p)); + _p0 = vmlaq_f32(_mean0, _p0, _var0); + _p1 = vmlaq_f32(_mean1, _p1, _var1); + _p = vcombine_f16(vcvt_f16_f32(_p0), vcvt_f16_f32(_p1)); + vst1q_f16(ptr, _p); + ptr += 8; + } + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vcvt_f32_f16(vld1_f16(ptr)); + _p = vmlaq_f32(_mean0, _p, _var0); + vst1_f16(ptr, vcvt_f16_f32(_p)); + ptr += 4; + } + for (; i < size; i++) + { + float v = (float)ptr[0]; + ptr[0] = (__fp16)(v * var + mean); + ptr++; + } + } +} + +int LayerNorm_arm::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) const +{ + const int dims = bottom_top_blob.dims; + const int w = bottom_top_blob.w; + const int h = bottom_top_blob.h; + const int channels = bottom_top_blob.c; + const int elempack = bottom_top_blob.elempack; + + if (dims == 1) + { + // assert affine_size == w + + __fp16* ptr = bottom_top_blob; + layernorm_fp16s(ptr, gamma_data, beta_data, eps, w * elempack, 1); + } + + if (dims == 2) + { + // assert affine_size == w + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + __fp16* ptr = bottom_top_blob.row<__fp16>(i); + layernorm_fp16s(ptr, gamma_data, beta_data, eps, w, elempack); + } + } + + if (dims == 3) + { + if (affine_size == w) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + for (int i = 0; i < h; i++) + { + __fp16* ptr = bottom_top_blob.channel(q).row<__fp16>(i); + layernorm_fp16s(ptr, gamma_data, beta_data, eps, w, elempack); + } + } + } + else // if (affine_size == w * h) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + __fp16* ptr = bottom_top_blob.channel(q); + layernorm_fp16s(ptr, gamma_data, beta_data, eps, w * h, elempack); + } + } + } + + return 0; +} +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +} // namespace ncnn diff --git a/src/net.cpp b/src/net.cpp index 32b5b2abd601..904e14cb2f7d 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -707,7 +707,7 @@ int NetPrivate::convert_layout(Mat& bottom_blob, const Layer* layer, const Optio if (elembits == 16) { #if NCNN_ARM82 - if (elemcount % 8 == 0 && ncnn::cpu_support_arm_asimdhp() && opt.use_fp16_arithmetic) + if (elemcount % 8 == 0 && ncnn::cpu_support_arm_asimdhp() && opt.use_fp16_arithmetic && layer->support_fp16_storage) dst_elempack = 8; else if (elemcount % 4 == 0) dst_elempack = 4; diff --git a/tests/testutil.cpp b/tests/testutil.cpp index 837043cb754c..ffc12bccfa3c 100644 --- a/tests/testutil.cpp +++ b/tests/testutil.cpp @@ -406,7 +406,7 @@ static int convert_to_optimal_layout(const ncnn::Mat& a, ncnn::Mat& a4, const nc if (elembits == 16) { #if NCNN_ARM82 - if (elemcount % 8 == 0 && ncnn::cpu_support_arm_asimdhp() && opt.use_fp16_arithmetic) + if (elemcount % 8 == 0 && ncnn::cpu_support_arm_asimdhp() && opt.use_fp16_arithmetic && op->support_fp16_storage) dst_elempack = 8; else if (elemcount % 4 == 0) dst_elempack = 4; From c1f9e959f546879b0cb8a00134fb06e7ad4faf2f Mon Sep 17 00:00:00 2001 From: nihui Date: Mon, 21 Oct 2024 16:34:07 +0800 Subject: [PATCH 10/15] pnnx torch 2.5 (#5748) --- .ci/pnnx.yml | 12 +++-- .../F_scaled_dot_product_attention.cpp | 45 +++++++++++++++++++ tools/pnnx/tests/ncnn/test_F_layer_norm.py | 2 +- tools/pnnx/tests/ncnn/test_nn_LayerNorm.py | 2 +- tools/pnnx/tests/onnx/test_F_relu.py | 2 +- tools/pnnx/tests/onnx/test_convnext_tiny.py | 2 +- tools/pnnx/tests/onnx/test_mobilenet_v2.py | 2 +- .../tests/onnx/test_mobilenet_v3_small.py | 2 +- tools/pnnx/tests/onnx/test_nn_ReLU.py | 2 +- tools/pnnx/tests/onnx/test_resnet18.py | 2 +- .../tests/onnx/test_shufflenet_v2_x1_0.py | 2 +- tools/pnnx/tests/onnx/test_squeezenet1_1.py | 2 +- tools/pnnx/tests/onnx/test_swin_t.py | 2 +- tools/pnnx/tests/onnx/test_vit_b_32.py | 2 +- 14 files changed, 65 insertions(+), 16 deletions(-) diff --git a/.ci/pnnx.yml b/.ci/pnnx.yml index d49da39a0afc..207d78c4e2d2 100644 --- a/.ci/pnnx.yml +++ b/.ci/pnnx.yml @@ -19,10 +19,10 @@ concurrency: variables: protobuf_version: 21.12 - libtorch_version: 2.4.0 - libtorchvision_version: 0.19.0 - onnxruntime_version: 1.18.1 - cache_date: 20240804 + libtorch_version: 2.5.0 + libtorchvision_version: 0.20.0 + onnxruntime_version: 1.19.2 + cache_date: 20241018 jobs: ubuntu: @@ -62,6 +62,9 @@ jobs: - torch-version: 2.4.0 torchvision-version: 0.19.0 + - torch-version: 2.5.0 + torchvision-version: 0.20.0 + runs-on: pool-name: docker container: @@ -157,6 +160,7 @@ jobs: cd onnxruntime-${{variables.onnxruntime_version}} patch -p1 -i ${{ci.workspace}}/pnnx-patches/onnxruntime-${{variables.onnxruntime_version}}-less-mlas-features.patch patch -p1 -i ${{ci.workspace}}/pnnx-patches/onnxruntime-${{variables.onnxruntime_version}}-monolithic-static-library.patch + patch -p1 -i ${{ci.workspace}}/pnnx-patches/onnxruntime-${{variables.onnxruntime_version}}-fix-gcc-avxvnni-check.patch mkdir -p build && cd build cmake -DCMAKE_INSTALL_PREFIX=${{ci.workspace}}/pnnx-deps-onnx-install -DCMAKE_BUILD_TYPE=MinSizeRel -Donnxruntime_USE_FULL_PROTOBUF=ON -Donnxruntime_BUILD_SHARED_LIB=ON -Donnxruntime_BUILD_UNIT_TESTS=OFF -Donnxruntime_ENABLE_CPUINFO=OFF -Donnxruntime_DISABLE_CONTRIB_OPS=ON -Donnxruntime_DISABLE_ML_OPS=ON -Donnxruntime_DISABLE_SPARSE_TENSORS=ON --compile-no-warning-as-error ../cmake cmake --build . -j $(nproc) diff --git a/tools/pnnx/src/pass_level2/F_scaled_dot_product_attention.cpp b/tools/pnnx/src/pass_level2/F_scaled_dot_product_attention.cpp index 9fba1e770cc5..bb11aad3d0ca 100644 --- a/tools/pnnx/src/pass_level2/F_scaled_dot_product_attention.cpp +++ b/tools/pnnx/src/pass_level2/F_scaled_dot_product_attention.cpp @@ -80,6 +80,51 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention_1, 10) +class F_scaled_dot_product_attention_2 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +10 9 +pnnx.Input input_0 0 1 query +pnnx.Input input_1 0 1 key +pnnx.Input input_2 0 1 value +pnnx.Input input_3 0 1 attn_mask +prim::Constant op_0 0 1 dropout_p value=%dropout_p +prim::Constant op_1 0 1 is_causal value=%is_causal +prim::Constant op_2 0 1 scale value=%scale +prim::Constant op_3 0 1 enable_gqa value=%enable_gqa +aten::scaled_dot_product_attention op_4 8 1 query key value attn_mask dropout_p is_causal scale enable_gqa out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.scaled_dot_product_attention"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + GraphRewriterPass::write(op, captured_params, captured_attrs); + + if (captured_params.at("scale").type == 0) + { + // drop scale=None for compatibility with old torch + op->params.erase("scale"); + } + + if (captured_params.at("enable_gqa").type == 1 && captured_params.at("enable_gqa").b == false) + { + // drop enable_gqa=False for compatibility with old torch + op->params.erase("enable_gqa"); + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention_2, 10) + static bool NearlyEqual(float a, float b, float epsilon) { if (a == b) diff --git a/tools/pnnx/tests/ncnn/test_F_layer_norm.py b/tools/pnnx/tests/ncnn/test_F_layer_norm.py index 9d590aa76dda..7815e4e687b5 100644 --- a/tools/pnnx/tests/ncnn/test_F_layer_norm.py +++ b/tools/pnnx/tests/ncnn/test_F_layer_norm.py @@ -55,7 +55,7 @@ def test(): b = test_F_layer_norm_ncnn.test_inference() for a0, b0 in zip(a, b): - if not torch.allclose(a0, b0, 1e-4, 1e-4): + if not torch.allclose(a0, b0, 1e-3, 1e-3): return False return True diff --git a/tools/pnnx/tests/ncnn/test_nn_LayerNorm.py b/tools/pnnx/tests/ncnn/test_nn_LayerNorm.py index d409bdfba3a1..672142208ef7 100644 --- a/tools/pnnx/tests/ncnn/test_nn_LayerNorm.py +++ b/tools/pnnx/tests/ncnn/test_nn_LayerNorm.py @@ -54,7 +54,7 @@ def test(): b = test_nn_LayerNorm_ncnn.test_inference() for a0, b0 in zip(a, b): - if not torch.allclose(a0, b0, 1e-4, 1e-4): + if not torch.allclose(a0, b0, 1e-3, 1e-3): return False return True diff --git a/tools/pnnx/tests/onnx/test_F_relu.py b/tools/pnnx/tests/onnx/test_F_relu.py index 0bb08d6920bb..f980cc081a38 100644 --- a/tools/pnnx/tests/onnx/test_F_relu.py +++ b/tools/pnnx/tests/onnx/test_F_relu.py @@ -59,7 +59,7 @@ def test(): if not torch.allclose(a0, b0, 1e-4, 1e-4): return False - if version.parse(torch.__version__) < version.parse('2.3'): + if version.parse(torch.__version__) < version.parse('2.6'): return True # export dynamo onnx diff --git a/tools/pnnx/tests/onnx/test_convnext_tiny.py b/tools/pnnx/tests/onnx/test_convnext_tiny.py index 530ee8eb5f8f..e28494dbe103 100644 --- a/tools/pnnx/tests/onnx/test_convnext_tiny.py +++ b/tools/pnnx/tests/onnx/test_convnext_tiny.py @@ -43,7 +43,7 @@ def test(): if not torch.allclose(a, b, 1e-4, 1e-4): return False - if version.parse(torch.__version__) < version.parse('2.4'): + if version.parse(torch.__version__) < version.parse('2.6'): return True # export dynamo onnx diff --git a/tools/pnnx/tests/onnx/test_mobilenet_v2.py b/tools/pnnx/tests/onnx/test_mobilenet_v2.py index b3e0648002bf..add698ad1f77 100644 --- a/tools/pnnx/tests/onnx/test_mobilenet_v2.py +++ b/tools/pnnx/tests/onnx/test_mobilenet_v2.py @@ -39,7 +39,7 @@ def test(): if not torch.allclose(a, b, 1e-4, 1e-4): return False - if version.parse(torch.__version__) < version.parse('2.4'): + if version.parse(torch.__version__) < version.parse('2.6'): return True # export dynamo onnx diff --git a/tools/pnnx/tests/onnx/test_mobilenet_v3_small.py b/tools/pnnx/tests/onnx/test_mobilenet_v3_small.py index 38a638668aee..32827d5ffa25 100644 --- a/tools/pnnx/tests/onnx/test_mobilenet_v3_small.py +++ b/tools/pnnx/tests/onnx/test_mobilenet_v3_small.py @@ -42,7 +42,7 @@ def test(): if not torch.allclose(a, b, 1e-4, 1e-4): return False - if version.parse(torch.__version__) < version.parse('2.4'): + if version.parse(torch.__version__) < version.parse('2.6'): return True # export dynamo onnx diff --git a/tools/pnnx/tests/onnx/test_nn_ReLU.py b/tools/pnnx/tests/onnx/test_nn_ReLU.py index 8230e3f4827a..a84145229426 100644 --- a/tools/pnnx/tests/onnx/test_nn_ReLU.py +++ b/tools/pnnx/tests/onnx/test_nn_ReLU.py @@ -61,7 +61,7 @@ def test(): if not torch.allclose(a0, b0, 1e-4, 1e-4): return False - if version.parse(torch.__version__) < version.parse('2.5'): + if version.parse(torch.__version__) < version.parse('2.6'): return True # export dynamo onnx diff --git a/tools/pnnx/tests/onnx/test_resnet18.py b/tools/pnnx/tests/onnx/test_resnet18.py index 57de5d1bdb65..583f88ce198f 100644 --- a/tools/pnnx/tests/onnx/test_resnet18.py +++ b/tools/pnnx/tests/onnx/test_resnet18.py @@ -39,7 +39,7 @@ def test(): if not torch.allclose(a, b, 1e-4, 1e-4): return False - if version.parse(torch.__version__) < version.parse('2.4'): + if version.parse(torch.__version__) < version.parse('2.6'): return True # export dynamo onnx diff --git a/tools/pnnx/tests/onnx/test_shufflenet_v2_x1_0.py b/tools/pnnx/tests/onnx/test_shufflenet_v2_x1_0.py index ad566a1c1c0d..4b498f67b613 100644 --- a/tools/pnnx/tests/onnx/test_shufflenet_v2_x1_0.py +++ b/tools/pnnx/tests/onnx/test_shufflenet_v2_x1_0.py @@ -39,7 +39,7 @@ def test(): if not torch.allclose(a, b, 1e-4, 1e-4): return False - if version.parse(torch.__version__) < version.parse('2.4'): + if version.parse(torch.__version__) < version.parse('2.6'): return True # export dynamo onnx diff --git a/tools/pnnx/tests/onnx/test_squeezenet1_1.py b/tools/pnnx/tests/onnx/test_squeezenet1_1.py index 28c7df8fb81e..4e9683da48d8 100644 --- a/tools/pnnx/tests/onnx/test_squeezenet1_1.py +++ b/tools/pnnx/tests/onnx/test_squeezenet1_1.py @@ -39,7 +39,7 @@ def test(): if not torch.allclose(a, b, 1e-4, 1e-4): return False - if version.parse(torch.__version__) < version.parse('2.5'): + if version.parse(torch.__version__) < version.parse('2.6'): return True # export dynamo onnx diff --git a/tools/pnnx/tests/onnx/test_swin_t.py b/tools/pnnx/tests/onnx/test_swin_t.py index 6361d20c9116..e78855d41540 100644 --- a/tools/pnnx/tests/onnx/test_swin_t.py +++ b/tools/pnnx/tests/onnx/test_swin_t.py @@ -43,7 +43,7 @@ def test(): if not torch.allclose(a, b, 1e-4, 1e-4): return False - if version.parse(torch.__version__) < version.parse('2.5'): + if version.parse(torch.__version__) < version.parse('2.6'): return True # export dynamo onnx diff --git a/tools/pnnx/tests/onnx/test_vit_b_32.py b/tools/pnnx/tests/onnx/test_vit_b_32.py index 3c92a119406a..678c0e43230c 100644 --- a/tools/pnnx/tests/onnx/test_vit_b_32.py +++ b/tools/pnnx/tests/onnx/test_vit_b_32.py @@ -46,7 +46,7 @@ def test(): if not torch.allclose(a, b, 1e-4, 1e-4): return False - if version.parse(torch.__version__) < version.parse('2.5'): + if version.parse(torch.__version__) < version.parse('2.6'): return True # export dynamo onnx From e7602a206bd2511791e72dbd89f731f305abe9b0 Mon Sep 17 00:00:00 2001 From: nihui Date: Mon, 21 Oct 2024 19:09:43 +0800 Subject: [PATCH 11/15] fix gemm arm int8 scales descales offset (#5750) --- src/layer/arm/gemm_arm.cpp | 1 - src/layer/arm/gemm_int8.h | 60 ++++++++++++++--------------- src/layer/arm/gemm_int8_bf16s.h | 60 ++++++++++++++--------------- src/layer/arm/gemm_int8_fp16s.h | 60 ++++++++++++++--------------- tests/test_multiheadattention_1.cpp | 6 +-- 5 files changed, 93 insertions(+), 94 deletions(-) diff --git a/src/layer/arm/gemm_arm.cpp b/src/layer/arm/gemm_arm.cpp index 7607d8f523e5..09f25869f43d 100644 --- a/src/layer/arm/gemm_arm.cpp +++ b/src/layer/arm/gemm_arm.cpp @@ -4404,7 +4404,6 @@ int Gemm_arm::forward(const std::vector& bottom_blobs, std::vector& to if (int8_scale_term) { return forward_int8(bottom_blobs, top_blobs, opt); - // return Gemm::forward_int8(bottom_blobs, top_blobs, opt); } #endif diff --git a/src/layer/arm/gemm_int8.h b/src/layer/arm/gemm_int8.h index 68688c863102..652f300b4fd5 100644 --- a/src/layer/arm/gemm_int8.h +++ b/src/layer/arm/gemm_int8.h @@ -1724,8 +1724,8 @@ static void compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales, float B_s const float v127_B_scale = 127.f * B_scale; - float* ps = scales; - float* pods = out_descales; + float* ps = (float*)scales + i; + float* pods = (float*)out_descales + i; #if __ARM_NEON if (elempack == 4) @@ -1897,8 +1897,8 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i { const float* p0 = (const float*)A + (i + ii) * A_hstep + k * elempack; - float32x4_t _scale0 = vld1q_f32((const float*)scales + ii); - float32x4_t _scale1 = vld1q_f32((const float*)scales + ii + 4); + float32x4_t _scale0 = vld1q_f32((const float*)scales + i + ii); + float32x4_t _scale1 = vld1q_f32((const float*)scales + i + ii + 4); if (elempack == 4) { @@ -2314,7 +2314,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i { const float* p0 = (const float*)A + (i + ii) * A_hstep + k * elempack; - float32x4_t _scale = vld1q_f32((const float*)scales + ii); + float32x4_t _scale = vld1q_f32((const float*)scales + i + ii); if (elempack == 4) { @@ -2592,8 +2592,8 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i { const float* p0 = (const float*)A + (i + ii) * A_hstep + k; - const float scale0 = scales[ii]; - const float scale1 = scales[ii + 1]; + const float scale0 = scales[i + ii]; + const float scale1 = scales[i + ii + 1]; // if (elempack == 1) { @@ -2680,7 +2680,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i { const float* p0 = (const float*)A + (i + ii) * A_hstep + k; - const float scale = scales[ii]; + const float scale = scales[i + ii]; // if (elempack == 1) { @@ -2750,8 +2750,8 @@ static void transpose_compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales, #endif #endif - float* ps = scales; - float* pods = out_descales; + float* ps = (float*)scales + i; + float* pods = (float*)out_descales + i; #if __ARM_NEON if (elempack == 4) @@ -3055,8 +3055,8 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int { const float* p0 = (const float*)A + k * A_hstep + (i + ii) * elempack; - float32x4_t _scale0 = vld1q_f32((const float*)scales + ii); - float32x4_t _scale1 = vld1q_f32((const float*)scales + ii + 4); + float32x4_t _scale0 = vld1q_f32((const float*)scales + i + ii); + float32x4_t _scale1 = vld1q_f32((const float*)scales + i + ii + 4); if (elempack == 4) { @@ -3396,7 +3396,7 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int { const float* p0 = (const float*)A + k * A_hstep + (i + ii) * elempack; - float32x4_t _scale = vld1q_f32((const float*)scales + ii); + float32x4_t _scale = vld1q_f32((const float*)scales + i + ii); if (elempack == 4) { @@ -3622,8 +3622,8 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int { const float* p0 = (const float*)A + k * A_hstep + (i + ii) * elempack; - const float scale0 = scales[ii]; - const float scale1 = scales[ii + 1]; + const float scale0 = scales[i + ii]; + const float scale1 = scales[i + ii + 1]; #if __ARM_NEON float32x4_t _scale0 = vdupq_n_f32(scale0); @@ -3805,7 +3805,7 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int { const float* p0 = (const float*)A + k * A_hstep + (i + ii) * elempack; - const float scale = scales[ii]; + const float scale = scales[i + ii]; #if __ARM_NEON float32x4_t _scale = vdupq_n_f32(scale); @@ -5646,8 +5646,8 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& { float* p0 = (float*)top_blob + (i + ii) * out_hstep + j * out_elempack; - float32x4_t _descale0 = vld1q_f32((const float*)descales + ii); - float32x4_t _descale1 = vld1q_f32((const float*)descales + ii + 4); + float32x4_t _descale0 = vld1q_f32((const float*)descales + i + ii); + float32x4_t _descale1 = vld1q_f32((const float*)descales + i + ii + 4); float32x4_t _c0; float32x4_t _c1; @@ -6593,7 +6593,7 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& { float* p0 = (float*)top_blob + (i + ii) * out_hstep + j * out_elempack; - float32x4_t _descale = vld1q_f32((const float*)descales + ii); + float32x4_t _descale = vld1q_f32((const float*)descales + i + ii); float32x4_t _c0; if (pC) @@ -7181,10 +7181,10 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& // out_elempack == 1 float* p0 = (float*)top_blob + (i + ii) * out_hstep + j; - const float descale0 = descales[ii]; - const float descale1 = descales[ii + 1]; + const float descale0 = descales[i + ii]; + const float descale1 = descales[i + ii + 1]; #if __ARM_NEON - float32x2_t _descale = vld1_f32((const float*)descales + ii); + float32x2_t _descale = vld1_f32((const float*)descales + i + ii); #endif float c0; @@ -7467,7 +7467,7 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& // out_elempack == 1 float* p0 = (float*)top_blob + (i + ii) * out_hstep + j; - const float descale = descales[ii]; + const float descale = descales[i + ii]; #if __ARM_NEON float32x4_t _descale = vdupq_n_f32(descale); #endif @@ -7726,8 +7726,8 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma { float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * out_elempack; - float32x4_t _descale0 = vld1q_f32((const float*)descales + ii); - float32x4_t _descale1 = vld1q_f32((const float*)descales + ii + 4); + float32x4_t _descale0 = vld1q_f32((const float*)descales + i + ii); + float32x4_t _descale1 = vld1q_f32((const float*)descales + i + ii + 4); float32x4_t _c0; float32x4_t _c1; @@ -8673,7 +8673,7 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma { float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * out_elempack; - float32x4_t _descale = vld1q_f32((const float*)descales + ii); + float32x4_t _descale = vld1q_f32((const float*)descales + i + ii); float32x4_t _c0; if (pC) @@ -9237,10 +9237,10 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma { float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * out_elempack; - const float descale0 = descales[ii]; - const float descale1 = descales[ii + 1]; + const float descale0 = descales[i + ii]; + const float descale1 = descales[i + ii + 1]; #if __ARM_NEON - float32x2_t _descale01 = vld1_f32((const float*)descales + ii); + float32x2_t _descale01 = vld1_f32((const float*)descales + i + ii); #endif float c0; @@ -9556,7 +9556,7 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma { float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * out_elempack; - const float descale = descales[ii]; + const float descale = descales[i + ii]; #if __ARM_NEON float32x4_t _descale = vdupq_n_f32(descale); #endif diff --git a/src/layer/arm/gemm_int8_bf16s.h b/src/layer/arm/gemm_int8_bf16s.h index 350f20ab4c0f..a1ad87d51229 100644 --- a/src/layer/arm/gemm_int8_bf16s.h +++ b/src/layer/arm/gemm_int8_bf16s.h @@ -38,8 +38,8 @@ static void compute_A_tile_bf16_int8_scales(const Mat& A, Mat& scales, float B_s const float v127_B_scale = 127.f * B_scale; - float* ps = scales; - float* pods = out_descales; + float* ps = (float*)scales + i; + float* pods = (float*)out_descales + i; #if __ARM_NEON if (elempack == 4) @@ -217,8 +217,8 @@ static void pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i { const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k * elempack; - float32x4_t _scale0 = vld1q_f32((const float*)scales + ii); - float32x4_t _scale1 = vld1q_f32((const float*)scales + ii + 4); + float32x4_t _scale0 = vld1q_f32((const float*)scales + i + ii); + float32x4_t _scale1 = vld1q_f32((const float*)scales + i + ii + 4); if (elempack == 4) { @@ -665,7 +665,7 @@ static void pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i { const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k * elempack; - float32x4_t _scale = vld1q_f32((const float*)scales + ii); + float32x4_t _scale = vld1q_f32((const float*)scales + i + ii); if (elempack == 4) { @@ -958,8 +958,8 @@ static void pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i { const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k; - const float scale0 = scales[ii]; - const float scale1 = scales[ii + 1]; + const float scale0 = scales[i + ii]; + const float scale1 = scales[i + ii + 1]; // if (elempack == 1) { @@ -1048,7 +1048,7 @@ static void pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i { const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k; - const float scale = scales[ii]; + const float scale = scales[i + ii]; // if (elempack == 1) { @@ -1121,8 +1121,8 @@ static void transpose_compute_A_tile_bf16_int8_scales(const Mat& A, Mat& scales, #endif #endif - float* ps = scales; - float* pods = out_descales; + float* ps = (float*)scales + i; + float* pods = (float*)out_descales + i; #if __ARM_NEON if (elempack == 4) @@ -1362,8 +1362,8 @@ static void transpose_pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int { const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * elempack; - float32x4_t _scale0 = vld1q_f32((const float*)scales + ii); - float32x4_t _scale1 = vld1q_f32((const float*)scales + ii + 4); + float32x4_t _scale0 = vld1q_f32((const float*)scales + i + ii); + float32x4_t _scale1 = vld1q_f32((const float*)scales + i + ii + 4); if (elempack == 4) { @@ -1731,7 +1731,7 @@ static void transpose_pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int { const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * elempack; - float32x4_t _scale = vld1q_f32((const float*)scales + ii); + float32x4_t _scale = vld1q_f32((const float*)scales + i + ii); if (elempack == 4) { @@ -1963,8 +1963,8 @@ static void transpose_pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int { const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * elempack; - const float scale0 = scales[ii]; - const float scale1 = scales[ii + 1]; + const float scale0 = scales[i + ii]; + const float scale1 = scales[i + ii + 1]; #if __ARM_NEON float32x4_t _scale0 = vdupq_n_f32(scale0); @@ -2187,7 +2187,7 @@ static void transpose_pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int { const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * elempack; - const float scale = scales[ii]; + const float scale = scales[i + ii]; #if __ARM_NEON float32x4_t _scale = vdupq_n_f32(scale); @@ -4169,8 +4169,8 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& { unsigned short* p0 = (unsigned short*)top_blob + (i + ii) * out_hstep + j * out_elempack; - float32x4_t _descale0 = vld1q_f32((const float*)descales + ii); - float32x4_t _descale1 = vld1q_f32((const float*)descales + ii + 4); + float32x4_t _descale0 = vld1q_f32((const float*)descales + i + ii); + float32x4_t _descale1 = vld1q_f32((const float*)descales + i + ii + 4); float32x4_t _c0; float32x4_t _c1; @@ -5189,7 +5189,7 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& { unsigned short* p0 = (unsigned short*)top_blob + (i + ii) * out_hstep + j * out_elempack; - float32x4_t _descale = vld1q_f32((const float*)descales + ii); + float32x4_t _descale = vld1q_f32((const float*)descales + i + ii); float32x4_t _c0; if (pC) @@ -5794,10 +5794,10 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& // out_elempack == 1 unsigned short* p0 = (unsigned short*)top_blob + (i + ii) * out_hstep + j; - const float descale0 = descales[ii]; - const float descale1 = descales[ii + 1]; + const float descale0 = descales[i + ii]; + const float descale1 = descales[i + ii + 1]; #if __ARM_NEON - float32x2_t _descale = vld1_f32((const float*)descales + ii); + float32x2_t _descale = vld1_f32((const float*)descales + i + ii); #endif float c0; @@ -6097,7 +6097,7 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& // out_elempack == 1 unsigned short* p0 = (unsigned short*)top_blob + (i + ii) * out_hstep + j; - const float descale = descales[ii]; + const float descale = descales[i + ii]; #if __ARM_NEON float32x4_t _descale = vdupq_n_f32(descale); #endif @@ -6359,8 +6359,8 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma { unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack; - float32x4_t _descale0 = vld1q_f32((const float*)descales + ii); - float32x4_t _descale1 = vld1q_f32((const float*)descales + ii + 4); + float32x4_t _descale0 = vld1q_f32((const float*)descales + i + ii); + float32x4_t _descale1 = vld1q_f32((const float*)descales + i + ii + 4); float32x4_t _c0; float32x4_t _c1; @@ -7318,7 +7318,7 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma { unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack; - float32x4_t _descale = vld1q_f32((const float*)descales + ii); + float32x4_t _descale = vld1q_f32((const float*)descales + i + ii); float32x4_t _c0; if (pC) @@ -7902,10 +7902,10 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma { unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack; - const float descale0 = descales[ii]; - const float descale1 = descales[ii + 1]; + const float descale0 = descales[i + ii]; + const float descale1 = descales[i + ii + 1]; #if __ARM_NEON - float32x2_t _descale01 = vld1_f32((const float*)descales + ii); + float32x2_t _descale01 = vld1_f32((const float*)descales + i + ii); #endif float c0; @@ -8250,7 +8250,7 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma { unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack; - const float descale = descales[ii]; + const float descale = descales[i + ii]; #if __ARM_NEON float32x4_t _descale = vdupq_n_f32(descale); #endif diff --git a/src/layer/arm/gemm_int8_fp16s.h b/src/layer/arm/gemm_int8_fp16s.h index e096a6caf6f6..0ea6c389c8b2 100644 --- a/src/layer/arm/gemm_int8_fp16s.h +++ b/src/layer/arm/gemm_int8_fp16s.h @@ -52,8 +52,8 @@ static void compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, float B_s const float v127_B_scale = 127.f * B_scale; - float* ps = scales; - float* pods = out_descales; + float* ps = (float*)scales + i; + float* pods = (float*)out_descales + i; #if __ARM_NEON #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -390,8 +390,8 @@ static void pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i { const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k * elempack; - float32x4_t _scale0 = vld1q_f32((const float*)scales + ii); - float32x4_t _scale1 = vld1q_f32((const float*)scales + ii + 4); + float32x4_t _scale0 = vld1q_f32((const float*)scales + i + ii); + float32x4_t _scale1 = vld1q_f32((const float*)scales + i + ii + 4); #if __aarch64__ if (elempack == 8) @@ -1007,7 +1007,7 @@ static void pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i { const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k * elempack; - float32x4_t _scale = vld1q_f32((const float*)scales + ii); + float32x4_t _scale = vld1q_f32((const float*)scales + i + ii); if (elempack == 4) { @@ -1300,8 +1300,8 @@ static void pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i { const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k; - const float scale0 = scales[ii]; - const float scale1 = scales[ii + 1]; + const float scale0 = scales[i + ii]; + const float scale1 = scales[i + ii + 1]; // if (elempack == 1) { @@ -1390,7 +1390,7 @@ static void pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i { const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k; - const float scale = scales[ii]; + const float scale = scales[i + ii]; // if (elempack == 1) { @@ -1471,8 +1471,8 @@ static void transpose_compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, #endif #endif - float* ps = scales; - float* pods = out_descales; + float* ps = (float*)scales + i; + float* pods = (float*)out_descales + i; #if __ARM_NEON #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -2035,8 +2035,8 @@ static void transpose_pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int { const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * elempack; - float32x4_t _scale0 = vld1q_f32((const float*)scales + ii); - float32x4_t _scale1 = vld1q_f32((const float*)scales + ii + 4); + float32x4_t _scale0 = vld1q_f32((const float*)scales + i + ii); + float32x4_t _scale1 = vld1q_f32((const float*)scales + i + ii + 4); #if __aarch64__ if (elempack == 8) @@ -2510,7 +2510,7 @@ static void transpose_pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int { const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * elempack; - float32x4_t _scale = vld1q_f32((const float*)scales + ii); + float32x4_t _scale = vld1q_f32((const float*)scales + i + ii); #if __aarch64__ if (elempack == 8) @@ -2803,8 +2803,8 @@ static void transpose_pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int { const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * elempack; - const float scale0 = scales[ii]; - const float scale1 = scales[ii + 1]; + const float scale0 = scales[i + ii]; + const float scale1 = scales[i + ii + 1]; #if __ARM_NEON float32x4_t _scale0 = vdupq_n_f32(scale0); @@ -3068,7 +3068,7 @@ static void transpose_pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int { const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * elempack; - const float scale = scales[ii]; + const float scale = scales[i + ii]; #if __ARM_NEON float32x4_t _scale = vdupq_n_f32(scale); @@ -5605,8 +5605,8 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& { unsigned short* p0 = (unsigned short*)top_blob + (i + ii) * out_hstep + j * out_elempack; - float32x4_t _descale0 = vld1q_f32((const float*)descales + ii); - float32x4_t _descale1 = vld1q_f32((const float*)descales + ii + 4); + float32x4_t _descale0 = vld1q_f32((const float*)descales + i + ii); + float32x4_t _descale1 = vld1q_f32((const float*)descales + i + ii + 4); float32x4_t _c0; float32x4_t _c1; @@ -6813,7 +6813,7 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& { unsigned short* p0 = (unsigned short*)top_blob + (i + ii) * out_hstep + j * out_elempack; - float32x4_t _descale = vld1q_f32((const float*)descales + ii); + float32x4_t _descale = vld1q_f32((const float*)descales + i + ii); float32x4_t _c0; if (pC) @@ -7418,10 +7418,10 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& // out_elempack == 1 unsigned short* p0 = (unsigned short*)top_blob + (i + ii) * out_hstep + j; - const float descale0 = descales[ii]; - const float descale1 = descales[ii + 1]; + const float descale0 = descales[i + ii]; + const float descale1 = descales[i + ii + 1]; #if __ARM_NEON - float32x2_t _descale = vld1_f32((const float*)descales + ii); + float32x2_t _descale = vld1_f32((const float*)descales + i + ii); #endif float c0; @@ -7721,7 +7721,7 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& // out_elempack == 1 unsigned short* p0 = (unsigned short*)top_blob + (i + ii) * out_hstep + j; - const float descale = descales[ii]; + const float descale = descales[i + ii]; #if __ARM_NEON float32x4_t _descale = vdupq_n_f32(descale); #endif @@ -7983,8 +7983,8 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma { unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack; - float32x4_t _descale0 = vld1q_f32((const float*)descales + ii); - float32x4_t _descale1 = vld1q_f32((const float*)descales + ii + 4); + float32x4_t _descale0 = vld1q_f32((const float*)descales + i + ii); + float32x4_t _descale1 = vld1q_f32((const float*)descales + i + ii + 4); float32x4_t _c0; float32x4_t _c1; @@ -9088,7 +9088,7 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma { unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack; - float32x4_t _descale = vld1q_f32((const float*)descales + ii); + float32x4_t _descale = vld1q_f32((const float*)descales + i + ii); float32x4_t _c0; if (pC) @@ -9683,10 +9683,10 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma { unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack; - const float descale0 = descales[ii]; - const float descale1 = descales[ii + 1]; + const float descale0 = descales[i + ii]; + const float descale1 = descales[i + ii + 1]; #if __ARM_NEON - float32x2_t _descale01 = vld1_f32((const float*)descales + ii); + float32x2_t _descale01 = vld1_f32((const float*)descales + i + ii); #endif float c0; @@ -10038,7 +10038,7 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma { unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack; - const float descale = descales[ii]; + const float descale = descales[i + ii]; #if __ARM_NEON float32x4_t _descale = vdupq_n_f32(descale); #endif diff --git a/tests/test_multiheadattention_1.cpp b/tests/test_multiheadattention_1.cpp index c29930a0be8c..7039b19cc3cb 100644 --- a/tests/test_multiheadattention_1.cpp +++ b/tests/test_multiheadattention_1.cpp @@ -55,7 +55,7 @@ static int test_multiheadattention_int8(const ncnn::Mat& q, const ncnn::Mat& k, as.push_back(RandomMat(k.h, q.h)); } - float epsilon = 0.15; + float epsilon = 0.1; int ret = test_layer("MultiHeadAttention", pd, weights, as, 1, epsilon); if (ret != 0) @@ -98,7 +98,7 @@ static int test_multiheadattention_int8_samekv(const ncnn::Mat& q, const ncnn::M as[0] = q; as[1] = kv; - float epsilon = 0.15; + float epsilon = 0.1; int ret = test_layer("MultiHeadAttention", pd, weights, as, 1, epsilon); if (ret != 0) @@ -139,7 +139,7 @@ static int test_multiheadattention_int8_sameqkv(const ncnn::Mat& a, int embed_di std::vector as(1); as[0] = a; - float epsilon = 0.15; + float epsilon = 0.1; int ret = test_layer("MultiHeadAttention", pd, weights, as, 1, epsilon); if (ret != 0) From 6077adc6bc08ae89d3f41c817cec5e9cd6882117 Mon Sep 17 00:00:00 2001 From: nihui Date: Tue, 22 Oct 2024 18:31:22 +0800 Subject: [PATCH 12/15] pnnx do not fold tensor with dynamic shape, use fp32 module by default (#5755) --- tools/pnnx/src/ir.cpp | 3 +++ tools/pnnx/src/pass_level0/shape_inference.cpp | 4 +++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 8b2b6dfd2d7f..a0eb8d692bff 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -2390,6 +2390,7 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) { fprintf(pyfp, "def export_torchscript():\n"); fprintf(pyfp, " net = Model()\n"); + fprintf(pyfp, " net.float()\n"); fprintf(pyfp, " net.eval()\n"); fprintf(pyfp, "\n"); fprintf(pyfp, " torch.manual_seed(0)\n"); @@ -2455,6 +2456,7 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) { fprintf(pyfp, "def export_onnx():\n"); fprintf(pyfp, " net = Model()\n"); + fprintf(pyfp, " net.float()\n"); fprintf(pyfp, " net.eval()\n"); fprintf(pyfp, "\n"); fprintf(pyfp, " torch.manual_seed(0)\n"); @@ -2576,6 +2578,7 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) { fprintf(pyfp, "def test_inference():\n"); fprintf(pyfp, " net = Model()\n"); + fprintf(pyfp, " net.float()\n"); fprintf(pyfp, " net.eval()\n"); fprintf(pyfp, "\n"); fprintf(pyfp, " torch.manual_seed(0)\n"); diff --git a/tools/pnnx/src/pass_level0/shape_inference.cpp b/tools/pnnx/src/pass_level0/shape_inference.cpp index a273dd79df88..5865390bdfa0 100644 --- a/tools/pnnx/src/pass_level0/shape_inference.cpp +++ b/tools/pnnx/src/pass_level0/shape_inference.cpp @@ -418,12 +418,14 @@ void shape_inference(const torch::jit::Module& mod, std::shared_ptr sizes1 = type1->symbolic_sizes().sizes().value(); std::vector sizes2 = type2->symbolic_sizes().sizes().value(); + bool is_shape_static = true; for (size_t i = 0; i < sizes1.size(); i++) { if (sizes1[i] == sizes2[i]) continue; sizes1[i] = c10::ShapeSymbol::fromStaticSize(-1); + is_shape_static = false; } auto finaltype = type1->withSymbolicShapes(c10::SymbolicShape(sizes1)); @@ -431,7 +433,7 @@ void shape_inference(const torch::jit::Module& mod, std::shared_ptrsetType(finaltype); // check if value that does not depend on inputs - if (value_link_input_map.find(v->debugName()) == value_link_input_map.end() && value_link_output(v, g_outputs)) + if (is_shape_static && value_link_input_map.find(v->debugName()) == value_link_input_map.end() && value_link_output(v, g_outputs)) { // fprintf(stderr, "foldable_constant %s\n", v->debugName().c_str()); foldable_constants.insert(v->debugName()); From c32442aa09a9ce7abdc9c17dd06aef8b7edc74cc Mon Sep 17 00:00:00 2001 From: nihui Date: Tue, 29 Oct 2024 15:07:05 +0800 Subject: [PATCH 13/15] disable x86 auto recip optimization for potential precision loss (#5762) --- CMakeLists.txt | 4 ++++ src/CMakeLists.txt | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 875a8d06598f..440838f5a657 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -517,6 +517,8 @@ else() unset(CMAKE_REQUIRED_FLAGS) elseif(CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC") + check_cxx_compiler_flag("-mrecip=none" NCNN_COMPILER_SUPPORT_X86_RECIP_NONE) + check_cxx_compiler_flag("/arch:AVX" NCNN_COMPILER_SUPPORT_X86_AVX) set(CMAKE_REQUIRED_FLAGS "/arch:AVX -mfma -mf16c") @@ -543,6 +545,8 @@ else() unset(CMAKE_REQUIRED_FLAGS) else() + check_cxx_compiler_flag("-mrecip=none" NCNN_COMPILER_SUPPORT_X86_RECIP_NONE) + check_cxx_compiler_flag("-mavx" NCNN_COMPILER_SUPPORT_X86_AVX) set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c") diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 803c34a780d4..449e5864b9ee 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -403,6 +403,11 @@ if(NCNN_TARGET_ARCH STREQUAL "x86") target_compile_options(ncnn PRIVATE -msimd128) endif() endif() + + if(NCNN_COMPILER_SUPPORT_X86_RECIP_NONE) + # recip optimization causes precision loss + target_compile_options(ncnn PRIVATE -mrecip=none) + endif() endif() if(NOT NCNN_RUNTIME_CPU AND NCNN_AVX512) From e71fdf8e51033048da4d59e1bb8dabdc334f09e4 Mon Sep 17 00:00:00 2001 From: nihui Date: Thu, 7 Nov 2024 10:27:11 +0800 Subject: [PATCH 14/15] pnnx write implicit int conversion in python script (#5767) --- .ci/pnnx.yml | 4 ++-- tools/pnnx/src/ir.cpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.ci/pnnx.yml b/.ci/pnnx.yml index 207d78c4e2d2..a08379ff8dbc 100644 --- a/.ci/pnnx.yml +++ b/.ci/pnnx.yml @@ -200,8 +200,8 @@ jobs: export OMP_NUM_THREADS=1 export MKL_NUM_THREADS=1 export MKL_ENABLE_INSTRUCTIONS=SSE4_2 - cd tools/pnnx - cd build && ctest --output-on-failure -j 16 + cd tools/pnnx/build + ctest --output-on-failure -j 16 - name: python-pnnx run: | diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index a0eb8d692bff..9e616699a9e5 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -1051,7 +1051,7 @@ static std::string expand_expression(const Operator* op) || t == "torch.long") { std::string unaryop = t; - if (t == "int") unaryop = "int"; + if (t == "int") unaryop = ""; // but the explicit int() causes troubles in tracing if (t == "abs") unaryop = "torch.abs"; if (t == "acos") unaryop = "torch.acos"; if (t == "acosh") unaryop = "torch.acosh"; From 9cefe9a6243e2feb420b64ad3d0908d9abca78d4 Mon Sep 17 00:00:00 2001 From: nihui Date: Thu, 14 Nov 2024 15:00:07 +0800 Subject: [PATCH 15/15] avx vnni int8, avx vnni int16, avx ne convert infrastructure (#5749) --- CMakeLists.txt | 48 +++++++++++++++++ cmake/ncnn_add_layer.cmake | 27 ++++++++++ src/CMakeLists.txt | 27 ++++++++++ src/cpu.cpp | 102 +++++++++++++++++++++++++++++++++++++ src/cpu.h | 6 +++ src/platform.h.in | 3 ++ 6 files changed, 213 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 440838f5a657..bf0e9f20fb8e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -506,6 +506,15 @@ else() set(CMAKE_REQUIRED_FLAGS "/arch:AVX2") check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI) + set(CMAKE_REQUIRED_FLAGS "/arch:AVX2") + check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpbssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT8) + + set(CMAKE_REQUIRED_FLAGS "/arch:AVX2") + check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpwsud_avx_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT16) + + set(CMAKE_REQUIRED_FLAGS "/arch:AVX2") + check_cxx_source_compiles("#include \nint main() { __m256 _a; __m128bh _s = _mm256_cvtneps_avx_pbh(_a); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_NE_CONVERT) + set(CMAKE_REQUIRED_FLAGS "/arch:AVX512") check_cxx_source_compiles("#include \nint main() { __m512i _s, _a, _b; _s = _mm512_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_VNNI) @@ -534,6 +543,15 @@ else() set(CMAKE_REQUIRED_FLAGS "/arch:AVX2 -mfma -mf16c -mavxvnni") check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI) + set(CMAKE_REQUIRED_FLAGS "/arch:AVX2 -mfma -mf16c -mavxvnni -mavxvnniint8") + check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpbssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT8) + + set(CMAKE_REQUIRED_FLAGS "/arch:AVX2 -mfma -mf16c -mavxvnni -mavxvnniint16") + check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpwsud_avx_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT16) + + set(CMAKE_REQUIRED_FLAGS "/arch:AVX2 -mfma -mf16c -mavxneconvert") + check_cxx_source_compiles("#include \nint main() { __m256 _a; __m128bh _s = _mm256_cvtneps_avx_pbh(_a); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_NE_CONVERT) + set(CMAKE_REQUIRED_FLAGS "/arch:AVX512 -mfma -mf16c -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mavx512vnni") check_cxx_source_compiles("#include \nint main() { __m512i _s, _a, _b; _s = _mm512_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_VNNI) @@ -560,6 +578,15 @@ else() set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx2 -mavxvnni") check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI) + set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx2 -mavxvnni -mavxvnniint8") + check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpbssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT8) + + set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx2 -mavxvnni -mavxvnniint16") + check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpwsud_avx_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT16) + + set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx2 -mavxneconvert") + check_cxx_source_compiles("#include \nint main() { __m256 _a; __m128bh _s = _mm256_cvtneps_avx_pbh(_a); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_NE_CONVERT) + set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mavx512vnni") check_cxx_source_compiles("#include \nint main() { __m512i _s, _a, _b; _s = _mm512_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_VNNI) @@ -603,9 +630,30 @@ else() if(NCNN_AVX2) option(NCNN_AVXVNNI "optimize x86 platform with avx vnni extension" ON) endif() + if(NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT8) + if(NCNN_AVXVNNI) + option(NCNN_AVXVNNIINT8 "optimize x86 platform with avx vnni int8 extension" ON) + endif() + else() + message(WARNING "The compiler does not support avx vnni int8 extension. NCNN_AVXVNNIINT8 will be OFF.") + endif() + if(NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT16) + if(NCNN_AVXVNNI) + option(NCNN_AVXVNNIINT16 "optimize x86 platform with avx vnni int16 extension" ON) + endif() + else() + message(WARNING "The compiler does not support avx vnni int16 extension. NCNN_AVXVNNIINT16 will be OFF.") + endif() else() message(WARNING "The compiler does not support avx vnni extension. NCNN_AVXVNNI will be OFF.") endif() + if(NCNN_COMPILER_SUPPORT_X86_AVX_NE_CONVERT) + if(NCNN_AVX2) + option(NCNN_AVXNECONVERT "optimize x86 platform with avx ne convert extension" ON) + endif() + else() + message(WARNING "The compiler does not support avx ne convert extension. NCNN_AVXNECONVERT will be OFF.") + endif() if(NCNN_COMPILER_SUPPORT_X86_AVX512) if(NCNN_AVX2) option(NCNN_AVX512 "optimize x86 platform with avx512 extension" ON) diff --git a/cmake/ncnn_add_layer.cmake b/cmake/ncnn_add_layer.cmake index 7f334fb0b68d..d9f898f62c95 100644 --- a/cmake/ncnn_add_layer.cmake +++ b/cmake/ncnn_add_layer.cmake @@ -156,6 +156,15 @@ macro(ncnn_add_layer class) if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNI) ncnn_add_arch_opt_source(${class} avxvnni "/arch:AVX2 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__") endif() + if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNIINT8) + ncnn_add_arch_opt_source(${class} avxvnniint8 "/arch:AVX2 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__ /D__AVXVNNIINT8__") + endif() + if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNIINT16) + ncnn_add_arch_opt_source(${class} avxvnniint16 "/arch:AVX2 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__ /D__AVXVNNIINT16__") + endif() + if(NCNN_RUNTIME_CPU AND NCNN_AVXNECONVERT) + ncnn_add_arch_opt_source(${class} avxneconvert "/arch:AVX2 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXNECONVERT__") + endif() if(NCNN_RUNTIME_CPU AND NCNN_AVX2) ncnn_add_arch_opt_source(${class} avx2 "/arch:AVX2 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__") endif() @@ -187,6 +196,15 @@ macro(ncnn_add_layer class) if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNI) ncnn_add_arch_opt_source(${class} avxvnni "/arch:AVX2 -mfma -mf16c -mavxvnni /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__") endif() + if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNIINT8) + ncnn_add_arch_opt_source(${class} avxvnniint8 "/arch:AVX2 -mfma -mf16c -mavxvnni -mavxvnniint8 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__ /D__AVXVNNIINT8__") + endif() + if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNIINT16) + ncnn_add_arch_opt_source(${class} avxvnniint16 "/arch:AVX2 -mfma -mf16c -mavxvnni -mavxvnniint16 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__ /D__AVXVNNIINT16__") + endif() + if(NCNN_RUNTIME_CPU AND NCNN_AVXNECONVERT) + ncnn_add_arch_opt_source(${class} avxneconvert "/arch:AVX2 -mfma -mf16c -mavxneconvert /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXNECONVERT__") + endif() if(NCNN_RUNTIME_CPU AND NCNN_AVX2) ncnn_add_arch_opt_source(${class} avx2 "/arch:AVX2 -mfma -mf16c /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__") endif() @@ -218,6 +236,15 @@ macro(ncnn_add_layer class) if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNI) ncnn_add_arch_opt_source(${class} avxvnni "-mavx2 -mfma -mf16c -mavxvnni") endif() + if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNIINT8) + ncnn_add_arch_opt_source(${class} avxvnniint8 "-mavx2 -mfma -mf16c -mavxvnni -mavxvnniint8") + endif() + if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNIINT16) + ncnn_add_arch_opt_source(${class} avxvnniint16 "-mavx2 -mfma -mf16c -mavxvnni -mavxvnniint16") + endif() + if(NCNN_RUNTIME_CPU AND NCNN_AVXNECONVERT) + ncnn_add_arch_opt_source(${class} avxneconvert "-mavx2 -mfma -mf16c -mavxneconvert") + endif() if(NCNN_RUNTIME_CPU AND NCNN_AVX2) ncnn_add_arch_opt_source(${class} avx2 "-mavx2 -mfma -mf16c") endif() diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 449e5864b9ee..4aa952e3d0f0 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -446,6 +446,15 @@ if(NCNN_TARGET_ARCH STREQUAL "x86") else() target_compile_options(ncnn PRIVATE /arch:AVX /D__SSSE3__ /D__SSE4_1__ /D__FMA__) endif() + if(NCNN_AVXVNNIINT8) + target_compile_options(ncnn PRIVATE /D__AVXVNNIINT8__) + endif() + if(NCNN_AVXVNNIINT16) + target_compile_options(ncnn PRIVATE /D__AVXVNNIINT16__) + endif() + if(NCNN_AVXNECONVERT) + target_compile_options(ncnn PRIVATE /D__AVXNECONVERT__) + endif() if(NCNN_AVXVNNI) target_compile_options(ncnn PRIVATE /D__AVXVNNI__) elseif(NCNN_XOP) @@ -460,6 +469,15 @@ if(NCNN_TARGET_ARCH STREQUAL "x86") else() target_compile_options(ncnn PRIVATE /arch:AVX -mfma /D__SSSE3__ /D__SSE4_1__ /D__FMA__) endif() + if(NCNN_AVXVNNIINT8) + target_compile_options(ncnn PRIVATE -mavxvnniint8 /D__AVXVNNIINT8__) + endif() + if(NCNN_AVXVNNIINT16) + target_compile_options(ncnn PRIVATE -mavxvnniint16 /D__AVXVNNIINT16__) + endif() + if(NCNN_AVXNECONVERT) + target_compile_options(ncnn PRIVATE -mavxneconvert /D__AVXNECONVERT__) + endif() if(NCNN_AVXVNNI) target_compile_options(ncnn PRIVATE -mavxvnni /D__AVXVNNI__) elseif(NCNN_XOP) @@ -474,6 +492,15 @@ if(NCNN_TARGET_ARCH STREQUAL "x86") else() target_compile_options(ncnn PRIVATE -mavx -mfma) endif() + if(NCNN_AVXVNNIINT8) + target_compile_options(ncnn PRIVATE -mavxvnniint8) + endif() + if(NCNN_AVXVNNIINT16) + target_compile_options(ncnn PRIVATE -mavxvnniint16) + endif() + if(NCNN_AVXNECONVERT) + target_compile_options(ncnn PRIVATE -mavxneconvert) + endif() if(NCNN_AVXVNNI) target_compile_options(ncnn PRIVATE -mavxvnni) elseif(NCNN_XOP) diff --git a/src/cpu.cpp b/src/cpu.cpp index 9ab0ebb31e99..c9307619ce91 100644 --- a/src/cpu.cpp +++ b/src/cpu.cpp @@ -183,6 +183,9 @@ static int g_cpu_support_x86_xop; static int g_cpu_support_x86_f16c; static int g_cpu_support_x86_avx2; static int g_cpu_support_x86_avx_vnni; +static int g_cpu_support_x86_avx_vnni_int8; +static int g_cpu_support_x86_avx_vnni_int16; +static int g_cpu_support_x86_avx_ne_convert; static int g_cpu_support_x86_avx512; static int g_cpu_support_x86_avx512_vnni; static int g_cpu_support_x86_avx512_bf16; @@ -617,6 +620,72 @@ static int get_cpu_support_x86_avx_vnni() return cpu_info[0] & (1u << 4); } +static int get_cpu_support_x86_avx_vnni_int8() +{ + unsigned int cpu_info[4] = {0}; + x86_cpuid(0, cpu_info); + + int nIds = cpu_info[0]; + if (nIds < 7) + return 0; + + x86_cpuid(1, cpu_info); + // check AVX XSAVE OSXSAVE + if (!(cpu_info[2] & (1u << 28)) || !(cpu_info[2] & (1u << 26)) || !(cpu_info[2] & (1u << 27))) + return 0; + + // check XSAVE enabled by kernel + if ((x86_get_xcr0() & 6) != 6) + return 0; + + x86_cpuid_sublevel(7, 1, cpu_info); + return cpu_info[3] & (1u << 4); +} + +static int get_cpu_support_x86_avx_vnni_int16() +{ + unsigned int cpu_info[4] = {0}; + x86_cpuid(0, cpu_info); + + int nIds = cpu_info[0]; + if (nIds < 7) + return 0; + + x86_cpuid(1, cpu_info); + // check AVX XSAVE OSXSAVE + if (!(cpu_info[2] & (1u << 28)) || !(cpu_info[2] & (1u << 26)) || !(cpu_info[2] & (1u << 27))) + return 0; + + // check XSAVE enabled by kernel + if ((x86_get_xcr0() & 6) != 6) + return 0; + + x86_cpuid_sublevel(7, 1, cpu_info); + return cpu_info[3] & (1u << 10); +} + +static int get_cpu_support_x86_avx_ne_convert() +{ + unsigned int cpu_info[4] = {0}; + x86_cpuid(0, cpu_info); + + int nIds = cpu_info[0]; + if (nIds < 7) + return 0; + + x86_cpuid(1, cpu_info); + // check AVX XSAVE OSXSAVE + if (!(cpu_info[2] & (1u << 28)) || !(cpu_info[2] & (1u << 26)) || !(cpu_info[2] & (1u << 27))) + return 0; + + // check XSAVE enabled by kernel + if ((x86_get_xcr0() & 6) != 6) + return 0; + + x86_cpuid_sublevel(7, 1, cpu_info); + return cpu_info[3] & (1u << 5); +} + static int get_cpu_support_x86_avx512() { #if __APPLE__ @@ -1967,6 +2036,9 @@ static void initialize_global_cpu_info() g_cpu_support_x86_f16c = get_cpu_support_x86_f16c(); g_cpu_support_x86_avx2 = get_cpu_support_x86_avx2(); g_cpu_support_x86_avx_vnni = get_cpu_support_x86_avx_vnni(); + g_cpu_support_x86_avx_vnni_int8 = get_cpu_support_x86_avx_vnni_int8(); + g_cpu_support_x86_avx_vnni_int16 = get_cpu_support_x86_avx_vnni_int16(); + g_cpu_support_x86_avx_ne_convert = get_cpu_support_x86_avx_ne_convert(); g_cpu_support_x86_avx512 = get_cpu_support_x86_avx512(); g_cpu_support_x86_avx512_vnni = get_cpu_support_x86_avx512_vnni(); g_cpu_support_x86_avx512_bf16 = get_cpu_support_x86_avx512_bf16(); @@ -2489,6 +2561,36 @@ int cpu_support_x86_avx_vnni() #endif } +int cpu_support_x86_avx_vnni_int8() +{ + try_initialize_global_cpu_info(); +#if defined(__i386__) || defined(__x86_64__) || defined(_M_IX86) || defined(_M_X64) + return g_cpu_support_x86_avx_vnni_int8; +#else + return 0; +#endif +} + +int cpu_support_x86_avx_vnni_int16() +{ + try_initialize_global_cpu_info(); +#if defined(__i386__) || defined(__x86_64__) || defined(_M_IX86) || defined(_M_X64) + return g_cpu_support_x86_avx_vnni_int16; +#else + return 0; +#endif +} + +int cpu_support_x86_avx_ne_convert() +{ + try_initialize_global_cpu_info(); +#if defined(__i386__) || defined(__x86_64__) || defined(_M_IX86) || defined(_M_X64) + return g_cpu_support_x86_avx_ne_convert; +#else + return 0; +#endif +} + int cpu_support_x86_avx512() { try_initialize_global_cpu_info(); diff --git a/src/cpu.h b/src/cpu.h index 2ae6b8c3ffe9..f0e4728633fe 100644 --- a/src/cpu.h +++ b/src/cpu.h @@ -93,6 +93,12 @@ NCNN_EXPORT int cpu_support_x86_f16c(); NCNN_EXPORT int cpu_support_x86_avx2(); // avx_vnni = x86 avx vnni NCNN_EXPORT int cpu_support_x86_avx_vnni(); +// avx_vnni_int8 = x86 avx vnni int8 +NCNN_EXPORT int cpu_support_x86_avx_vnni_int8(); +// avx_vnni_int16 = x86 avx vnni int16 +NCNN_EXPORT int cpu_support_x86_avx_vnni_int16(); +// avx_ne_convert = x86 avx ne convert +NCNN_EXPORT int cpu_support_x86_avx_ne_convert(); // avx512 = x86 avx512f + avx512cd + avx512bw + avx512dq + avx512vl NCNN_EXPORT int cpu_support_x86_avx512(); // avx512_vnni = x86 avx512 vnni diff --git a/src/platform.h.in b/src/platform.h.in index 50a9454b7da0..a0b372d8296b 100644 --- a/src/platform.h.in +++ b/src/platform.h.in @@ -40,6 +40,9 @@ #cmakedefine01 NCNN_F16C #cmakedefine01 NCNN_AVX2 #cmakedefine01 NCNN_AVXVNNI +#cmakedefine01 NCNN_AVXVNNIINT8 +#cmakedefine01 NCNN_AVXVNNIINT16 +#cmakedefine01 NCNN_AVXNECONVERT #cmakedefine01 NCNN_AVX512 #cmakedefine01 NCNN_AVX512VNNI #cmakedefine01 NCNN_AVX512BF16